diff --git a/.asf.yaml b/.asf.yaml index 5fe94dc04af5..aab8c1e6df2d 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -50,6 +50,9 @@ github: main: required_pull_request_reviews: required_approving_review_count: 1 + pull_requests: + # enable updating head branches of pull requests + allow_update_branch: true # publishes the content of the `asf-site` branch to # https://datafusion.apache.org/ diff --git a/.github/workflows/docs_pr.yaml b/.github/workflows/docs_pr.yaml index d3c901c5b71b..8d11cdf9d39b 100644 --- a/.github/workflows/docs_pr.yaml +++ b/.github/workflows/docs_pr.yaml @@ -34,26 +34,7 @@ on: workflow_dispatch: jobs: - # Run doc tests - linux-test-doc: - name: cargo doctest (amd64) - runs-on: ubuntu-latest - container: - image: amd64/rust - steps: - - uses: actions/checkout@v4 - with: - submodules: true - fetch-depth: 1 - - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: stable - - name: Run doctests (embedded rust examples) - run: cargo test --doc --features avro,json - - name: Verify Working Directory Clean - run: git diff --exit-code - + # Test doc build linux-test-doc-build: name: Test doc build diff --git a/.github/workflows/extended.yml b/.github/workflows/extended.yml index d80fdb75d932..fb97fbac97d9 100644 --- a/.github/workflows/extended.yml +++ b/.github/workflows/extended.yml @@ -101,7 +101,17 @@ jobs: - name: Run tests (excluding doctests) env: RUST_BACKTRACE: 1 - run: cargo test --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --workspace --lib --tests --bins --features avro,json,backtrace,extended_tests,recursive_protection + run: | + cargo test \ + --profile ci \ + --exclude datafusion-examples \ + --exclude datafusion-benchmarks \ + --exclude datafusion-cli \ + --workspace \ + --lib \ + --tests \ + --bins \ + --features avro,json,backtrace,extended_tests,recursive_protection - name: Verify Working Directory Clean run: git diff --exit-code - name: Cleanup @@ -126,7 +136,7 @@ jobs: - name: Run tests run: | cd datafusion - cargo test --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --exclude datafusion-sqllogictest --workspace --lib --tests --features=force_hash_collisions,avro + cargo test --profile ci --exclude datafusion-examples --exclude datafusion-benchmarks --exclude datafusion-sqllogictest --exclude datafusion-cli --workspace --lib --tests --features=force_hash_collisions,avro cargo clean sqllogictest-sqlite: diff --git a/.github/workflows/dev_pr.yml b/.github/workflows/labeler.yml similarity index 96% rename from .github/workflows/dev_pr.yml rename to .github/workflows/labeler.yml index 11c14c5c2fee..8b251552d3b2 100644 --- a/.github/workflows/dev_pr.yml +++ b/.github/workflows/labeler.yml @@ -49,7 +49,7 @@ jobs: uses: actions/labeler@v5.0.0 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - configuration-path: .github/workflows/dev_pr/labeler.yml + configuration-path: .github/workflows/labeler/labeler-config.yml sync-labels: true # TODO: Enable this when eps1lon/actions-label-merge-conflict is available. diff --git a/.github/workflows/dev_pr/labeler.yml b/.github/workflows/labeler/labeler-config.yml similarity index 95% rename from .github/workflows/dev_pr/labeler.yml rename to .github/workflows/labeler/labeler-config.yml index da93e6541855..e40813072521 100644 --- a/.github/workflows/dev_pr/labeler.yml +++ b/.github/workflows/labeler/labeler-config.yml @@ -41,7 +41,7 @@ physical-expr: physical-plan: - changed-files: - - any-glob-to-any-file: [datafusion/physical-plan/**/*'] + - any-glob-to-any-file: ['datafusion/physical-plan/**/*'] catalog: @@ -77,6 +77,10 @@ proto: - changed-files: - any-glob-to-any-file: ['datafusion/proto/**/*', 'datafusion/proto-common/**/*'] +spark: +- changed-files: + - any-glob-to-any-file: ['datafusion/spark/**/*'] + substrait: - changed-files: - any-glob-to-any-file: ['datafusion/substrait/**/*'] diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 3fa8ce080474..ecb25483ce07 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -39,14 +39,6 @@ on: workflow_dispatch: jobs: - # Check license header - license-header-check: - runs-on: ubuntu-latest - name: Check License Header - steps: - - uses: actions/checkout@v4 - - uses: korandoru/hawkeye@v6 - # Check crate compiles and base cargo check passes linux-build-lib: name: linux build test @@ -401,8 +393,8 @@ jobs: - name: Run tests with headless mode working-directory: ./datafusion/wasmtest run: | - wasm-pack test --headless --firefox - wasm-pack test --headless --chrome --chromedriver $CHROMEWEBDRIVER/chromedriver + RUSTFLAGS='--cfg getrandom_backend="wasm_js"' wasm-pack test --headless --firefox + RUSTFLAGS='--cfg getrandom_backend="wasm_js"' wasm-pack test --headless --chrome --chromedriver $CHROMEWEBDRIVER/chromedriver # verify that the benchmark queries return the correct results verify-benchmark-results: @@ -476,6 +468,28 @@ jobs: POSTGRES_HOST: postgres POSTGRES_PORT: ${{ job.services.postgres.ports[5432] }} + sqllogictest-substrait: + name: "Run sqllogictest in Substrait round-trip mode" + needs: linux-build-lib + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 1 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Run sqllogictest + # TODO: Right now several tests are failing in Substrait round-trip mode, so this + # command cannot be run for all the .slt files. Run it for just one that works (limit.slt) + # until most of the tickets in https://github.com/apache/datafusion/issues/16248 are addressed + # and this command can be run without filters. + run: cargo test --test sqllogictests -- --substrait-round-trip limit.slt + # Temporarily commenting out the Windows flow, the reason is enormously slow running build # Waiting for new Windows 2025 github runner # Details: https://github.com/apache/datafusion/issues/13726 diff --git a/Cargo.lock b/Cargo.lock index 2b3eeecf5d9b..1dd9031ce426 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -77,23 +77,23 @@ version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "891477e0c6a8957309ee5c45a6368af3ae14bb510732d2684ffa19af310920f9" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", "once_cell", "version_check", ] [[package]] name = "ahash" -version = "0.8.11" +version = "0.8.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" dependencies = [ "cfg-if", "const-random", - "getrandom 0.2.15", + "getrandom 0.3.3", "once_cell", "version_check", - "zerocopy 0.7.35", + "zerocopy", ] [[package]] @@ -199,9 +199,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.95" +version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" [[package]] name = "apache-avro" @@ -223,8 +223,8 @@ dependencies = [ "serde_bytes", "serde_json", "snap", - "strum 0.26.3", - "strum_macros 0.26.4", + "strum", + "strum_macros", "thiserror 1.0.69", "typed-builder", "uuid", @@ -246,9 +246,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "55.0.0" +version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3095aaf545942ff5abd46654534f15b03a90fba78299d661e045e5d587222f0d" +checksum = "f3f15b4c6b148206ff3a2b35002e08929c2462467b62b9c02036d9c34f9ef994" dependencies = [ "arrow-arith", "arrow-array", @@ -259,20 +259,20 @@ dependencies = [ "arrow-ipc", "arrow-json", "arrow-ord", + "arrow-pyarrow", "arrow-row", "arrow-schema", "arrow-select", "arrow-string", "half", - "pyo3", - "rand 0.9.0", + "rand 0.9.1", ] [[package]] name = "arrow-arith" -version = "55.0.0" +version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00752064ff47cee746e816ddb8450520c3a52cbad1e256f6fa861a35f86c45e7" +checksum = "30feb679425110209ae35c3fbf82404a39a4c0436bb3ec36164d8bffed2a4ce4" dependencies = [ "arrow-array", "arrow-buffer", @@ -284,26 +284,26 @@ dependencies = [ [[package]] name = "arrow-array" -version = "55.0.0" +version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cebfe926794fbc1f49ddd0cdaf898956ca9f6e79541efce62dabccfd81380472" +checksum = "70732f04d285d49054a48b72c54f791bb3424abae92d27aafdf776c98af161c8" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow-buffer", "arrow-data", "arrow-schema", "chrono", "chrono-tz", "half", - "hashbrown 0.15.2", + "hashbrown 0.15.3", "num", ] [[package]] name = "arrow-buffer" -version = "55.0.0" +version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0303c7ec4cf1a2c60310fc4d6bbc3350cd051a17bf9e9c0a8e47b4db79277824" +checksum = "169b1d5d6cb390dd92ce582b06b23815c7953e9dfaaea75556e89d890d19993d" dependencies = [ "bytes", "half", @@ -312,9 +312,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "55.0.0" +version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "335f769c5a218ea823d3760a743feba1ef7857cba114c01399a891c2fff34285" +checksum = "e4f12eccc3e1c05a766cafb31f6a60a46c2f8efec9b74c6e0648766d30686af8" dependencies = [ "arrow-array", "arrow-buffer", @@ -333,9 +333,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "55.0.0" +version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "510db7dfbb4d5761826516cc611d97b3a68835d0ece95b034a052601109c0b1b" +checksum = "012c9fef3f4a11573b2c74aec53712ff9fdae4a95f4ce452d1bbf088ee00f06b" dependencies = [ "arrow-array", "arrow-cast", @@ -343,15 +343,14 @@ dependencies = [ "chrono", "csv", "csv-core", - "lazy_static", "regex", ] [[package]] name = "arrow-data" -version = "55.0.0" +version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8affacf3351a24039ea24adab06f316ded523b6f8c3dbe28fbac5f18743451b" +checksum = "8de1ce212d803199684b658fc4ba55fb2d7e87b213de5af415308d2fee3619c2" dependencies = [ "arrow-buffer", "arrow-schema", @@ -361,9 +360,9 @@ dependencies = [ [[package]] name = "arrow-flight" -version = "55.0.0" +version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2e0fad280f41a918d53ba48288a246ff04202d463b3b380fbc0edecdcb52cfd" +checksum = "5cb3e1d2b441e6d1d5988e3f7c4523c9466b18ef77d7c525d92d36d4cad49fbe" dependencies = [ "arrow-arith", "arrow-array", @@ -388,9 +387,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "55.0.0" +version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69880a9e6934d9cba2b8630dd08a3463a91db8693b16b499d54026b6137af284" +checksum = "d9ea5967e8b2af39aff5d9de2197df16e305f47f404781d3230b2dc672da5d92" dependencies = [ "arrow-array", "arrow-buffer", @@ -398,13 +397,14 @@ dependencies = [ "arrow-schema", "flatbuffers", "lz4_flex", + "zstd", ] [[package]] name = "arrow-json" -version = "55.0.0" +version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8dafd17a05449e31e0114d740530e0ada7379d7cb9c338fd65b09a8130960b0" +checksum = "5709d974c4ea5be96d900c01576c7c0b99705f4a3eec343648cb1ca863988a9c" dependencies = [ "arrow-array", "arrow-buffer", @@ -413,7 +413,7 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap 2.8.0", + "indexmap 2.10.0", "lexical-core", "memchr", "num", @@ -424,9 +424,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "55.0.0" +version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "895644523af4e17502d42c3cb6b27cb820f0cb77954c22d75c23a85247c849e1" +checksum = "6506e3a059e3be23023f587f79c82ef0bcf6d293587e3272d20f2d30b969b5a7" dependencies = [ "arrow-array", "arrow-buffer", @@ -435,11 +435,23 @@ dependencies = [ "arrow-select", ] +[[package]] +name = "arrow-pyarrow" +version = "55.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e55ecf16b9b61d433f6e63c72fc6afcf2597d7db96583de88ebb887d1822268" +dependencies = [ + "arrow-array", + "arrow-data", + "arrow-schema", + "pyo3", +] + [[package]] name = "arrow-row" -version = "55.0.0" +version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9be8a2a4e5e7d9c822b2b8095ecd77010576d824f654d347817640acfc97d229" +checksum = "52bf7393166beaf79b4bed9bfdf19e97472af32ce5b6b48169d321518a08cae2" dependencies = [ "arrow-array", "arrow-buffer", @@ -450,21 +462,22 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "55.0.0" +version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7450c76ab7c5a6805be3440dc2e2096010da58f7cab301fdc996a4ee3ee74e49" +checksum = "af7686986a3bf2254c9fb130c623cdcb2f8e1f15763e7c71c310f0834da3d292" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "serde", + "serde_json", ] [[package]] name = "arrow-select" -version = "55.0.0" +version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa5f5a93c75f46ef48e4001535e7b6c922eeb0aa20b73cf58d09e13d057490d8" +checksum = "dd2b45757d6a2373faa3352d02ff5b54b098f5e21dccebc45a21806bc34501e5" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow-array", "arrow-buffer", "arrow-data", @@ -474,9 +487,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "55.0.0" +version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e7005d858d84b56428ba2a98a107fe88c0132c61793cf6b8232a1f9bfc0452b" +checksum = "0377d532850babb4d927a06294314b316e23311503ed580ec6ce6a0158f49d40" dependencies = [ "arrow-array", "arrow-buffer", @@ -503,9 +516,9 @@ dependencies = [ [[package]] name = "assert_cmd" -version = "2.0.16" +version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc1835b7f27878de8525dc71410b5a31cdcc5f230aed5ba5df968e09c201b23d" +checksum = "2bd389a4b2970a01282ee455294913c0a43724daedcd1a24c3eb0ec1c1320b66" dependencies = [ "anstyle", "bstr", @@ -551,7 +564,7 @@ checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -573,7 +586,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -584,7 +597,7 @@ checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -610,9 +623,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-config" -version = "1.6.1" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c39646d1a6b51240a1a23bb57ea4eebede7e16fbc237fdc876980233dcecb4f" +checksum = "455e9fb7743c6f6267eb2830ccc08686fbb3d13c9a689369562fd4d4ef9ea462" dependencies = [ "aws-credential-types", "aws-runtime", @@ -629,7 +642,7 @@ dependencies = [ "bytes", "fastrand", "hex", - "http 1.2.0", + "http 1.3.1", "ring", "time", "tokio", @@ -640,9 +653,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.2.2" +version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4471bef4c22a06d2c7a1b6492493d3fdf24a805323109d6874f9c94d5906ac14" +checksum = "687bc16bc431a8533fe0097c7f0182874767f920989d7260950172ae8e3c4465" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -652,9 +665,9 @@ dependencies = [ [[package]] name = "aws-lc-rs" -version = "1.12.6" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dabb68eb3a7aa08b46fddfd59a3d55c978243557a90ab804769f7e20e67d2b01" +checksum = "93fcc8f365936c834db5514fc45aee5b1202d677e6b40e48468aaaa8183ca8c7" dependencies = [ "aws-lc-sys", "zeroize", @@ -662,9 +675,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.27.0" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bbe221bbf523b625a4dd8585c7f38166e31167ec2ca98051dbcb4c3b6e825d2" +checksum = "61b1d86e7705efe1be1b569bab41d4fa1e14e220b60a160f78de2db687add079" dependencies = [ "bindgen", "cc", @@ -675,9 +688,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.5.6" +version = "1.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0aff45ffe35196e593ea3b9dd65b320e51e2dda95aff4390bc459e461d09c6ad" +checksum = "4f6c68419d8ba16d9a7463671593c54f81ba58cab466e9b759418da606dcc2e2" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -691,7 +704,6 @@ dependencies = [ "fastrand", "http 0.2.12", "http-body 0.4.6", - "once_cell", "percent-encoding", "pin-project-lite", "tracing", @@ -700,9 +712,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.63.0" +version = "1.73.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1cb45b83b53b5cd55ee33fd9fd8a70750255a3f286e4dca20e882052f2b256f" +checksum = "b2ac1674cba7872061a29baaf02209fefe499ff034dfd91bd4cc59e4d7741489" dependencies = [ "aws-credential-types", "aws-runtime", @@ -716,16 +728,15 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "once_cell", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-ssooidc" -version = "1.64.0" +version = "1.74.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8d4d9bc075ea6238778ed3951b65d3cde8c3864282d64fdcd19f2a90c0609f1" +checksum = "3a6a22f077f5fd3e3c0270d4e1a110346cddf6769e9433eb9e6daceb4ca3b149" dependencies = [ "aws-credential-types", "aws-runtime", @@ -739,16 +750,15 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "once_cell", "regex-lite", "tracing", ] [[package]] name = "aws-sdk-sts" -version = "1.64.0" +version = "1.75.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "819ccba087f403890fee4825eeab460e64c59345667d2b83a12cf544b581e3a7" +checksum = "e3258fa707f2f585ee3049d9550954b959002abd59176975150a01d5cf38ae3f" dependencies = [ "aws-credential-types", "aws-runtime", @@ -763,16 +773,15 @@ dependencies = [ "aws-types", "fastrand", "http 0.2.12", - "once_cell", "regex-lite", "tracing", ] [[package]] name = "aws-sigv4" -version = "1.3.0" +version = "1.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d03c3c05ff80d54ff860fe38c726f6f494c639ae975203a101335f223386db" +checksum = "ddfb9021f581b71870a17eac25b52335b82211cdc092e02b6876b2bcefa61666" dependencies = [ "aws-credential-types", "aws-smithy-http", @@ -783,8 +792,7 @@ dependencies = [ "hex", "hmac", "http 0.2.12", - "http 1.2.0", - "once_cell", + "http 1.3.1", "percent-encoding", "sha2", "time", @@ -804,9 +812,9 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.62.0" +version = "0.62.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5949124d11e538ca21142d1fba61ab0a2a2c1bc3ed323cdb3e4b878bfb83166" +checksum = "99335bec6cdc50a346fda1437f9fefe33abf8c99060739a546a16457f2862ca9" dependencies = [ "aws-smithy-runtime-api", "aws-smithy-types", @@ -814,9 +822,8 @@ dependencies = [ "bytes-utils", "futures-core", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "http-body 0.4.6", - "once_cell", "percent-encoding", "pin-project-lite", "pin-utils", @@ -825,15 +832,15 @@ dependencies = [ [[package]] name = "aws-smithy-http-client" -version = "1.0.0" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0497ef5d53065b7cd6a35e9c1654bd1fefeae5c52900d91d1b188b0af0f29324" +checksum = "7e44697a9bded898dcd0b1cb997430d949b87f4f8940d91023ae9062bf218250" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", "aws-smithy-types", "h2", - "http 1.2.0", + "http 1.3.1", "hyper", "hyper-rustls", "hyper-util", @@ -848,21 +855,20 @@ dependencies = [ [[package]] name = "aws-smithy-json" -version = "0.61.3" +version = "0.61.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92144e45819cae7dc62af23eac5a038a58aa544432d2102609654376a900bd07" +checksum = "a16e040799d29c17412943bdbf488fd75db04112d0c0d4b9290bacf5ae0014b9" dependencies = [ "aws-smithy-types", ] [[package]] name = "aws-smithy-observability" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445d065e76bc1ef54963db400319f1dd3ebb3e0a74af20f7f7630625b0cc7cc0" +checksum = "9364d5989ac4dd918e5cc4c4bdcc61c9be17dcd2586ea7f69e348fc7c6cab393" dependencies = [ "aws-smithy-runtime-api", - "once_cell", ] [[package]] @@ -877,9 +883,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.8.1" +version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0152749e17ce4d1b47c7747bdfec09dac1ccafdcbc741ebf9daa2a373356730f" +checksum = "14302f06d1d5b7d333fd819943075b13d27c7700b414f574c3c35859bfb55d5e" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -890,10 +896,9 @@ dependencies = [ "bytes", "fastrand", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "http-body 0.4.6", "http-body 1.0.1", - "once_cell", "pin-project-lite", "pin-utils", "tokio", @@ -902,15 +907,15 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.7.4" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3da37cf5d57011cb1753456518ec76e31691f1f474b73934a284eb2a1c76510f" +checksum = "bd8531b6d8882fd8f48f82a9754e682e29dd44cff27154af51fa3eb730f59efb" dependencies = [ "aws-smithy-async", "aws-smithy-types", "bytes", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "pin-project-lite", "tokio", "tracing", @@ -919,15 +924,15 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.3.0" +version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "836155caafba616c0ff9b07944324785de2ab016141c3550bd1c07882f8cee8f" +checksum = "d498595448e43de7f4296b7b7a18a8a02c61ec9349128c80a368f7c3b4ab11a8" dependencies = [ "base64-simd", "bytes", "bytes-utils", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -942,18 +947,18 @@ dependencies = [ [[package]] name = "aws-smithy-xml" -version = "0.60.9" +version = "0.60.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab0b0166827aa700d3dc519f72f8b3a91c35d0b8d042dc5d643a91e6f80648fc" +checksum = "3db87b96cb1b16c024980f133968d52882ca0daaee3a086c6decc500f6c99728" dependencies = [ "xmlparser", ] [[package]] name = "aws-types" -version = "1.3.6" +version = "1.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3873f8deed8927ce8d04487630dc9ff73193bab64742a61d050e57a68dec4125" +checksum = "8a322fec39e4df22777ed3ad8ea868ac2f94cd15e1a55f6ee8d8d6305057689a" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -973,7 +978,7 @@ dependencies = [ "axum-core", "bytes", "futures-util", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "http-body-util", "itoa", @@ -999,7 +1004,7 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "http-body-util", "mime", @@ -1012,9 +1017,9 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.74" +version = "0.3.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +checksum = "6806a6321ec58106fea15becdad98371e28d92ccbc7c8f1b3b6dd724fe8f1002" dependencies = [ "addr2line", "cfg-if", @@ -1067,7 +1072,7 @@ version = "0.69.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "cexpr", "clang-sys", "itertools 0.10.5", @@ -1080,7 +1085,7 @@ dependencies = [ "regex", "rustc-hash 1.1.0", "shlex", - "syn 2.0.100", + "syn 2.0.104", "which", ] @@ -1092,9 +1097,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.8.0" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" +checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" [[package]] name = "bitvec" @@ -1119,9 +1124,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.8.1" +version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "389a099b34312839e16420d499a9cad9650541715937ffbdd40d36f49e77eeb3" +checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" dependencies = [ "arrayref", "arrayvec", @@ -1152,7 +1157,7 @@ dependencies = [ "futures-util", "hex", "home", - "http 1.2.0", + "http 1.3.1", "http-body-util", "hyper", "hyper-named-pipe", @@ -1191,9 +1196,9 @@ dependencies = [ [[package]] name = "borsh" -version = "1.5.5" +version = "1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5430e3be710b68d984d1391c854eb431a9d548640711faa54eecb1df93db91cc" +checksum = "ad8646f98db542e39fc66e68a20b2144f6a732636df7c2354e74645faaa433ce" dependencies = [ "borsh-derive", "cfg_aliases", @@ -1201,22 +1206,22 @@ dependencies = [ [[package]] name = "borsh-derive" -version = "1.5.5" +version = "1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8b668d39970baad5356d7c83a86fee3a539e6f93bf6764c97368243e17a0487" +checksum = "fdd1d3c0c2f5833f22386f252fe8ed005c7f59fdcddeef025c01b4c3b9fd9ac3" dependencies = [ "once_cell", "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] name = "brotli" -version = "7.0.0" +version = "8.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd" +checksum = "9991eea70ea4f293524138648e41ee89b0b2b12ddef3b255effa43c8056e0e0d" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -1225,9 +1230,9 @@ dependencies = [ [[package]] name = "brotli-decompressor" -version = "4.0.2" +version = "5.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74fa05ad7d803d413eb8380983b092cbbaf9a85f151b871360e7b00cd7060b37" +checksum = "874bb8112abecc98cbd6d81ea4fa7e94fb9449648c93cc89aa40c81c24d7de03" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -1235,9 +1240,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.11.3" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "531a9155a481e2ee699d4f98f43c0ca4ff8ee1bfd55c31e9e98fb29d2b176fe0" +checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4" dependencies = [ "memchr", "regex-automata", @@ -1313,6 +1318,15 @@ dependencies = [ "bzip2-sys", ] +[[package]] +name = "bzip2" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bea8dcd42434048e4f7a304411d9273a411f647446c1234a65ce0554923f4cff" +dependencies = [ + "libbz2-rs-sys", +] + [[package]] name = "bzip2-sys" version = "0.1.13+1.0.8" @@ -1331,9 +1345,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.14" +version = "1.2.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9" +checksum = "5f4ac86a9e5bc1e2b3449ab9d7d3a6a405e3d1bb28d7b9be8614f55846ae3766" dependencies = [ "jobserver", "libc", @@ -1363,9 +1377,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.40" +version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a7964611d71df112cb1730f2ee67324fcf4d0fc6606acbbe9bfe06df124637c" +checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" dependencies = [ "android-tzdata", "iana-time-zone", @@ -1389,9 +1403,9 @@ dependencies = [ [[package]] name = "chrono-tz-build" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94fea34d77a245229e7746bd2beb786cd2a896f306ff491fb8cecb3074b10a7" +checksum = "8f10f8c9340e31fc120ff885fcdb54a0b48e474bbd77cab557f0c30a3e569402" dependencies = [ "parse-zoneinfo", "phf_codegen", @@ -1432,7 +1446,7 @@ checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" dependencies = [ "glob", "libc", - "libloading 0.8.6", + "libloading 0.8.7", ] [[package]] @@ -1448,9 +1462,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.35" +version = "4.5.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8aa86934b44c19c50f87cc2790e19f54f7a67aedb64101c2e1a2e5ecfb73944" +checksum = "40b6887a1d8685cebccf115538db5c0efe625ccac9696ad45c409d96566e910f" dependencies = [ "clap_builder", "clap_derive", @@ -1458,9 +1472,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.35" +version = "4.5.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2414dbb2dd0695280da6ea9261e327479e9d37b0630f6b53ba2a11c60c679fd9" +checksum = "e0c66c08ce9f0c698cbce5c0279d0bb6ac936d8674174fe48f736533b964f59e" dependencies = [ "anstream", "anstyle", @@ -1470,14 +1484,14 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.32" +version = "4.5.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09176aae279615badda0765c0c0b3f6ed53f4709118af73cf4655d85d1530cd7" +checksum = "d2c7947ae4cc3d851207c1adb5b5e260ff0cca11446b1d6d1423788e442257ce" dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -1522,9 +1536,9 @@ dependencies = [ [[package]] name = "console" -version = "0.15.10" +version = "0.15.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" dependencies = [ "encode_unicode", "libc", @@ -1558,7 +1572,7 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", "once_cell", "tiny-keccak", ] @@ -1642,7 +1656,7 @@ dependencies = [ "anes", "cast", "ciborium", - "clap 4.5.35", + "clap 4.5.40", "criterion-plot", "futures", "is-terminal", @@ -1744,19 +1758,25 @@ dependencies = [ [[package]] name = "ctor" -version = "0.2.9" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a2785755761f3ddc1492979ce1e48d2c00d09311c39e4466429188f3dd6501" +checksum = "a4735f265ba6a1188052ca32d461028a7d1125868be18e287e756019da7607b5" dependencies = [ - "quote", - "syn 2.0.100", + "ctor-proc-macro", + "dtor", ] +[[package]] +name = "ctor-proc-macro" +version = "0.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f211af61d8efdd104f96e57adf5e426ba1bc3ed7a4ead616e15e5881fd79c4d" + [[package]] name = "darling" -version = "0.20.10" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ "darling_core", "darling_macro", @@ -1764,27 +1784,27 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.20.10" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" dependencies = [ "fnv", "ident_case", "proc-macro2", "quote", "strsim", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] name = "darling_macro" -version = "0.20.10" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ "darling_core", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -1809,14 +1829,14 @@ dependencies = [ [[package]] name = "datafusion" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "arrow-ipc", "arrow-schema", "async-trait", "bytes", - "bzip2 0.5.2", + "bzip2 0.6.0", "chrono", "criterion", "ctor", @@ -1860,7 +1880,7 @@ dependencies = [ "parking_lot", "parquet", "paste", - "rand 0.8.5", + "rand 0.9.1", "rand_distr", "regex", "rstest", @@ -1879,7 +1899,7 @@ dependencies = [ [[package]] name = "datafusion-benchmarks" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "datafusion", @@ -1891,7 +1911,7 @@ dependencies = [ "mimalloc", "object_store", "parquet", - "rand 0.8.5", + "rand 0.9.1", "serde", "serde_json", "snmalloc-rs", @@ -1903,7 +1923,7 @@ dependencies = [ [[package]] name = "datafusion-catalog" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "async-trait", @@ -1927,7 +1947,7 @@ dependencies = [ [[package]] name = "datafusion-catalog-listing" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "async-trait", @@ -1943,20 +1963,19 @@ dependencies = [ "futures", "log", "object_store", - "tempfile", "tokio", ] [[package]] name = "datafusion-cli" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "assert_cmd", "async-trait", "aws-config", "aws-credential-types", - "clap 4.5.35", + "clap 4.5.40", "ctor", "datafusion", "dirs", @@ -1964,6 +1983,7 @@ dependencies = [ "futures", "insta", "insta-cmd", + "log", "mimalloc", "object_store", "parking_lot", @@ -1978,9 +1998,9 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "46.0.1" +version = "48.0.0" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "apache-avro", "arrow", "arrow-ipc", @@ -1988,7 +2008,8 @@ dependencies = [ "chrono", "half", "hashbrown 0.14.5", - "indexmap 2.8.0", + "hex", + "indexmap 2.10.0", "insta", "libc", "log", @@ -1996,7 +2017,7 @@ dependencies = [ "parquet", "paste", "pyo3", - "rand 0.8.5", + "rand 0.9.1", "recursive", "sqlparser", "tokio", @@ -2005,7 +2026,7 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "46.0.1" +version = "48.0.0" dependencies = [ "futures", "log", @@ -2014,13 +2035,13 @@ dependencies = [ [[package]] name = "datafusion-datasource" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "async-compression", "async-trait", "bytes", - "bzip2 0.5.2", + "bzip2 0.6.0", "chrono", "criterion", "datafusion-common", @@ -2038,7 +2059,7 @@ dependencies = [ "log", "object_store", "parquet", - "rand 0.8.5", + "rand 0.9.1", "tempfile", "tokio", "tokio-util", @@ -2049,7 +2070,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-avro" -version = "46.0.1" +version = "48.0.0" dependencies = [ "apache-avro", "arrow", @@ -2074,7 +2095,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-csv" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "async-trait", @@ -2097,7 +2118,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-json" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "async-trait", @@ -2120,7 +2141,7 @@ dependencies = [ [[package]] name = "datafusion-datasource-parquet" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "async-trait", @@ -2137,6 +2158,7 @@ dependencies = [ "datafusion-physical-expr-common", "datafusion-physical-optimizer", "datafusion-physical-plan", + "datafusion-pruning", "datafusion-session", "futures", "itertools 0.14.0", @@ -2144,17 +2166,17 @@ dependencies = [ "object_store", "parking_lot", "parquet", - "rand 0.8.5", + "rand 0.9.1", "tokio", ] [[package]] name = "datafusion-doc" -version = "46.0.1" +version = "48.0.0" [[package]] name = "datafusion-examples" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "arrow-flight", @@ -2184,7 +2206,7 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "chrono", @@ -2192,19 +2214,21 @@ dependencies = [ "datafusion-common", "datafusion-expr", "futures", + "insta", "log", "object_store", "parking_lot", - "rand 0.8.5", + "rand 0.9.1", "tempfile", "url", ] [[package]] name = "datafusion-expr" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", + "async-trait", "chrono", "ctor", "datafusion-common", @@ -2214,7 +2238,8 @@ dependencies = [ "datafusion-functions-window-common", "datafusion-physical-expr-common", "env_logger", - "indexmap 2.8.0", + "indexmap 2.10.0", + "insta", "paste", "recursive", "serde_json", @@ -2223,18 +2248,18 @@ dependencies = [ [[package]] name = "datafusion-expr-common" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "datafusion-common", - "indexmap 2.8.0", + "indexmap 2.10.0", "itertools 0.14.0", "paste", ] [[package]] name = "datafusion-ffi" -version = "46.0.1" +version = "48.0.0" dependencies = [ "abi_stable", "arrow", @@ -2242,7 +2267,9 @@ dependencies = [ "async-ffi", "async-trait", "datafusion", + "datafusion-functions-aggregate-common", "datafusion-proto", + "datafusion-proto-common", "doc-comment", "futures", "log", @@ -2253,7 +2280,7 @@ dependencies = [ [[package]] name = "datafusion-functions" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "arrow-buffer", @@ -2272,7 +2299,7 @@ dependencies = [ "itertools 0.14.0", "log", "md-5", - "rand 0.8.5", + "rand 0.9.1", "regex", "sha2", "tokio", @@ -2282,9 +2309,9 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "46.0.1" +version = "48.0.0" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow", "criterion", "datafusion-common", @@ -2298,25 +2325,25 @@ dependencies = [ "half", "log", "paste", - "rand 0.8.5", + "rand 0.9.1", ] [[package]] name = "datafusion-functions-aggregate-common" -version = "46.0.1" +version = "48.0.0" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow", "criterion", "datafusion-common", "datafusion-expr-common", "datafusion-physical-expr-common", - "rand 0.8.5", + "rand 0.9.1", ] [[package]] name = "datafusion-functions-nested" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "arrow-ord", @@ -2327,17 +2354,18 @@ dependencies = [ "datafusion-expr", "datafusion-functions", "datafusion-functions-aggregate", + "datafusion-functions-aggregate-common", "datafusion-macros", "datafusion-physical-expr-common", "itertools 0.14.0", "log", "paste", - "rand 0.8.5", + "rand 0.9.1", ] [[package]] name = "datafusion-functions-table" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "async-trait", @@ -2351,7 +2379,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "datafusion-common", @@ -2367,7 +2395,7 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" -version = "46.0.1" +version = "48.0.0" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -2375,20 +2403,21 @@ dependencies = [ [[package]] name = "datafusion-macros" -version = "46.0.1" +version = "48.0.0" dependencies = [ "datafusion-expr", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] name = "datafusion-optimizer" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "async-trait", "chrono", + "criterion", "ctor", "datafusion-common", "datafusion-expr", @@ -2398,7 +2427,7 @@ dependencies = [ "datafusion-physical-expr", "datafusion-sql", "env_logger", - "indexmap 2.8.0", + "indexmap 2.10.0", "insta", "itertools 0.14.0", "log", @@ -2409,9 +2438,9 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "46.0.1" +version = "48.0.0" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow", "criterion", "datafusion-common", @@ -2422,21 +2451,21 @@ dependencies = [ "datafusion-physical-expr-common", "half", "hashbrown 0.14.5", - "indexmap 2.8.0", + "indexmap 2.10.0", "insta", "itertools 0.14.0", "log", "paste", - "petgraph", - "rand 0.8.5", + "petgraph 0.8.2", + "rand 0.9.1", "rstest", ] [[package]] name = "datafusion-physical-expr-common" -version = "46.0.1" +version = "48.0.0" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow", "datafusion-common", "datafusion-expr-common", @@ -2446,7 +2475,7 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "datafusion-common", @@ -2457,17 +2486,19 @@ dependencies = [ "datafusion-physical-expr", "datafusion-physical-expr-common", "datafusion-physical-plan", + "datafusion-pruning", "insta", "itertools 0.14.0", "log", "recursive", + "tokio", ] [[package]] name = "datafusion-physical-plan" -version = "46.0.1" +version = "48.0.0" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow", "arrow-ord", "arrow-schema", @@ -2486,13 +2517,13 @@ dependencies = [ "futures", "half", "hashbrown 0.14.5", - "indexmap 2.8.0", + "indexmap 2.10.0", "insta", "itertools 0.14.0", "log", "parking_lot", "pin-project-lite", - "rand 0.8.5", + "rand 0.9.1", "rstest", "rstest_reuse", "tempfile", @@ -2501,7 +2532,7 @@ dependencies = [ [[package]] name = "datafusion-proto" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "chrono", @@ -2518,13 +2549,12 @@ dependencies = [ "prost", "serde", "serde_json", - "strum 0.27.1", "tokio", ] [[package]] name = "datafusion-proto-common" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "datafusion-common", @@ -2535,9 +2565,28 @@ dependencies = [ "serde_json", ] +[[package]] +name = "datafusion-pruning" +version = "48.0.0" +dependencies = [ + "arrow", + "arrow-schema", + "datafusion-common", + "datafusion-datasource", + "datafusion-expr", + "datafusion-expr-common", + "datafusion-functions-nested", + "datafusion-physical-expr", + "datafusion-physical-expr-common", + "datafusion-physical-plan", + "insta", + "itertools 0.14.0", + "log", +] + [[package]] name = "datafusion-session" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "async-trait", @@ -2557,9 +2606,23 @@ dependencies = [ "tokio", ] +[[package]] +name = "datafusion-spark" +version = "48.0.0" +dependencies = [ + "arrow", + "datafusion-catalog", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-functions", + "datafusion-macros", + "log", +] + [[package]] name = "datafusion-sql" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "bigdecimal", @@ -2571,7 +2634,7 @@ dependencies = [ "datafusion-functions-nested", "datafusion-functions-window", "env_logger", - "indexmap 2.8.0", + "indexmap 2.10.0", "insta", "log", "paste", @@ -2583,15 +2646,17 @@ dependencies = [ [[package]] name = "datafusion-sqllogictest" -version = "46.0.1" +version = "48.0.0" dependencies = [ "arrow", "async-trait", "bigdecimal", "bytes", "chrono", - "clap 4.5.35", + "clap 4.5.40", "datafusion", + "datafusion-spark", + "datafusion-substrait", "env_logger", "futures", "half", @@ -2614,7 +2679,7 @@ dependencies = [ [[package]] name = "datafusion-substrait" -version = "46.0.1" +version = "48.0.0" dependencies = [ "async-recursion", "async-trait", @@ -2634,7 +2699,7 @@ dependencies = [ [[package]] name = "datafusion-wasmtest" -version = "46.0.1" +version = "48.0.0" dependencies = [ "chrono", "console_error_panic_hook", @@ -2645,7 +2710,7 @@ dependencies = [ "datafusion-optimizer", "datafusion-physical-plan", "datafusion-sql", - "getrandom 0.2.15", + "getrandom 0.3.3", "insta", "object_store", "tokio", @@ -2656,9 +2721,9 @@ dependencies = [ [[package]] name = "deranged" -version = "0.3.11" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" dependencies = [ "powerfmt", "serde", @@ -2710,7 +2775,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -2721,15 +2786,30 @@ checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" [[package]] name = "docker_credential" -version = "1.3.1" +version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31951f49556e34d90ed28342e1df7e1cb7a229c4cab0aecc627b5d91edd41d07" +checksum = "1d89dfcba45b4afad7450a99b39e751590463e45c04728cf555d36bb66940de8" dependencies = [ "base64 0.21.7", "serde", "serde_json", ] +[[package]] +name = "dtor" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97cbdf2ad6846025e8e25df05171abfb30e3ababa12ee0a0e44b9bbe570633a8" +dependencies = [ + "dtor-proc-macro", +] + +[[package]] +name = "dtor-proc-macro" +version = "0.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7454e41ff9012c00d53cf7f475c5e3afa3b91b7c90568495495e8d9bf47a1055" + [[package]] name = "dunce" version = "1.0.5" @@ -2738,9 +2818,9 @@ checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" [[package]] name = "dyn-clone" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "feeef44e73baff3a26d371801df019877a9866a8c493d315ab00177843314f35" +checksum = "1c7a8fb8a9fbf66c1f703fe16184d10ca0ee9d23be5b4436400408ba54a95005" [[package]] name = "educe" @@ -2751,14 +2831,14 @@ dependencies = [ "enum-ordinalize", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] name = "either" -version = "1.13.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] name = "encode_unicode" @@ -2789,7 +2869,7 @@ checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -2804,9 +2884,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.7" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3716d7a920fb4fac5d84e9d4bce8ceb321e9414b4409da61b07b75c1e3d0697" +checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" dependencies = [ "anstream", "anstyle", @@ -2823,9 +2903,9 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" -version = "0.3.10" +version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" +checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18" dependencies = [ "libc", "windows-sys 0.59.0", @@ -2833,9 +2913,9 @@ dependencies = [ [[package]] name = "error-code" -version = "3.3.1" +version = "3.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5d9305ccc6942a704f4335694ecd3de2ea531b114ac2d51f5f843750787a92f" +checksum = "dea2df4cf52843e0452895c455a1a2cfbb842a1e7329671acf418fdc53ed4c59" [[package]] name = "escape8259" @@ -2845,13 +2925,13 @@ checksum = "5692dd7b5a1978a5aeb0ce83b7655c58ca8efdcb79d21036ea249da95afec2c6" [[package]] name = "etcetera" -version = "0.8.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" +checksum = "26c7b13d0780cb82722fd59f6f57f925e143427e4a75313a6c77243bf5326ae6" dependencies = [ "cfg-if", "home", - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] @@ -2868,13 +2948,13 @@ checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "fd-lock" -version = "4.0.2" +version = "4.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e5768da2206272c81ef0b5e951a41862938a6070da63bcea197899942d3b947" +checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78" dependencies = [ "cfg-if", - "rustix 0.38.44", - "windows-sys 0.52.0", + "rustix 1.0.7", + "windows-sys 0.59.0", ] [[package]] @@ -2931,15 +3011,15 @@ version = "25.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1045398c1bfd89168b5fd3f1fc11f6e70b34f6f66300c87d44d3de849463abf1" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "rustc_version", ] [[package]] name = "flate2" -version = "1.1.1" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ced92e76e966ca2fd84c8f7aa01a4aea65b0eb6648d72f7c8f3e2764a67fece" +checksum = "4a3d7db9596fecd151c5f638c0ee5d5bd487b6e0ea232e5dc96d5250f6f94b1d" dependencies = [ "crc32fast", "libz-rs-sys", @@ -2963,9 +3043,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "foldhash" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" [[package]] name = "form_urlencoded" @@ -3053,7 +3133,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -3129,9 +3209,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", "js-sys", @@ -3142,14 +3222,16 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.3.1" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", + "js-sys", "libc", - "wasi 0.13.3+wasi-0.2.2", - "windows-targets 0.52.6", + "r-efi", + "wasi 0.14.2+wasi-0.2.4", + "wasm-bindgen", ] [[package]] @@ -3179,17 +3261,17 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.8" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5017294ff4bb30944501348f6f8e42e6ad28f42c8bbef7a74029aff064a4e3c2" +checksum = "a9421a676d1b147b16b82c9225157dc629087ef8ec4d5e2960f9437a90dac0a5" dependencies = [ "atomic-waker", "bytes", "fnv", "futures-core", "futures-sink", - "http 1.2.0", - "indexmap 2.8.0", + "http 1.3.1", + "indexmap 2.10.0", "slab", "tokio", "tokio-util", @@ -3198,9 +3280,9 @@ dependencies = [ [[package]] name = "half" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7db2ff139bba50379da6aa0766b52fdcb62cb5b263009b09ed58ba604e14bbd1" +checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ "cfg-if", "crunchy", @@ -3222,15 +3304,15 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "allocator-api2", ] [[package]] name = "hashbrown" -version = "0.15.2" +version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" dependencies = [ "allocator-api2", "equivalent", @@ -3254,9 +3336,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "hermit-abi" -version = "0.4.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" +checksum = "f154ce46856750ed433c8649605bf7ed2de3bc35fd9d2a9f30cddd873c80cb08" [[package]] name = "hex" @@ -3295,9 +3377,9 @@ dependencies = [ [[package]] name = "http" -version = "1.2.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" dependencies = [ "bytes", "fnv", @@ -3322,27 +3404,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.2.0", + "http 1.3.1", ] [[package]] name = "http-body-util" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", - "futures-util", - "http 1.2.0", + "futures-core", + "http 1.3.1", "http-body 1.0.1", "pin-project-lite", ] [[package]] name = "httparse" -version = "1.10.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2d708df4e7140240a16cd6ab0ab65c972d7433ab77819ea693fde9c43811e2a" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" [[package]] name = "httpdate" @@ -3352,9 +3434,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "humantime" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +checksum = "9b112acc8b3adf4b107a8ec20977da0273a8c386765a3ec0229bd500a1443f9f" [[package]] name = "hyper" @@ -3366,7 +3448,7 @@ dependencies = [ "futures-channel", "futures-util", "h2", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "httparse", "httpdate", @@ -3399,7 +3481,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" dependencies = [ "futures-util", - "http 1.2.0", + "http 1.3.1", "hyper", "hyper-util", "rustls", @@ -3425,16 +3507,17 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.10" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" +checksum = "cf9f1e950e0d9d1d3c47184416723cf29c0d1f93bd8cccf37e4beb6b44f31710" dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "hyper", + "libc", "pin-project-lite", "socket2", "tokio", @@ -3459,16 +3542,17 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.61" +version = "0.1.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" dependencies = [ "android_system_properties", "core-foundation-sys", "iana-time-zone-haiku", "js-sys", + "log", "wasm-bindgen", - "windows-core 0.52.0", + "windows-core", ] [[package]] @@ -3482,21 +3566,22 @@ dependencies = [ [[package]] name = "icu_collections" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +checksum = "200072f5d0e3614556f94a9930d5dc3e0662a652823904c3a75dc3b0af7fee47" dependencies = [ "displaydoc", + "potential_utf", "yoke", "zerofrom", "zerovec", ] [[package]] -name = "icu_locid" -version = "1.5.0" +name = "icu_locale_core" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +checksum = "0cde2700ccaed3872079a65fb1a78f6c0a36c91570f28755dda67bc8f7d9f00a" dependencies = [ "displaydoc", "litemap", @@ -3505,31 +3590,11 @@ dependencies = [ "zerovec", ] -[[package]] -name = "icu_locid_transform" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" -dependencies = [ - "displaydoc", - "icu_locid", - "icu_locid_transform_data", - "icu_provider", - "tinystr", - "zerovec", -] - -[[package]] -name = "icu_locid_transform_data" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" - [[package]] name = "icu_normalizer" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +checksum = "436880e8e18df4d7bbc06d58432329d6458cc84531f7ac5f024e93deadb37979" dependencies = [ "displaydoc", "icu_collections", @@ -3537,67 +3602,54 @@ dependencies = [ "icu_properties", "icu_provider", "smallvec", - "utf16_iter", - "utf8_iter", - "write16", "zerovec", ] [[package]] name = "icu_normalizer_data" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" +checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3" [[package]] name = "icu_properties" -version = "1.5.1" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b" dependencies = [ "displaydoc", "icu_collections", - "icu_locid_transform", + "icu_locale_core", "icu_properties_data", "icu_provider", - "tinystr", + "potential_utf", + "zerotrie", "zerovec", ] [[package]] name = "icu_properties_data" -version = "1.5.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" +checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632" [[package]] name = "icu_provider" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +checksum = "03c80da27b5f4187909049ee2d72f276f0d9f99a42c306bd0131ecfe04d8e5af" dependencies = [ "displaydoc", - "icu_locid", - "icu_provider_macros", + "icu_locale_core", "stable_deref_trait", "tinystr", "writeable", "yoke", "zerofrom", + "zerotrie", "zerovec", ] -[[package]] -name = "icu_provider_macros" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.100", -] - [[package]] name = "ident_case" version = "1.0.1" @@ -3617,9 +3669,9 @@ dependencies = [ [[package]] name = "idna_adapter" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" dependencies = [ "icu_normalizer", "icu_properties", @@ -3638,12 +3690,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.8.0" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3954d50fe15b02142bf25d3b8bdadb634ec3948f103d04ffe3031bc8fe9d7058" +checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" dependencies = [ "equivalent", - "hashbrown 0.15.2", + "hashbrown 0.15.3", "serde", ] @@ -3662,21 +3714,19 @@ dependencies = [ [[package]] name = "indoc" -version = "2.0.5" +version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" [[package]] name = "insta" -version = "1.42.2" +version = "1.43.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50259abbaa67d11d2bcafc7ba1d094ed7a0c70e3ce893f0d0997f73558cb3084" +checksum = "154934ea70c58054b556dd430b99a98c2a7ff5309ac9891597e339b5c28f4371" dependencies = [ "console", "globset", - "linked-hash-map", "once_cell", - "pin-project", "regex", "serde", "similar", @@ -3708,9 +3758,9 @@ checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] name = "is-terminal" -version = "0.4.15" +version = "0.4.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e19b23d53f35ce9f56aebc7d1bb4e6ac1e9c0db7ac85c8d1760c04379edced37" +checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" dependencies = [ "hermit-abi", "libc", @@ -3752,15 +3802,15 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "jiff" -version = "0.2.4" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d699bc6dfc879fb1bf9bdff0d4c56f0884fc6f0d0eb0fba397a6d00cd9a6b85e" +checksum = "a194df1107f33c79f4f93d02c80798520551949d59dfad22b6157048a88cca93" dependencies = [ "jiff-static", "log", @@ -3771,21 +3821,22 @@ dependencies = [ [[package]] name = "jiff-static" -version = "0.2.4" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d16e75759ee0aa64c57a56acbf43916987b20c77373cb7e808979e02b93c9f9" +checksum = "6c6e1db7ed32c6c71b759497fae34bf7933636f75a251b9e736555da426f6442" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] name = "jobserver" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" dependencies = [ + "getrandom 0.3.3", "libc", ] @@ -3875,11 +3926,17 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "libbz2-rs-sys" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "775bf80d5878ab7c2b1080b5351a48b2f737d9f6f8b383574eebcc22be0dfccb" + [[package]] name = "libc" -version = "0.2.171" +version = "0.2.174" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" +checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" [[package]] name = "libflate" @@ -3917,25 +3974,25 @@ dependencies = [ [[package]] name = "libloading" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" +checksum = "6a793df0d7afeac54f95b471d3af7f0d4fb975699f972341a4b76988d49cdf0c" dependencies = [ "cfg-if", - "windows-targets 0.52.6", + "windows-targets 0.53.0", ] [[package]] name = "libm" -version = "0.2.11" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" +checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] name = "libmimalloc-sys" -version = "0.1.42" +version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec9d6fac27761dabcd4ee73571cdb06b7022dc99089acbe5435691edffaac0f4" +checksum = "bf88cd67e9de251c1781dbe2f641a1a3ad66eaae831b8a2c38fbdc5ddae16d4d" dependencies = [ "cc", "libc", @@ -3947,9 +4004,9 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "libc", - "redox_syscall 0.5.8", + "redox_syscall 0.5.12", ] [[package]] @@ -3960,25 +4017,19 @@ checksum = "5297962ef19edda4ce33aaa484386e0a5b3d7f2f4e037cbeee00503ef6b29d33" dependencies = [ "anstream", "anstyle", - "clap 4.5.35", + "clap 4.5.40", "escape8259", ] [[package]] name = "libz-rs-sys" -version = "0.4.2" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "902bc563b5d65ad9bba616b490842ef0651066a1a1dc3ce1087113ffcb873c8d" +checksum = "172a788537a2221661b480fee8dc5f96c580eb34fa88764d3205dc356c7e4221" dependencies = [ "zlib-rs", ] -[[package]] -name = "linked-hash-map" -version = "0.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" - [[package]] name = "linux-raw-sys" version = "0.4.15" @@ -3987,15 +4038,15 @@ checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "linux-raw-sys" -version = "0.9.2" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db9c683daf087dc577b7506e9695b3d556a9f3849903fa28186283afd6809e9" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" [[package]] name = "litemap" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" +checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" [[package]] name = "lock_api" @@ -4013,6 +4064,12 @@ version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "lz4_flex" version = "0.11.3" @@ -4066,9 +4123,9 @@ dependencies = [ [[package]] name = "mimalloc" -version = "0.1.46" +version = "0.1.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "995942f432bbb4822a7e9c3faa87a695185b0d09273ba85f097b54f4e458f2af" +checksum = "b1791cbe101e95af5764f06f20f6760521f7158f69dbf9d6baf941ee1bf6bc40" dependencies = [ "libmimalloc-sys", ] @@ -4117,9 +4174,9 @@ dependencies = [ [[package]] name = "multimap" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" +checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" [[package]] name = "nibble_vec" @@ -4132,11 +4189,11 @@ dependencies = [ [[package]] name = "nix" -version = "0.29.0" +version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "cfg-if", "cfg_aliases", "libc", @@ -4266,11 +4323,21 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "objc2-core-foundation" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daeaf60f25471d26948a1c2f840e3f7d86f4109e3af4e8e4b5cd70c39690d925" +checksum = "1c10c2894a6fed806ade6027bcd50662746363a9589d3ec9d9bef30a4e4bc166" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", +] + +[[package]] +name = "objc2-io-kit" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71c1c64d6120e51cd86033f67176b1cb66780c2efe34dec55176f77befd93c0a" +dependencies = [ + "libc", + "objc2-core-foundation", ] [[package]] @@ -4284,9 +4351,9 @@ dependencies = [ [[package]] name = "object_store" -version = "0.12.0" +version = "0.12.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9ce831b09395f933addbc56d894d889e4b226eba304d4e7adbab591e26daf1e" +checksum = "7781f96d79ed0f961a7021424ab01840efbda64ae7a505aaea195efc91eaaec4" dependencies = [ "async-trait", "base64 0.22.1", @@ -4294,7 +4361,7 @@ dependencies = [ "chrono", "form_urlencoded", "futures", - "http 1.2.0", + "http 1.3.1", "http-body-util", "humantime", "hyper", @@ -4303,7 +4370,7 @@ dependencies = [ "parking_lot", "percent-encoding", "quick-xml", - "rand 0.8.5", + "rand 0.9.1", "reqwest", "ring", "rustls-pemfile", @@ -4315,19 +4382,21 @@ dependencies = [ "tracing", "url", "walkdir", + "wasm-bindgen-futures", + "web-time", ] [[package]] name = "once_cell" -version = "1.20.3" +version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" [[package]] name = "oorandom" -version = "11.1.4" +version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" [[package]] name = "openssl-probe" @@ -4364,9 +4433,9 @@ checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] name = "owo-colors" -version = "4.1.0" +version = "4.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb37767f6569cd834a413442455e0f066d0d522de8630436e2a1761d9726ba56" +checksum = "26995317201fa17f3656c36716aed4a7c81743a9634ac4c99c0eeda495db0cec" [[package]] name = "parking_lot" @@ -4386,18 +4455,18 @@ checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.5.8", + "redox_syscall 0.5.12", "smallvec", "windows-targets 0.52.6", ] [[package]] name = "parquet" -version = "55.0.0" +version = "55.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd31a8290ac5b19f09ad77ee7a1e6a541f1be7674ad410547d5f1eef6eef4a9c" +checksum = "b17da4150748086bd43352bc77372efa9b6e3dbd06a04831d2a98c041c225cfa" dependencies = [ - "ahash 0.8.11", + "ahash 0.8.12", "arrow-array", "arrow-buffer", "arrow-cast", @@ -4412,12 +4481,13 @@ dependencies = [ "flate2", "futures", "half", - "hashbrown 0.15.2", + "hashbrown 0.15.3", "lz4_flex", "num", "num-bigint", "object_store", "paste", + "ring", "seq-macro", "simdutf8", "snap", @@ -4449,7 +4519,7 @@ dependencies = [ "regex", "regex-syntax", "structmeta", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -4517,7 +4587,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ "fixedbitset", - "indexmap 2.8.0", + "indexmap 2.10.0", +] + +[[package]] +name = "petgraph" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54acf3a685220b533e437e264e4d932cfbdc4cc7ec0cd232ed73c08d03b8a7ca" +dependencies = [ + "fixedbitset", + "hashbrown 0.15.3", + "indexmap 2.10.0", + "serde", ] [[package]] @@ -4560,22 +4642,22 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.9" +version = "1.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfe2e71e1471fe07709406bf725f710b02927c9c54b2b5b2ec0e8087d97c327d" +checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.9" +version = "1.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6e859e6e5bd50440ab63c47e3ebabc90f26251f7c73c3d3e837b74a1cc3fa67" +checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -4592,9 +4674,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" [[package]] name = "plotters" @@ -4626,9 +4708,9 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" +checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" [[package]] name = "portable-atomic-util" @@ -4648,7 +4730,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -4664,7 +4746,7 @@ dependencies = [ "hmac", "md-5", "memchr", - "rand 0.9.0", + "rand 0.9.1", "sha2", "stringprep", ] @@ -4682,6 +4764,15 @@ dependencies = [ "postgres-protocol", ] +[[package]] +name = "potential_utf" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7c30837279ca13e7c867e9e40053bc68740f988cb07f7ca6df43cc734b585" +dependencies = [ + "zerovec", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -4690,11 +4781,11 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.20" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" dependencies = [ - "zerocopy 0.7.35", + "zerocopy", ] [[package]] @@ -4729,19 +4820,19 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.31" +version = "0.2.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb" +checksum = "664ec5419c51e34154eec046ebcba56312d5a2fc3b09a06da188e1ad21afadf6" dependencies = [ "proc-macro2", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] name = "proc-macro-crate" -version = "3.2.0" +version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" +checksum = "edce586971a4dfaa28950c6f18ed55e0406c1ab88bbce2c6f6293a7aaba73d35" dependencies = [ "toml_edit", ] @@ -4772,9 +4863,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.93" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" dependencies = [ "unicode-ident", ] @@ -4796,16 +4887,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" dependencies = [ "heck 0.5.0", - "itertools 0.13.0", + "itertools 0.10.5", "log", "multimap", "once_cell", - "petgraph", + "petgraph 0.7.1", "prettyplease", "prost", "prost-types", "regex", - "syn 2.0.100", + "syn 2.0.104", "tempfile", ] @@ -4816,10 +4907,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" dependencies = [ "anyhow", - "itertools 0.13.0", + "itertools 0.10.5", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -4842,9 +4933,9 @@ dependencies = [ [[package]] name = "psm" -version = "0.1.25" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f58e5423e24c18cc840e1c98370b3993c6649cd1678b4d24318bcf0a083cbe88" +checksum = "6e944464ec8536cd1beb0bbfd96987eb5e3b72f2ecdafdc5c769a37f1fa2ae1f" dependencies = [ "cc", ] @@ -4871,9 +4962,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17da310086b068fbdcefbba30aeb3721d5bb9af8db4987d6735b2183ca567229" +checksum = "e5203598f366b11a02b13aa20cab591229ff0a89fd121a308a5df751d5fc9219" dependencies = [ "cfg-if", "indoc", @@ -4889,9 +4980,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e27165889bd793000a098bb966adc4300c312497ea25cf7a690a9f0ac5aa5fc1" +checksum = "99636d423fa2ca130fa5acde3059308006d46f98caac629418e53f7ebb1e9999" dependencies = [ "once_cell", "target-lexicon", @@ -4899,9 +4990,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05280526e1dbf6b420062f3ef228b78c0c54ba94e157f5cb724a609d0f2faabc" +checksum = "78f9cf92ba9c409279bc3305b5409d90db2d2c22392d443a87df3a1adad59e33" dependencies = [ "libc", "pyo3-build-config", @@ -4909,27 +5000,27 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c3ce5686aa4d3f63359a5100c62a127c9f15e8398e5fdeb5deef1fed5cd5f44" +checksum = "0b999cb1a6ce21f9a6b147dcf1be9ffedf02e0043aec74dc390f3007047cecd9" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] name = "pyo3-macros-backend" -version = "0.24.1" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4cf6faa0cbfb0ed08e89beb8103ae9724eb4750e3a78084ba4017cbe94f3855" +checksum = "822ece1c7e1012745607d5cf0bcb2874769f0f7cb34c4cde03b9358eb9ef911a" dependencies = [ "heck 0.5.0", "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -4940,9 +5031,9 @@ checksum = "5a651516ddc9168ebd67b24afd085a718be02f8858fe406591b013d101ce2f40" [[package]] name = "quick-xml" -version = "0.37.2" +version = "0.37.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "165859e9e55f79d67b96c5d96f4e88b6f2695a1972849c15a6a3f5c59fc2c003" +checksum = "331e97a1af0bf59823e6eadffe373d7b27f485be8748f71471c662c1f269b7fb" dependencies = [ "memchr", "serde", @@ -4950,11 +5041,12 @@ dependencies = [ [[package]] name = "quinn" -version = "0.11.6" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" +checksum = "626214629cda6781b6dc1d316ba307189c85ba657213ce642d9c77670f8202c8" dependencies = [ "bytes", + "cfg_aliases", "pin-project-lite", "quinn-proto", "quinn-udp", @@ -4964,17 +5056,19 @@ dependencies = [ "thiserror 2.0.12", "tokio", "tracing", + "web-time", ] [[package]] name = "quinn-proto" -version = "0.11.9" +version = "0.11.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" +checksum = "49df843a9161c85bb8aae55f101bc0bac8bcafd637a620d9122fd7e0b2f7422e" dependencies = [ "bytes", - "getrandom 0.2.15", - "rand 0.8.5", + "getrandom 0.3.3", + "lru-slab", + "rand 0.9.1", "ring", "rustc-hash 2.1.1", "rustls", @@ -4988,9 +5082,9 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.10" +version = "0.5.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e46f3055866785f6b92bc6164b76be02ca8f2eb4b002c0354b28cf4c119e5944" +checksum = "ee4e529991f949c5e25755532370b8af5d114acae52326361d68d47af64aa842" dependencies = [ "cfg_aliases", "libc", @@ -5009,6 +5103,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5" + [[package]] name = "radium" version = "0.7.0" @@ -5038,13 +5138,12 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" dependencies = [ "rand_chacha 0.9.0", - "rand_core 0.9.1", - "zerocopy 0.8.18", + "rand_core 0.9.3", ] [[package]] @@ -5064,7 +5163,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core 0.9.1", + "rand_core 0.9.3", ] [[package]] @@ -5073,27 +5172,26 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", ] [[package]] name = "rand_core" -version = "0.9.1" +version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a88e0da7a2c97baa202165137c158d0a2e824ac465d13d81046727b34cb247d3" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom 0.3.1", - "zerocopy 0.8.18", + "getrandom 0.3.3", ] [[package]] name = "rand_distr" -version = "0.4.3" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" dependencies = [ "num-traits", - "rand 0.8.5", + "rand 0.9.1", ] [[package]] @@ -5133,7 +5231,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -5147,11 +5245,11 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.8" +version = "0.5.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" +checksum = "928fca9cf2aa042393a8325b9ead81d2f0df4cb12e1e24cef072922ccd99c5af" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", ] [[package]] @@ -5160,7 +5258,7 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd6f9d3d47bdd2ad6945c5015a226ec6155d0bcdfd8f7cd29f86b71f8de99d2b" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", "libredox", "thiserror 2.0.12", ] @@ -5206,7 +5304,7 @@ version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ef7fa9ed0256d64a688a3747d0fef7a88851c18a5e1d57f115f38ec2e09366" dependencies = [ - "hashbrown 0.15.2", + "hashbrown 0.15.3", "memchr", ] @@ -5236,16 +5334,16 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.12" +version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da" +checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb" dependencies = [ "base64 0.22.1", "bytes", "futures-core", "futures-util", "h2", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "http-body-util", "hyper", @@ -5282,13 +5380,13 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.13" +version = "0.17.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ac5d832aa16abd7d1def883a8545280c20a60f523a370aa3a9617c2b8550ee" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" dependencies = [ "cc", "cfg-if", - "getrandom 0.2.15", + "getrandom 0.2.16", "libc", "untrusted", "windows-sys 0.52.0", @@ -5331,9 +5429,9 @@ checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" [[package]] name = "rstest" -version = "0.24.0" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03e905296805ab93e13c1ec3a03f4b6c4f35e9498a3d5fa96dc626d22c03cd89" +checksum = "6fc39292f8613e913f7df8fa892b8944ceb47c247b78e1b1ae2f09e019be789d" dependencies = [ "futures-timer", "futures-util", @@ -5343,9 +5441,9 @@ dependencies = [ [[package]] name = "rstest_macros" -version = "0.24.0" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef0053bbffce09062bee4bcc499b0fbe7a57b879f1efe088d6d8d4c7adcdef9b" +checksum = "1f168d99749d307be9de54d23fd226628d99768225ef08f6ffb52e0182a27746" dependencies = [ "cfg-if", "glob", @@ -5355,7 +5453,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.100", + "syn 2.0.104", "unicode-ident", ] @@ -5367,14 +5465,14 @@ checksum = "b3a8fb4672e840a587a66fc577a5491375df51ddb88f2a2c2a792598c326fe14" dependencies = [ "quote", "rand 0.8.5", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] name = "rust_decimal" -version = "1.37.1" +version = "1.37.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "faa7de2ba56ac291bd90c6b9bece784a52ae1411f9506544b3eae36dd2356d50" +checksum = "b203a6425500a03e0919c42d3c47caca51e79f1132046626d2c8871c5092035d" dependencies = [ "arrayvec", "borsh", @@ -5420,7 +5518,7 @@ version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "errno", "libc", "linux-raw-sys 0.4.15", @@ -5429,22 +5527,22 @@ dependencies = [ [[package]] name = "rustix" -version = "1.0.2" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7178faa4b75a30e269c71e61c353ce2748cf3d76f0c44c393f4e60abf49b825" +checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "errno", "libc", - "linux-raw-sys 0.9.2", + "linux-raw-sys 0.9.4", "windows-sys 0.59.0", ] [[package]] name = "rustls" -version = "0.23.23" +version = "0.23.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395" +checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321" dependencies = [ "aws-lc-rs", "once_cell", @@ -5478,18 +5576,19 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" +checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" dependencies = [ "web-time", + "zeroize", ] [[package]] name = "rustls-webpki" -version = "0.102.8" +version = "0.103.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435" dependencies = [ "aws-lc-rs", "ring", @@ -5499,17 +5598,17 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.19" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" +checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" [[package]] name = "rustyline" -version = "15.0.0" +version = "16.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ee1e066dc922e513bda599c6ccb5f3bb2b0ea5870a579448f2622993f0a9a2f" +checksum = "62fd9ca5ebc709e8535e8ef7c658eb51457987e48c98ead2be482172accc408d" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "cfg-if", "clipboard-win", "fd-lock", @@ -5527,9 +5626,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.19" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" [[package]] name = "same-file" @@ -5570,7 +5669,7 @@ dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -5591,7 +5690,7 @@ version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", "core-foundation", "core-foundation-sys", "libc", @@ -5619,9 +5718,9 @@ dependencies = [ [[package]] name = "seq-macro" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" +checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" [[package]] name = "serde" @@ -5634,9 +5733,9 @@ dependencies = [ [[package]] name = "serde_bytes" -version = "0.11.15" +version = "0.11.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "387cc504cb06bb40a96c8e04e951fe01854cf6bc921053c954e4a606d9675c6a" +checksum = "8437fd221bde2d4ca316d61b90e337e9e702b3820b87d63caa9ba6c02bd06d96" dependencies = [ "serde", ] @@ -5649,7 +5748,7 @@ checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -5660,7 +5759,7 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -5677,13 +5776,13 @@ dependencies = [ [[package]] name = "serde_repr" -version = "0.1.19" +version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c64451ba24fc7a6a2d60fc75dd9c83c90903b19028d4eff35e88fc1e86564e9" +checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -5695,7 +5794,7 @@ dependencies = [ "proc-macro2", "quote", "serde", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -5720,7 +5819,7 @@ dependencies = [ "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.8.0", + "indexmap 2.10.0", "serde", "serde_derive", "serde_json", @@ -5737,7 +5836,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -5746,7 +5845,7 @@ version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ - "indexmap 2.8.0", + "indexmap 2.10.0", "itoa", "ryu", "serde", @@ -5755,9 +5854,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.10.8" +version = "0.10.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", "cpufeatures", @@ -5781,9 +5880,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook-registry" -version = "1.4.2" +version = "1.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410" dependencies = [ "libc", ] @@ -5817,9 +5916,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.14.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" +checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" [[package]] name = "snap" @@ -5847,9 +5946,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.8" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" +checksum = "4f5fd57c80058a56cf5c777ab8a126398ece8e442983605d280a44ce79d0edef" dependencies = [ "libc", "windows-sys 0.52.0", @@ -5857,9 +5956,9 @@ dependencies = [ [[package]] name = "sqllogictest" -version = "0.28.0" +version = "0.28.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b2f0b80fc250ed3fdd82fc88c0ada5ad62ee1ed5314ac5474acfa52082f518" +checksum = "9fcbf91368a8d6807093d94f274fa4d0978cd78a310fee1d20368c545a606f7a" dependencies = [ "async-trait", "educe", @@ -5899,7 +5998,7 @@ checksum = "da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -5910,9 +6009,9 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "stacker" -version = "0.1.18" +version = "0.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d08feb8f695b465baed819b03c128dc23f57a694510ab1f06c77f763975685e" +checksum = "cddb07e32ddb770749da91081d8d0ac3a16f1a569a18b20348cd371f5dead06b" dependencies = [ "cc", "cfg-if", @@ -5953,7 +6052,7 @@ dependencies = [ "proc-macro2", "quote", "structmeta-derive", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -5964,7 +6063,7 @@ checksum = "152a0b65a590ff6c3da95cabe2353ee04e6167c896b28e3b14478c2636c922fc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -5997,15 +6096,6 @@ version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" -[[package]] -name = "strum" -version = "0.27.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f64def088c51c9510a8579e3c5d67c65349dcf755e5479ad3d010aa6454e2c32" -dependencies = [ - "strum_macros 0.27.1", -] - [[package]] name = "strum_macros" version = "0.26.4" @@ -6016,27 +6106,14 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.100", -] - -[[package]] -name = "strum_macros" -version = "0.27.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8" -dependencies = [ - "heck 0.5.0", - "proc-macro2", - "quote", - "rustversion", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] name = "subst" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33e7942675ea19db01ef8cf15a1e6443007208e6c74568bd64162da26d40160d" +checksum = "0a9a86e5144f63c2d18334698269a8bfae6eece345c70b64821ea5b35054ec99" dependencies = [ "memchr", "unicode-width 0.1.14", @@ -6044,9 +6121,9 @@ dependencies = [ [[package]] name = "substrait" -version = "0.55.0" +version = "0.57.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3a359aeb711c1e1944c0c4178bbb2d679d39237ac5bfe28f7e0506e522e5ce6" +checksum = "6bc9c2fe95fd228d6178a0a6f0eadc758231c6df1eeabaa3517dfcd5454e4ef2" dependencies = [ "heck 0.5.0", "pbjson", @@ -6063,7 +6140,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", - "syn 2.0.100", + "syn 2.0.104", "typify", "walkdir", ] @@ -6087,9 +6164,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.100" +version = "2.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" +checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" dependencies = [ "proc-macro2", "quote", @@ -6107,25 +6184,26 @@ dependencies = [ [[package]] name = "synstructure" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] name = "sysinfo" -version = "0.34.2" +version = "0.35.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4b93974b3d3aeaa036504b8eefd4c039dced109171c1ae973f1dc63b2c7e4b2" +checksum = "3c3ffa3e4ff2b324a57f7aeb3c349656c7b127c3c189520251a648102a92496e" dependencies = [ "libc", "memchr", "ntapi", "objc2-core-foundation", + "objc2-io-kit", "windows", ] @@ -6143,14 +6221,14 @@ checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" [[package]] name = "tempfile" -version = "3.19.1" +version = "3.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf" +checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" dependencies = [ "fastrand", - "getrandom 0.3.1", + "getrandom 0.3.3", "once_cell", - "rustix 1.0.2", + "rustix 1.0.7", "windows-sys 0.59.0", ] @@ -6168,14 +6246,14 @@ dependencies = [ "chrono-tz", "datafusion-common", "env_logger", - "rand 0.8.5", + "rand 0.9.1", ] [[package]] name = "testcontainers" -version = "0.23.3" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59a4f01f39bb10fc2a5ab23eb0d888b1e2bb168c157f61a1b98e6c501c639c74" +checksum = "23bb7577dca13ad86a78e8271ef5d322f37229ec83b8d98da6d996c588a1ddb1" dependencies = [ "async-trait", "bollard", @@ -6202,9 +6280,9 @@ dependencies = [ [[package]] name = "testcontainers-modules" -version = "0.11.6" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d43ed4e8f58424c3a2c6c56dbea6643c3c23e8666a34df13c54f0a184e6c707" +checksum = "eac95cde96549fc19c6bf19ef34cc42bd56e264c1cb97e700e21555be0ecf9e2" dependencies = [ "testcontainers", ] @@ -6244,7 +6322,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -6255,7 +6333,7 @@ checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -6281,9 +6359,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.37" +version = "0.3.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" +checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" dependencies = [ "deranged", "itoa", @@ -6296,15 +6374,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.2" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" +checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" [[package]] name = "time-macros" -version = "0.2.19" +version = "0.2.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2834e6017e3e5e4b9834939793b282bc03b37a3336245fa820e35e233e2a85de" +checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" dependencies = [ "num-conv", "time-core", @@ -6321,9 +6399,9 @@ dependencies = [ [[package]] name = "tinystr" -version = "0.7.6" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b" dependencies = [ "displaydoc", "zerovec", @@ -6341,9 +6419,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8" +checksum = "09b3661f17e86524eccd4371ab0429194e0d7c008abb45f7a7495b1719463c71" dependencies = [ "tinyvec_macros", ] @@ -6356,9 +6434,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.44.1" +version = "1.45.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f382da615b842244d4b8738c82ed1275e6c5dd90c459a30941cd07080b06c91a" +checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779" dependencies = [ "backtrace", "bytes", @@ -6380,7 +6458,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -6402,7 +6480,7 @@ dependencies = [ "pin-project-lite", "postgres-protocol", "postgres-types", - "rand 0.9.0", + "rand 0.9.1", "socket2", "tokio", "tokio-util", @@ -6411,9 +6489,9 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.1" +version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" +checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" dependencies = [ "rustls", "tokio", @@ -6447,9 +6525,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.14" +version = "0.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b9590b93e6fcc1739458317cccd391ad3955e2bde8913edf6f95f9e65a8f034" +checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" dependencies = [ "bytes", "futures-core", @@ -6460,17 +6538,17 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.8" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +checksum = "3da5db5a963e24bc68be8b17b6fa82814bb22ee8660f192bb182771d498f09a3" [[package]] name = "toml_edit" -version = "0.22.24" +version = "0.22.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" +checksum = "310068873db2c5b3e7659d2cc35d21855dbafa50d1ce336397c666e3cb08137e" dependencies = [ - "indexmap 2.8.0", + "indexmap 2.10.0", "toml_datetime", "winnow", ] @@ -6487,7 +6565,7 @@ dependencies = [ "base64 0.22.1", "bytes", "h2", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "http-body-util", "hyper", @@ -6571,7 +6649,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -6669,7 +6747,7 @@ checksum = "f9534daa9fd3ed0bd911d462a37f172228077e7abf18c18a5f67199d959205f8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -6680,9 +6758,9 @@ checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" [[package]] name = "typify" -version = "0.3.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e03ba3643450cfd95a1aca2e1938fef63c1c1994489337998aff4ad771f21ef8" +checksum = "fcc5bec3cdff70fd542e579aa2e52967833e543a25fae0d14579043d2e868a50" dependencies = [ "typify-impl", "typify-macro", @@ -6690,9 +6768,9 @@ dependencies = [ [[package]] name = "typify-impl" -version = "0.3.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bce48219a2f3154aaa2c56cbf027728b24a3c8fe0a47ed6399781de2b3f3eeaf" +checksum = "b52a67305054e1da6f3d99ad94875dcd0c7c49adbd17b4b64f0eefb7ae5bf8ab" dependencies = [ "heck 0.5.0", "log", @@ -6703,16 +6781,16 @@ dependencies = [ "semver", "serde", "serde_json", - "syn 2.0.100", + "syn 2.0.104", "thiserror 2.0.12", "unicode-ident", ] [[package]] name = "typify-macro" -version = "0.3.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68b5780d745920ed73c5b7447496a9b5c42ed2681a9b70859377aec423ecf02b" +checksum = "0ff5799be156e4f635c348c6051d165e1c59997827155133351a8c4d333d9841" dependencies = [ "proc-macro2", "quote", @@ -6721,7 +6799,7 @@ dependencies = [ "serde", "serde_json", "serde_tokenstream", - "syn 2.0.100", + "syn 2.0.104", "typify-impl", ] @@ -6733,9 +6811,9 @@ checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" [[package]] name = "unicode-ident" -version = "1.0.16" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" [[package]] name = "unicode-normalization" @@ -6772,9 +6850,9 @@ checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" [[package]] name = "unindent" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" [[package]] name = "unsafe-libyaml" @@ -6806,12 +6884,6 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" -[[package]] -name = "utf16_iter" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" - [[package]] name = "utf8_iter" version = "1.0.4" @@ -6826,11 +6898,11 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" +checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" dependencies = [ - "getrandom 0.3.1", + "getrandom 0.3.3", "js-sys", "serde", "wasm-bindgen", @@ -6890,9 +6962,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasi" -version = "0.13.3+wasi-0.2.2" +version = "0.14.2+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" dependencies = [ "wit-bindgen-rt", ] @@ -6925,7 +6997,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", "wasm-bindgen-shared", ] @@ -6960,7 +7032,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -6995,7 +7067,7 @@ checksum = "17d5042cc5fa009658f9a7333ef24291b1291a25b6382dd68862a7f3b969f69b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -7045,11 +7117,11 @@ dependencies = [ [[package]] name = "whoami" -version = "1.5.2" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "372d5b87f58ec45c384ba03563b03544dc5fadc3983e434b286913f5b4a9bb6d" +checksum = "6994d13118ab492c3c80c1f81928718159254c53c472bf9ce36f8dae4add02a7" dependencies = [ - "redox_syscall 0.5.8", + "redox_syscall 0.5.12", "wasite", "web-sys", ] @@ -7087,55 +7159,70 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows" -version = "0.57.0" +version = "0.61.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12342cb4d8e3b046f3d80effd474a7a02447231330ef77d71daa6fbc40681143" +checksum = "c5ee8f3d025738cb02bad7868bbb5f8a6327501e870bf51f1b455b0a2454a419" dependencies = [ - "windows-core 0.57.0", - "windows-targets 0.52.6", + "windows-collections", + "windows-core", + "windows-future", + "windows-link", + "windows-numerics", ] [[package]] -name = "windows-core" -version = "0.52.0" +name = "windows-collections" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" dependencies = [ - "windows-targets 0.52.6", + "windows-core", ] [[package]] name = "windows-core" -version = "0.57.0" +version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2ed2439a290666cd67ecce2b0ffaad89c2a56b976b736e6ece670297897832d" +checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" dependencies = [ "windows-implement", "windows-interface", - "windows-result 0.1.2", - "windows-targets 0.52.6", + "windows-link", + "windows-result", + "windows-strings 0.4.2", +] + +[[package]] +name = "windows-future" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" +dependencies = [ + "windows-core", + "windows-link", + "windows-threading", ] [[package]] name = "windows-implement" -version = "0.57.0" +version = "0.60.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7" +checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] name = "windows-interface" -version = "0.57.0" +version = "0.59.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] @@ -7145,51 +7232,51 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" [[package]] -name = "windows-registry" +name = "windows-numerics" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" dependencies = [ - "windows-result 0.2.0", - "windows-strings", - "windows-targets 0.52.6", + "windows-core", + "windows-link", ] [[package]] -name = "windows-result" -version = "0.1.2" +name = "windows-registry" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e383302e8ec8515204254685643de10811af0ed97ea37210dc26fb0032647f8" +checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3" dependencies = [ - "windows-targets 0.52.6", + "windows-result", + "windows-strings 0.3.1", + "windows-targets 0.53.0", ] [[package]] name = "windows-result" -version = "0.2.0" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" dependencies = [ - "windows-targets 0.52.6", + "windows-link", ] [[package]] name = "windows-strings" -version = "0.1.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319" dependencies = [ - "windows-result 0.2.0", - "windows-targets 0.52.6", + "windows-link", ] [[package]] -name = "windows-sys" -version = "0.48.0" +name = "windows-strings" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" dependencies = [ - "windows-targets 0.48.5", + "windows-link", ] [[package]] @@ -7210,21 +7297,6 @@ dependencies = [ "windows-targets 0.52.6", ] -[[package]] -name = "windows-targets" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" -dependencies = [ - "windows_aarch64_gnullvm 0.48.5", - "windows_aarch64_msvc 0.48.5", - "windows_i686_gnu 0.48.5", - "windows_i686_msvc 0.48.5", - "windows_x86_64_gnu 0.48.5", - "windows_x86_64_gnullvm 0.48.5", - "windows_x86_64_msvc 0.48.5", -] - [[package]] name = "windows-targets" version = "0.52.6" @@ -7234,7 +7306,7 @@ dependencies = [ "windows_aarch64_gnullvm 0.52.6", "windows_aarch64_msvc 0.52.6", "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm", + "windows_i686_gnullvm 0.52.6", "windows_i686_msvc 0.52.6", "windows_x86_64_gnu 0.52.6", "windows_x86_64_gnullvm 0.52.6", @@ -7242,10 +7314,29 @@ dependencies = [ ] [[package]] -name = "windows_aarch64_gnullvm" -version = "0.48.5" +name = "windows-targets" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b" +dependencies = [ + "windows_aarch64_gnullvm 0.53.0", + "windows_aarch64_msvc 0.53.0", + "windows_i686_gnu 0.53.0", + "windows_i686_gnullvm 0.53.0", + "windows_i686_msvc 0.53.0", + "windows_x86_64_gnu 0.53.0", + "windows_x86_64_gnullvm 0.53.0", + "windows_x86_64_msvc 0.53.0", +] + +[[package]] +name = "windows-threading" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" +dependencies = [ + "windows-link", +] [[package]] name = "windows_aarch64_gnullvm" @@ -7254,10 +7345,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] -name = "windows_aarch64_msvc" -version = "0.48.5" +name = "windows_aarch64_gnullvm" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" [[package]] name = "windows_aarch64_msvc" @@ -7266,10 +7357,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] -name = "windows_i686_gnu" -version = "0.48.5" +name = "windows_aarch64_msvc" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" [[package]] name = "windows_i686_gnu" @@ -7277,6 +7368,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + [[package]] name = "windows_i686_gnullvm" version = "0.52.6" @@ -7284,10 +7381,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] -name = "windows_i686_msvc" -version = "0.48.5" +name = "windows_i686_gnullvm" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" [[package]] name = "windows_i686_msvc" @@ -7296,10 +7393,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] -name = "windows_x86_64_gnu" -version = "0.48.5" +name = "windows_i686_msvc" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" [[package]] name = "windows_x86_64_gnu" @@ -7308,10 +7405,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] -name = "windows_x86_64_gnullvm" -version = "0.48.5" +name = "windows_x86_64_gnu" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" [[package]] name = "windows_x86_64_gnullvm" @@ -7320,10 +7417,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] -name = "windows_x86_64_msvc" -version = "0.48.5" +name = "windows_x86_64_gnullvm" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" [[package]] name = "windows_x86_64_msvc" @@ -7331,35 +7428,35 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" + [[package]] name = "winnow" -version = "0.7.2" +version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59690dea168f2198d1a3b0cac23b8063efcd11012f10ae4698f284808c8ef603" +checksum = "c06928c8748d81b05c9be96aad92e1b6ff01833332f281e8cfca3be4b35fc9ec" dependencies = [ "memchr", ] [[package]] name = "wit-bindgen-rt" -version = "0.33.0" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.1", ] -[[package]] -name = "write16" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" - [[package]] name = "writeable" -version = "0.5.5" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" +checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" [[package]] name = "wyz" @@ -7372,13 +7469,12 @@ dependencies = [ [[package]] name = "xattr" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e105d177a3871454f754b33bb0ee637ecaaac997446375fd3e5d43a2ed00c909" +checksum = "0d65cbf2f12c15564212d48f4e3dfb87923d25d611f2aed18f4cb23f0413d89e" dependencies = [ "libc", - "linux-raw-sys 0.4.15", - "rustix 0.38.44", + "rustix 1.0.7", ] [[package]] @@ -7398,9 +7494,9 @@ dependencies = [ [[package]] name = "yoke" -version = "0.7.5" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" +checksum = "5f41bb01b8226ef4bfd589436a297c53d118f65921786300e427be8d487695cc" dependencies = [ "serde", "stable_deref_trait", @@ -7410,75 +7506,54 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.7.5" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", "synstructure", ] [[package]] name = "zerocopy" -version = "0.7.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" -dependencies = [ - "byteorder", - "zerocopy-derive 0.7.35", -] - -[[package]] -name = "zerocopy" -version = "0.8.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79386d31a42a4996e3336b0919ddb90f81112af416270cff95b5f5af22b839c2" -dependencies = [ - "zerocopy-derive 0.8.18", -] - -[[package]] -name = "zerocopy-derive" -version = "0.7.35" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.100", + "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.18" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76331675d372f91bf8d17e13afbd5fe639200b73d01f0fc748bb059f9cca2db7" +checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] name = "zerofrom" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" dependencies = [ "zerofrom-derive", ] [[package]] name = "zerofrom-derive" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", "synstructure", ] @@ -7488,11 +7563,22 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +[[package]] +name = "zerotrie" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36f0bbd478583f79edad978b407914f61b2972f5af6fa089686016be8f9af595" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + [[package]] name = "zerovec" -version = "0.10.4" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +checksum = "4a05eb080e015ba39cc9e23bbe5e7fb04d5fb040350f99f34e338d5fdd294428" dependencies = [ "yoke", "zerofrom", @@ -7501,20 +7587,20 @@ dependencies = [ [[package]] name = "zerovec-derive" -version = "0.10.3" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.104", ] [[package]] name = "zlib-rs" -version = "0.4.2" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b20717f0917c908dc63de2e44e97f1e6b126ca58d0e391cee86d504eb8fbd05" +checksum = "626bd9fa9734751fc50d6060752170984d7053f5a39061f524cda68023d4db8a" [[package]] name = "zstd" @@ -7527,18 +7613,18 @@ dependencies = [ [[package]] name = "zstd-safe" -version = "7.2.1" +version = "7.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" dependencies = [ "zstd-sys", ] [[package]] name = "zstd-sys" -version = "2.0.13+zstd.1.5.6" +version = "2.0.15+zstd.1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237" dependencies = [ "cc", "pkg-config", diff --git a/Cargo.toml b/Cargo.toml index de53b7df50d9..8124abd013f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,12 +42,14 @@ members = [ "datafusion/physical-expr", "datafusion/physical-expr-common", "datafusion/physical-optimizer", + "datafusion/pruning", "datafusion/physical-plan", "datafusion/proto", "datafusion/proto/gen", "datafusion/proto-common", "datafusion/proto-common/gen", "datafusion/session", + "datafusion/spark", "datafusion/sql", "datafusion/sqllogictest", "datafusion/substrait", @@ -75,7 +77,7 @@ repository = "https://github.com/apache/datafusion" # Define Minimum Supported Rust Version (MSRV) rust-version = "1.82.0" # Define DataFusion version -version = "46.0.1" +version = "48.0.0" [workspace.dependencies] # We turn off default-features for some dependencies here so the workspaces which inherit them can @@ -87,86 +89,91 @@ ahash = { version = "0.8", default-features = false, features = [ "runtime-rng", ] } apache-avro = { version = "0.17", default-features = false } -arrow = { version = "55.0.0", features = [ +arrow = { version = "55.2.0", features = [ "prettyprint", "chrono-tz", ] } -arrow-buffer = { version = "55.0.0", default-features = false } -arrow-flight = { version = "55.0.0", features = [ +arrow-buffer = { version = "55.2.0", default-features = false } +arrow-flight = { version = "55.2.0", features = [ "flight-sql-experimental", ] } -arrow-ipc = { version = "55.0.0", default-features = false, features = [ +arrow-ipc = { version = "55.2.0", default-features = false, features = [ "lz4", + "zstd", ] } -arrow-ord = { version = "55.0.0", default-features = false } -arrow-schema = { version = "55.0.0", default-features = false } +arrow-ord = { version = "55.2.0", default-features = false } +arrow-schema = { version = "55.2.0", default-features = false } async-trait = "0.1.88" bigdecimal = "0.4.8" bytes = "1.10" -chrono = { version = "0.4.38", default-features = false } +chrono = { version = "0.4.41", default-features = false } criterion = "0.5.1" -ctor = "0.2.9" +ctor = "0.4.0" dashmap = "6.0.1" -datafusion = { path = "datafusion/core", version = "46.0.1", default-features = false } -datafusion-catalog = { path = "datafusion/catalog", version = "46.0.1" } -datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "46.0.1" } -datafusion-common = { path = "datafusion/common", version = "46.0.1", default-features = false } -datafusion-common-runtime = { path = "datafusion/common-runtime", version = "46.0.1" } -datafusion-datasource = { path = "datafusion/datasource", version = "46.0.1", default-features = false } -datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "46.0.1", default-features = false } -datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "46.0.1", default-features = false } -datafusion-datasource-json = { path = "datafusion/datasource-json", version = "46.0.1", default-features = false } -datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "46.0.1", default-features = false } -datafusion-doc = { path = "datafusion/doc", version = "46.0.1" } -datafusion-execution = { path = "datafusion/execution", version = "46.0.1" } -datafusion-expr = { path = "datafusion/expr", version = "46.0.1" } -datafusion-expr-common = { path = "datafusion/expr-common", version = "46.0.1" } -datafusion-ffi = { path = "datafusion/ffi", version = "46.0.1" } -datafusion-functions = { path = "datafusion/functions", version = "46.0.1" } -datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "46.0.1" } -datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "46.0.1" } -datafusion-functions-nested = { path = "datafusion/functions-nested", version = "46.0.1" } -datafusion-functions-table = { path = "datafusion/functions-table", version = "46.0.1" } -datafusion-functions-window = { path = "datafusion/functions-window", version = "46.0.1" } -datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "46.0.1" } -datafusion-macros = { path = "datafusion/macros", version = "46.0.1" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "46.0.1", default-features = false } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "46.0.1", default-features = false } -datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "46.0.1", default-features = false } -datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "46.0.1" } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "46.0.1" } -datafusion-proto = { path = "datafusion/proto", version = "46.0.1" } -datafusion-proto-common = { path = "datafusion/proto-common", version = "46.0.1" } -datafusion-session = { path = "datafusion/session", version = "46.0.1" } -datafusion-sql = { path = "datafusion/sql", version = "46.0.1" } +datafusion = { path = "datafusion/core", version = "48.0.0", default-features = false } +datafusion-catalog = { path = "datafusion/catalog", version = "48.0.0" } +datafusion-catalog-listing = { path = "datafusion/catalog-listing", version = "48.0.0" } +datafusion-common = { path = "datafusion/common", version = "48.0.0", default-features = false } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "48.0.0" } +datafusion-datasource = { path = "datafusion/datasource", version = "48.0.0", default-features = false } +datafusion-datasource-avro = { path = "datafusion/datasource-avro", version = "48.0.0", default-features = false } +datafusion-datasource-csv = { path = "datafusion/datasource-csv", version = "48.0.0", default-features = false } +datafusion-datasource-json = { path = "datafusion/datasource-json", version = "48.0.0", default-features = false } +datafusion-datasource-parquet = { path = "datafusion/datasource-parquet", version = "48.0.0", default-features = false } +datafusion-doc = { path = "datafusion/doc", version = "48.0.0" } +datafusion-execution = { path = "datafusion/execution", version = "48.0.0" } +datafusion-expr = { path = "datafusion/expr", version = "48.0.0" } +datafusion-expr-common = { path = "datafusion/expr-common", version = "48.0.0" } +datafusion-ffi = { path = "datafusion/ffi", version = "48.0.0" } +datafusion-functions = { path = "datafusion/functions", version = "48.0.0" } +datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "48.0.0" } +datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "48.0.0" } +datafusion-functions-nested = { path = "datafusion/functions-nested", version = "48.0.0" } +datafusion-functions-table = { path = "datafusion/functions-table", version = "48.0.0" } +datafusion-functions-window = { path = "datafusion/functions-window", version = "48.0.0" } +datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "48.0.0" } +datafusion-macros = { path = "datafusion/macros", version = "48.0.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "48.0.0", default-features = false } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "48.0.0", default-features = false } +datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "48.0.0", default-features = false } +datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "48.0.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "48.0.0" } +datafusion-proto = { path = "datafusion/proto", version = "48.0.0" } +datafusion-proto-common = { path = "datafusion/proto-common", version = "48.0.0" } +datafusion-pruning = { path = "datafusion/pruning", version = "48.0.0" } +datafusion-session = { path = "datafusion/session", version = "48.0.0" } +datafusion-spark = { path = "datafusion/spark", version = "48.0.0" } +datafusion-sql = { path = "datafusion/sql", version = "48.0.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "48.0.0" } doc-comment = "0.3" env_logger = "0.11" futures = "0.3" -half = { version = "2.5.0", default-features = false } +half = { version = "2.6.0", default-features = false } hashbrown = { version = "0.14.5", features = ["raw"] } -indexmap = "2.8.0" +indexmap = "2.10.0" itertools = "0.14" log = "^0.4" -object_store = { version = "0.12.0", default-features = false } +object_store = { version = "0.12.2", default-features = false } parking_lot = "0.12" -parquet = { version = "55.0.0", default-features = false, features = [ +parquet = { version = "55.2.0", default-features = false, features = [ "arrow", "async", "object_store", + "encryption", ] } pbjson = { version = "0.7.0" } pbjson-types = "0.7" # Should match arrow-flight's version of prost. -insta = { version = "1.41.1", features = ["glob", "filters"] } +insta = { version = "1.43.1", features = ["glob", "filters"] } prost = "0.13.1" -rand = "0.8.5" +rand = "0.9" recursive = "0.1.1" regex = "1.8" -rstest = "0.24.0" +rstest = "0.25.0" serde_json = "1" -sqlparser = { version = "0.55.0", features = ["visitor"] } +sqlparser = { version = "0.55.0", default-features = false, features = ["std", "visitor"] } tempfile = "3" -tokio = { version = "1.44", features = ["macros", "rt", "sync"] } +tokio = { version = "1.45", features = ["macros", "rt", "sync"] } url = "2.5.4" [profile.release] @@ -209,7 +216,14 @@ strip = false # Detects large stack-allocated futures that may cause stack overflow crashes (see threshold in clippy.toml) large_futures = "warn" used_underscore_binding = "warn" +or_fun_call = "warn" +unnecessary_lazy_evaluations = "warn" +uninlined_format_args = "warn" [workspace.lints.rust] -unexpected_cfgs = { level = "warn", check-cfg = ["cfg(tarpaulin)"] } +unexpected_cfgs = { level = "warn", check-cfg = [ + 'cfg(datafusion_coop, values("tokio", "tokio_fallback", "per_stream"))', + "cfg(tarpaulin)", + "cfg(tarpaulin_include)", +] } unused_qualifications = "deny" diff --git a/NOTICE.txt b/NOTICE.txt index 21be1a20d554..7f3c80d606c0 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -1,5 +1,5 @@ Apache DataFusion -Copyright 2019-2024 The Apache Software Foundation +Copyright 2019-2025 The Apache Software Foundation This product includes software developed at -The Apache Software Foundation (http://www.apache.org/). \ No newline at end of file +The Apache Software Foundation (http://www.apache.org/). diff --git a/README.md b/README.md index 158033d40599..c142d8f366b2 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ [![Open Issues][open-issues-badge]][open-issues-url] [![Discord chat][discord-badge]][discord-url] [![Linkedin][linkedin-badge]][linkedin-url] +![Crates.io MSRV][msrv-badge] [crates-badge]: https://img.shields.io/crates/v/datafusion.svg [crates-url]: https://crates.io/crates/datafusion @@ -40,6 +41,7 @@ [open-issues-url]: https://github.com/apache/datafusion/issues [linkedin-badge]: https://img.shields.io/badge/Follow-Linkedin-blue [linkedin-url]: https://www.linkedin.com/company/apache-datafusion/ +[msrv-badge]: https://img.shields.io/crates/msrv/datafusion?label=Min%20Rust%20Version [Website](https://datafusion.apache.org/) | [API Docs](https://docs.rs/datafusion/latest/datafusion/) | @@ -133,20 +135,6 @@ Optional features: [apache avro]: https://avro.apache.org/ [apache parquet]: https://parquet.apache.org/ -## Rust Version Compatibility Policy - -The Rust toolchain releases are tracked at [Rust Versions](https://releases.rs) and follow -[semantic versioning](https://semver.org/). A Rust toolchain release can be identified -by a version string like `1.80.0`, or more generally `major.minor.patch`. - -DataFusion's supports the last 4 stable Rust minor versions released and any such versions released within the last 4 months. - -For example, given the releases `1.78.0`, `1.79.0`, `1.80.0`, `1.80.1` and `1.81.0` DataFusion will support 1.78.0, which is 3 minor versions prior to the most minor recent `1.81`. - -Note: If a Rust hotfix is released for the current MSRV, the MSRV will be updated to the specific minor version that includes all applicable hotfixes preceding other policies. - -DataFusion enforces MSRV policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Fdatafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) - ## DataFusion API Evolution and Deprecation Guidelines Public methods in Apache DataFusion evolve over time: while we try to maintain a diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 063f4dac22d8..f9c198597b74 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -51,7 +51,7 @@ snmalloc-rs = { version = "0.3", optional = true } structopt = { version = "0.3", default-features = false } test-utils = { path = "../test-utils/", version = "0.1.0" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } -tokio-util = { version = "0.7.14" } +tokio-util = { version = "0.7.15" } [dev-dependencies] datafusion-proto = { workspace = true } diff --git a/benchmarks/README.md b/benchmarks/README.md index 86b2e1b3b958..d0f413b2e97b 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -23,7 +23,6 @@ This crate contains benchmarks based on popular public data sets and open source benchmark suites, to help with performance and scalability testing of DataFusion. - ## Other engines The benchmarks measure changes to DataFusion itself, rather than @@ -31,11 +30,11 @@ its performance against other engines. For competitive benchmarking, DataFusion is included in the benchmark setups for several popular benchmarks that compare performance with other engines. For example: -* [ClickBench] scripts are in the [ClickBench repo](https://github.com/ClickHouse/ClickBench/tree/main/datafusion) -* [H2o.ai `db-benchmark`] scripts are in [db-benchmark](https://github.com/apache/datafusion/tree/main/benchmarks/src/h2o.rs) +- [ClickBench] scripts are in the [ClickBench repo](https://github.com/ClickHouse/ClickBench/tree/main/datafusion) +- [H2o.ai `db-benchmark`] scripts are in [db-benchmark](https://github.com/apache/datafusion/tree/main/benchmarks/src/h2o.rs) -[ClickBench]: https://github.com/ClickHouse/ClickBench/tree/main -[H2o.ai `db-benchmark`]: https://github.com/h2oai/db-benchmark +[clickbench]: https://github.com/ClickHouse/ClickBench/tree/main +[h2o.ai `db-benchmark`]: https://github.com/h2oai/db-benchmark # Running the benchmarks @@ -65,31 +64,54 @@ Create / download a specific dataset (TPCH) ```shell ./bench.sh data tpch ``` + Data is placed in the `data` subdirectory. ## Running benchmarks Run benchmark for TPC-H dataset + ```shell ./bench.sh run tpch ``` + or for TPC-H dataset scale 10 + ```shell ./bench.sh run tpch10 ``` To run for specific query, for example Q21 + ```shell ./bench.sh run tpch10 21 ``` -## Select join algorithm +## Benchmark with modified configurations + +### Select join algorithm + The benchmark runs with `prefer_hash_join == true` by default, which enforces HASH join algorithm. To run TPCH benchmarks with join other than HASH: + ```shell PREFER_HASH_JOIN=false ./bench.sh run tpch ``` +### Configure with environment variables + +Any [datafusion options](https://datafusion.apache.org/user-guide/configs.html) that are provided environment variables are +also considered by the benchmarks. +The following configuration runs the TPCH benchmark with datafusion configured to _not_ repartition join keys. + +```shell +DATAFUSION_OPTIMIZER_REPARTITION_JOINS=false ./bench.sh run tpch +``` + +You might want to adjust the results location to avoid overwriting previous results. +Environment configuration that was picked up by datafusion is logged at `info` level. +To verify that datafusion picked up your configuration, run the benchmarks with `RUST_LOG=info` or higher. + ## Comparing performance of main and a branch ```shell @@ -407,7 +429,7 @@ logs. Example -dfbench parquet-filter --path ./data --scale-factor 1.0 +dfbench parquet-filter --path ./data --scale-factor 1.0 generates the synthetic dataset at `./data/logs.parquet`. The size of the dataset can be controlled through the `size_factor` @@ -439,6 +461,7 @@ Iteration 2 returned 1781686 rows in 1947 ms ``` ## Sort + Test performance of sorting large datasets This test sorts a a synthetic dataset generated during the @@ -462,26 +485,39 @@ Additionally, an optional `--limit` flag is available for the sort benchmark. Wh See [`sort_tpch.rs`](src/sort_tpch.rs) for more details. ### Sort TPCH Benchmark Example Runs + 1. Run all queries with default setting: + ```bash cargo run --release --bin dfbench -- sort-tpch -p './datafusion/benchmarks/data/tpch_sf1' -o '/tmp/sort_tpch.json' ``` 2. Run a specific query: + ```bash cargo run --release --bin dfbench -- sort-tpch -p './datafusion/benchmarks/data/tpch_sf1' -o '/tmp/sort_tpch.json' --query 2 ``` 3. Run all queries as TopK queries on presorted data: + ```bash cargo run --release --bin dfbench -- sort-tpch --sorted --limit 10 -p './datafusion/benchmarks/data/tpch_sf1' -o '/tmp/sort_tpch.json' ``` 4. Run all queries with `bench.sh` script: + ```bash ./bench.sh run sort_tpch ``` +### TopK TPCH + +In addition, topk_tpch is available from the bench.sh script: + +```bash +./bench.sh run topk_tpch +``` + ## IMDB Run Join Order Benchmark (JOB) on IMDB dataset. @@ -515,59 +551,78 @@ External aggregation benchmarks run several aggregation queries with different m This benchmark is inspired by [DuckDB's external aggregation paper](https://hannes.muehleisen.org/publications/icde2024-out-of-core-kuiper-boncz-muehleisen.pdf), specifically Section VI. ### External Aggregation Example Runs + 1. Run all queries with predefined memory limits: + ```bash # Under 'benchmarks/' directory cargo run --release --bin external_aggr -- benchmark -n 4 --iterations 3 -p '....../data/tpch_sf1' -o '/tmp/aggr.json' ``` 2. Run a query with specific memory limit: + ```bash cargo run --release --bin external_aggr -- benchmark -n 4 --iterations 3 -p '....../data/tpch_sf1' -o '/tmp/aggr.json' --query 1 --memory-limit 30M ``` 3. Run all queries with `bench.sh` script: + ```bash ./bench.sh data external_aggr ./bench.sh run external_aggr ``` +## h2o.ai benchmarks + +The h2o.ai benchmarks are a set of performance tests for groupby and join operations. Beyond the standard h2o benchmark, there is also an extended benchmark for window functions. These benchmarks use synthetic data with configurable sizes (small: 1e7 rows, medium: 1e8 rows, big: 1e9 rows) to evaluate DataFusion's performance across different data scales. + +Reference: + +- [H2O AI Benchmark](https://duckdb.org/2023/04/14/h2oai.html) +- [Extended window benchmark](https://duckdb.org/2024/06/26/benchmarks-over-time.html#window-functions-benchmark) -## h2o benchmarks for groupby +### h2o benchmarks for groupby + +#### Generate data for h2o benchmarks -### Generate data for h2o benchmarks There are three options for generating data for h2o benchmarks: `small`, `medium`, and `big`. The data is generated in the `data` directory. 1. Generate small data (1e7 rows) + ```bash ./bench.sh data h2o_small ``` - 2. Generate medium data (1e8 rows) + ```bash ./bench.sh data h2o_medium ``` - 3. Generate large data (1e9 rows) + ```bash ./bench.sh data h2o_big ``` -### Run h2o benchmarks +#### Run h2o benchmarks + There are three options for running h2o benchmarks: `small`, `medium`, and `big`. + 1. Run small data benchmark + ```bash ./bench.sh run h2o_small ``` 2. Run medium data benchmark + ```bash ./bench.sh run h2o_medium ``` 3. Run large data benchmark + ```bash ./bench.sh run h2o_big ``` @@ -575,53 +630,53 @@ There are three options for running h2o benchmarks: `small`, `medium`, and `big` 4. Run a specific query with a specific data path For example, to run query 1 with the small data generated above: + ```bash cargo run --release --bin dfbench -- h2o --path ./benchmarks/data/h2o/G1_1e7_1e7_100_0.csv --query 1 ``` -## h2o benchmarks for join +### h2o benchmarks for join -### Generate data for h2o benchmarks There are three options for generating data for h2o benchmarks: `small`, `medium`, and `big`. The data is generated in the `data` directory. -1. Generate small data (4 table files, the largest is 1e7 rows) +Here is a example to generate `small` dataset and run the benchmark. To run other +dataset size configuration, change the command similar to the previous example. + ```bash +# Generate small data (4 table files, the largest is 1e7 rows) ./bench.sh data h2o_small_join + +# Run the benchmark +./bench.sh run h2o_small_join ``` +To run a specific query with a specific join data paths, the data paths are including 4 table files. -2. Generate medium data (4 table files, the largest is 1e8 rows) -```bash -./bench.sh data h2o_medium_join -``` +For example, to run query 1 with the small data generated above: -3. Generate large data (4 table files, the largest is 1e9 rows) ```bash -./bench.sh data h2o_big_join +cargo run --release --bin dfbench -- h2o --join-paths ./benchmarks/data/h2o/J1_1e7_NA_0.csv,./benchmarks/data/h2o/J1_1e7_1e1_0.csv,./benchmarks/data/h2o/J1_1e7_1e4_0.csv,./benchmarks/data/h2o/J1_1e7_1e7_NA.csv --queries-path ./benchmarks/queries/h2o/join.sql --query 1 ``` -### Run h2o benchmarks -There are three options for running h2o benchmarks: `small`, `medium`, and `big`. -1. Run small data benchmark -```bash -./bench.sh run h2o_small_join -``` +### Extended h2o benchmarks for window -2. Run medium data benchmark -```bash -./bench.sh run h2o_medium_join -``` +This benchmark extends the h2o benchmark suite to evaluate window function performance. H2o window benchmark uses the same dataset as the h2o join benchmark. There are three options for generating data for h2o benchmarks: `small`, `medium`, and `big`. + +Here is a example to generate `small` dataset and run the benchmark. To run other +dataset size configuration, change the command similar to the previous example. -3. Run large data benchmark ```bash -./bench.sh run h2o_big_join +# Generate small data +./bench.sh data h2o_small_window + +# Run the benchmark +./bench.sh run h2o_small_window ``` -4. Run a specific query with a specific join data paths, the data paths are including 4 table files. +To run a specific query with a specific window data paths, the data paths are including 4 table files (the same as h2o-join dataset) For example, to run query 1 with the small data generated above: + ```bash -cargo run --release --bin dfbench -- h2o --join-paths ./benchmarks/data/h2o/J1_1e7_NA_0.csv,./benchmarks/data/h2o/J1_1e7_1e1_0.csv,./benchmarks/data/h2o/J1_1e7_1e4_0.csv,./benchmarks/data/h2o/J1_1e7_1e7_NA.csv --queries-path ./benchmarks/queries/h2o/join.sql --query 1 +cargo run --release --bin dfbench -- h2o --join-paths ./benchmarks/data/h2o/J1_1e7_NA_0.csv,./benchmarks/data/h2o/J1_1e7_1e1_0.csv,./benchmarks/data/h2o/J1_1e7_1e4_0.csv,./benchmarks/data/h2o/J1_1e7_1e7_NA.csv --queries-path ./benchmarks/queries/h2o/window.sql --query 1 ``` -[1]: http://www.tpc.org/tpch/ -[2]: https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 5d3ad3446ddb..effce26d1cd2 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -28,6 +28,12 @@ set -e # https://stackoverflow.com/questions/59895/how-do-i-get-the-directory-where-a-bash-script-is-located-from-within-the-script SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +# Execute command and also print it, for debugging purposes +debug_run() { + set -x + "$@" + set +x +} # Set Defaults COMMAND= @@ -43,61 +49,83 @@ usage() { Orchestrates running benchmarks against DataFusion checkouts Usage: -$0 data [benchmark] [query] -$0 run [benchmark] +$0 data [benchmark] +$0 run [benchmark] [query] $0 compare +$0 compare_detail $0 venv -********** +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ Examples: -********** +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ # Create the datasets for all benchmarks in $DATA_DIR ./bench.sh data # Run the 'tpch' benchmark on the datafusion checkout in /source/datafusion DATAFUSION_DIR=/source/datafusion ./bench.sh run tpch -********** -* Commands -********** -data: Generates or downloads data needed for benchmarking -run: Runs the named benchmark -compare: Compares results from benchmark runs -venv: Creates new venv (unless already exists) and installs compare's requirements into it - -********** -* Benchmarks -********** +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +Commands +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +data: Generates or downloads data needed for benchmarking +run: Runs the named benchmark +compare: Compares fastest results from benchmark runs +compare_detail: Compares minimum, average (±stddev), and maximum results from benchmark runs +venv: Creates new venv (unless already exists) and installs compare's requirements into it + +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +Benchmarks +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +# Run all of the following benchmarks all(default): Data/Run/Compare for all benchmarks + +# TPC-H Benchmarks tpch: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), single parquet file per table, hash join +tpch_csv: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), single csv file per table, hash join tpch_mem: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), query from memory tpch10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), single parquet file per table, hash join +tpch_csv10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), single csv file per table, hash join tpch_mem10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), query from memory -cancellation: How long cancelling a query takes -parquet: Benchmark of parquet reader's filtering speed -sort: Benchmark of sorting speed -sort_tpch: Benchmark of sorting speed for end-to-end sort queries on TPCH dataset + +# Extended TPC-H Benchmarks +sort_tpch: Benchmark of sorting speed for end-to-end sort queries on TPC-H dataset (SF=1) +topk_tpch: Benchmark of top-k (sorting with limit) queries on TPC-H dataset (SF=1) +external_aggr: External aggregation benchmark on TPC-H dataset (SF=1) + +# ClickBench Benchmarks clickbench_1: ClickBench queries against a single parquet file clickbench_partitioned: ClickBench queries against a partitioned (100 files) parquet clickbench_extended: ClickBench \"inspired\" queries against a single parquet (DataFusion specific) -external_aggr: External aggregation benchmark + +# H2O.ai Benchmarks (Group By, Join, Window) h2o_small: h2oai benchmark with small dataset (1e7 rows) for groupby, default file format is csv h2o_medium: h2oai benchmark with medium dataset (1e8 rows) for groupby, default file format is csv h2o_big: h2oai benchmark with large dataset (1e9 rows) for groupby, default file format is csv h2o_small_join: h2oai benchmark with small dataset (1e7 rows) for join, default file format is csv h2o_medium_join: h2oai benchmark with medium dataset (1e8 rows) for join, default file format is csv h2o_big_join: h2oai benchmark with large dataset (1e9 rows) for join, default file format is csv +h2o_small_window: Extended h2oai benchmark with small dataset (1e7 rows) for window, default file format is csv +h2o_medium_window: Extended h2oai benchmark with medium dataset (1e8 rows) for window, default file format is csv +h2o_big_window: Extended h2oai benchmark with large dataset (1e9 rows) for window, default file format is csv + +# Join Order Benchmark (IMDB) imdb: Join Order Benchmark (JOB) using the IMDB dataset converted to parquet -********** -* Supported Configuration (Environment Variables) -********** +# Micro-Benchmarks (specific operators and features) +cancellation: How long cancelling a query takes +parquet: Benchmark of parquet reader's filtering speed +sort: Benchmark of sorting speed + +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +Supported Configuration (Environment Variables) +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ DATA_DIR directory to store datasets CARGO_COMMAND command that runs the benchmark binary DATAFUSION_DIR directory to use (default $DATAFUSION_DIR) RESULTS_NAME folder where the benchmark files are stored PREFER_HASH_JOIN Prefer hash join algorithm (default true) VENV_PATH Python venv to use for compare and venv commands (default ./venv, override by /bin/activate) +DATAFUSION_* Set the given datafusion configuration " exit 1 } @@ -204,6 +232,16 @@ main() { h2o_big_join) data_h2o_join "BIG" "CSV" ;; + # h2o window benchmark uses the same data as the h2o join + h2o_small_window) + data_h2o_join "SMALL" "CSV" + ;; + h2o_medium_window) + data_h2o_join "MEDIUM" "CSV" + ;; + h2o_big_window) + data_h2o_join "BIG" "CSV" + ;; external_aggr) # same data as for tpch data_tpch "1" @@ -212,6 +250,10 @@ main() { # same data as for tpch data_tpch "1" ;; + topk_tpch) + # same data as for tpch + data_tpch "1" + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for data generation" usage @@ -226,10 +268,15 @@ main() { RESULTS_NAME=${RESULTS_NAME:-"${BRANCH_NAME}"} RESULTS_DIR=${RESULTS_DIR:-"$SCRIPT_DIR/results/$RESULTS_NAME"} + # Optional query filter to run specific query + QUERY=${ARG3} + QUERY_ARG=$([ -n "$QUERY" ] && echo "--query ${QUERY}" || echo "") + echo "***************************" echo "DataFusion Benchmark Script" echo "COMMAND: ${COMMAND}" echo "BENCHMARK: ${BENCHMARK}" + echo "QUERY: ${QUERY:-All}" echo "DATAFUSION_DIR: ${DATAFUSION_DIR}" echo "BRANCH_NAME: ${BRANCH_NAME}" echo "DATA_DIR: ${DATA_DIR}" @@ -244,9 +291,11 @@ main() { mkdir -p "${DATA_DIR}" case "$BENCHMARK" in all) - run_tpch "1" + run_tpch "1" "parquet" + run_tpch "1" "csv" run_tpch_mem "1" - run_tpch "10" + run_tpch "10" "parquet" + run_tpch "10" "csv" run_tpch_mem "10" run_cancellation run_parquet @@ -264,13 +313,19 @@ main() { run_external_aggr ;; tpch) - run_tpch "1" + run_tpch "1" "parquet" + ;; + tpch_csv) + run_tpch "1" "csv" ;; tpch_mem) run_tpch_mem "1" ;; tpch10) - run_tpch "10" + run_tpch "10" "parquet" + ;; + tpch_csv10) + run_tpch "10" "csv" ;; tpch_mem10) run_tpch_mem "10" @@ -314,12 +369,24 @@ main() { h2o_big_join) run_h2o_join "BIG" "CSV" "join" ;; + h2o_small_window) + run_h2o_window "SMALL" "CSV" "window" + ;; + h2o_medium_window) + run_h2o_window "MEDIUM" "CSV" "window" + ;; + h2o_big_window) + run_h2o_window "BIG" "CSV" "window" + ;; external_aggr) run_external_aggr ;; sort_tpch) run_sort_tpch ;; + topk_tpch) + run_topk_tpch + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for run" usage @@ -331,6 +398,9 @@ main() { compare) compare_benchmarks "$ARG2" "$ARG3" ;; + compare_detail) + compare_benchmarks "$ARG2" "$ARG3" "--detailed" + ;; venv) setup_venv ;; @@ -396,6 +466,17 @@ data_tpch() { $CARGO_COMMAND --bin tpch -- convert --input "${TPCH_DIR}" --output "${TPCH_DIR}" --format parquet popd > /dev/null fi + + # Create 'csv' files from tbl + FILE="${TPCH_DIR}/csv/supplier" + if test -d "${FILE}"; then + echo " csv files exist ($FILE exists)." + else + echo " creating csv files using benchmark binary ..." + pushd "${SCRIPT_DIR}" > /dev/null + $CARGO_COMMAND --bin tpch -- convert --input "${TPCH_DIR}" --output "${TPCH_DIR}/csv" --format csv + popd > /dev/null + fi } # Runs the tpch benchmark @@ -410,12 +491,9 @@ run_tpch() { RESULTS_FILE="${RESULTS_DIR}/tpch_sf${SCALE_FACTOR}.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running tpch benchmark..." - # Optional query filter to run specific query - QUERY=$([ -n "$ARG3" ] && echo "--query $ARG3" || echo "") - # debug the target command - set -x - $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" $QUERY - set +x + + FORMAT=$2 + debug_run $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format ${FORMAT} -o "${RESULTS_FILE}" ${QUERY_ARG} } # Runs the tpch in memory @@ -430,13 +508,8 @@ run_tpch_mem() { RESULTS_FILE="${RESULTS_DIR}/tpch_mem_sf${SCALE_FACTOR}.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running tpch_mem benchmark..." - # Optional query filter to run specific query - QUERY=$([ -n "$ARG3" ] && echo "--query $ARG3" || echo "") - # debug the target command - set -x # -m means in memory - $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" -m --format parquet -o "${RESULTS_FILE}" $QUERY - set +x + debug_run $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" -m --format parquet -o "${RESULTS_FILE}" ${QUERY_ARG} } # Runs the cancellation benchmark @@ -444,7 +517,7 @@ run_cancellation() { RESULTS_FILE="${RESULTS_DIR}/cancellation.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running cancellation benchmark..." - $CARGO_COMMAND --bin dfbench -- cancellation --iterations 5 --path "${DATA_DIR}/cancellation" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- cancellation --iterations 5 --path "${DATA_DIR}/cancellation" -o "${RESULTS_FILE}" } # Runs the parquet filter benchmark @@ -452,7 +525,7 @@ run_parquet() { RESULTS_FILE="${RESULTS_DIR}/parquet.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running parquet filter benchmark..." - $CARGO_COMMAND --bin parquet -- filter --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin parquet -- filter --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" } # Runs the sort benchmark @@ -460,7 +533,7 @@ run_sort() { RESULTS_FILE="${RESULTS_DIR}/sort.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running sort benchmark..." - $CARGO_COMMAND --bin parquet -- sort --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin parquet -- sort --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" } @@ -514,7 +587,7 @@ run_clickbench_1() { RESULTS_FILE="${RESULTS_DIR}/clickbench_1.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (1 file) benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries" -o "${RESULTS_FILE}" ${QUERY_ARG} } # Runs the clickbench benchmark with the partitioned parquet files @@ -522,7 +595,7 @@ run_clickbench_partitioned() { RESULTS_FILE="${RESULTS_DIR}/clickbench_partitioned.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (partitioned, 100 files) benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries" -o "${RESULTS_FILE}" ${QUERY_ARG} } # Runs the clickbench "extended" benchmark with a single large parquet file @@ -530,7 +603,7 @@ run_clickbench_extended() { RESULTS_FILE="${RESULTS_DIR}/clickbench_extended.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (1 file) extended benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/extended.sql" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/extended" -o "${RESULTS_FILE}" ${QUERY_ARG} } # Downloads the csv.gz files IMDB datasets from Peter Boncz's homepage(one of the JOB paper authors) @@ -645,7 +718,7 @@ run_imdb() { RESULTS_FILE="${RESULTS_DIR}/imdb.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running imdb benchmark..." - $CARGO_COMMAND --bin imdb -- benchmark datafusion --iterations 5 --path "${IMDB_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin imdb -- benchmark datafusion --iterations 5 --path "${IMDB_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" ${QUERY_ARG} } data_h2o() { @@ -800,6 +873,7 @@ data_h2o_join() { deactivate } +# Runner for h2o groupby benchmark run_h2o() { # Default values for size and data format SIZE=${1:-"SMALL"} @@ -835,14 +909,16 @@ run_h2o() { QUERY_FILE="${SCRIPT_DIR}/queries/h2o/${RUN_Type}.sql" # Run the benchmark using the dynamically constructed file path and query file - $CARGO_COMMAND --bin dfbench -- h2o \ + debug_run $CARGO_COMMAND --bin dfbench -- h2o \ --iterations 3 \ --path "${H2O_DIR}/${FILE_NAME}" \ --queries-path "${QUERY_FILE}" \ - -o "${RESULTS_FILE}" + -o "${RESULTS_FILE}" \ + ${QUERY_ARG} } -run_h2o_join() { +# Utility function to run h2o join/window benchmark +h2o_runner() { # Default values for size and data format SIZE=${1:-"SMALL"} DATA_FORMAT=${2:-"CSV"} @@ -851,10 +927,10 @@ run_h2o_join() { # Data directory and results file path H2O_DIR="${DATA_DIR}/h2o" - RESULTS_FILE="${RESULTS_DIR}/h2o_join.json" + RESULTS_FILE="${RESULTS_DIR}/h2o_${RUN_Type}.json" echo "RESULTS_FILE: ${RESULTS_FILE}" - echo "Running h2o join benchmark..." + echo "Running h2o ${RUN_Type} benchmark..." # Set the file name based on the size case "$SIZE" in @@ -882,14 +958,25 @@ run_h2o_join() { ;; esac - # Set the query file name based on the RUN_Type + # Set the query file name based on the RUN_Type QUERY_FILE="${SCRIPT_DIR}/queries/h2o/${RUN_Type}.sql" - $CARGO_COMMAND --bin dfbench -- h2o \ + debug_run $CARGO_COMMAND --bin dfbench -- h2o \ --iterations 3 \ --join-paths "${H2O_DIR}/${X_TABLE_FILE_NAME},${H2O_DIR}/${SMALL_TABLE_FILE_NAME},${H2O_DIR}/${MEDIUM_TABLE_FILE_NAME},${H2O_DIR}/${LARGE_TABLE_FILE_NAME}" \ --queries-path "${QUERY_FILE}" \ - -o "${RESULTS_FILE}" + -o "${RESULTS_FILE}" \ + ${QUERY_ARG} +} + +# Runners for h2o join benchmark +run_h2o_join() { + h2o_runner "$1" "$2" "join" +} + +# Runners for h2o join benchmark +run_h2o_window() { + h2o_runner "$1" "$2" "window" } # Runs the external aggregation benchmark @@ -905,7 +992,7 @@ run_external_aggr() { # number-of-partitions), and by default `--partitions` is set to number of # CPU cores, we set a constant number of partitions to prevent this # benchmark to fail on some machines. - $CARGO_COMMAND --bin external_aggr -- benchmark --partitions 4 --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin external_aggr -- benchmark --partitions 4 --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" ${QUERY_ARG} } # Runs the sort integration benchmark @@ -915,7 +1002,17 @@ run_sort_tpch() { echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running sort tpch benchmark..." - $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" + debug_run $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" ${QUERY_ARG} +} + +# Runs the sort tpch integration benchmark with limit 100 (topk) +run_topk_tpch() { + TPCH_DIR="${DATA_DIR}/tpch_sf1" + RESULTS_FILE="${RESULTS_DIR}/run_topk_tpch.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running topk tpch benchmark..." + + $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" --limit 100 ${QUERY_ARG} } @@ -923,6 +1020,8 @@ compare_benchmarks() { BASE_RESULTS_DIR="${SCRIPT_DIR}/results" BRANCH1="$1" BRANCH2="$2" + OPTS="$3" + if [ -z "$BRANCH1" ] ; then echo " not specified. Available branches:" ls -1 "${BASE_RESULTS_DIR}" @@ -943,7 +1042,7 @@ compare_benchmarks() { echo "--------------------" echo "Benchmark ${BENCH}" echo "--------------------" - PATH=$VIRTUAL_ENV/bin:$PATH python3 "${SCRIPT_DIR}"/compare.py "${RESULTS_FILE1}" "${RESULTS_FILE2}" + PATH=$VIRTUAL_ENV/bin:$PATH python3 "${SCRIPT_DIR}"/compare.py $OPTS "${RESULTS_FILE1}" "${RESULTS_FILE2}" else echo "Note: Skipping ${RESULTS_FILE1} as ${RESULTS_FILE2} does not exist" fi diff --git a/benchmarks/compare.py b/benchmarks/compare.py index 4b609c744d50..7e51a38a92c2 100755 --- a/benchmarks/compare.py +++ b/benchmarks/compare.py @@ -18,7 +18,9 @@ from __future__ import annotations +import argparse import json +import math from dataclasses import dataclass from typing import Dict, List, Any from pathlib import Path @@ -47,6 +49,7 @@ class QueryRun: query: int iterations: List[QueryResult] start_time: int + success: bool = True @classmethod def load_from(cls, data: Dict[str, Any]) -> QueryRun: @@ -54,17 +57,57 @@ def load_from(cls, data: Dict[str, Any]) -> QueryRun: query=data["query"], iterations=[QueryResult(**iteration) for iteration in data["iterations"]], start_time=data["start_time"], + success=data.get("success", True), ) @property - def execution_time(self) -> float: + def min_execution_time(self) -> float: assert len(self.iterations) >= 1 - # Use minimum execution time to account for variations / other - # things the system was doing return min(iteration.elapsed for iteration in self.iterations) + @property + def max_execution_time(self) -> float: + assert len(self.iterations) >= 1 + + return max(iteration.elapsed for iteration in self.iterations) + + + @property + def mean_execution_time(self) -> float: + assert len(self.iterations) >= 1 + + total = sum(iteration.elapsed for iteration in self.iterations) + return total / len(self.iterations) + + + @property + def stddev_execution_time(self) -> float: + assert len(self.iterations) >= 1 + + mean = self.mean_execution_time + squared_diffs = [(iteration.elapsed - mean) ** 2 for iteration in self.iterations] + variance = sum(squared_diffs) / len(self.iterations) + return math.sqrt(variance) + + def execution_time_report(self, detailed = False) -> tuple[float, str]: + if detailed: + mean_execution_time = self.mean_execution_time + return ( + mean_execution_time, + f"{self.min_execution_time:.2f} / {mean_execution_time :.2f} ±{self.stddev_execution_time:.2f} / {self.max_execution_time:.2f} ms" + ) + else: + # Use minimum execution time to account for variations / other + # things the system was doing + min_execution_time = self.min_execution_time + return ( + min_execution_time, + f"{min_execution_time :.2f} ms" + ) + + @dataclass class Context: benchmark_version: str @@ -106,6 +149,7 @@ def compare( baseline_path: Path, comparison_path: Path, noise_threshold: float, + detailed: bool, ) -> None: baseline = BenchmarkRun.load_from_file(baseline_path) comparison = BenchmarkRun.load_from_file(comparison_path) @@ -125,16 +169,34 @@ def compare( faster_count = 0 slower_count = 0 no_change_count = 0 + failure_count = 0 total_baseline_time = 0 total_comparison_time = 0 for baseline_result, comparison_result in zip(baseline.queries, comparison.queries): assert baseline_result.query == comparison_result.query - - total_baseline_time += baseline_result.execution_time - total_comparison_time += comparison_result.execution_time - - change = comparison_result.execution_time / baseline_result.execution_time + + base_failed = not baseline_result.success + comp_failed = not comparison_result.success + # If a query fails, its execution time is excluded from the performance comparison + if base_failed or comp_failed: + change_text = "incomparable" + failure_count += 1 + table.add_row( + f"Q{baseline_result.query}", + "FAIL" if base_failed else baseline_result.execution_time_report(detailed)[1], + "FAIL" if comp_failed else comparison_result.execution_time_report(detailed)[1], + change_text, + ) + continue + + baseline_value, baseline_text = baseline_result.execution_time_report(detailed) + comparison_value, comparison_text = comparison_result.execution_time_report(detailed) + + total_baseline_time += baseline_value + total_comparison_time += comparison_value + + change = comparison_value / baseline_value if (1.0 - noise_threshold) <= change <= (1.0 + noise_threshold): change_text = "no change" @@ -148,16 +210,20 @@ def compare( table.add_row( f"Q{baseline_result.query}", - f"{baseline_result.execution_time:.2f}ms", - f"{comparison_result.execution_time:.2f}ms", + baseline_text, + comparison_text, change_text, ) console.print(table) # Calculate averages - avg_baseline_time = total_baseline_time / len(baseline.queries) - avg_comparison_time = total_comparison_time / len(comparison.queries) + avg_baseline_time = 0.0 + avg_comparison_time = 0.0 + if len(baseline.queries) - failure_count > 0: + avg_baseline_time = total_baseline_time / (len(baseline.queries) - failure_count) + if len(comparison.queries) - failure_count > 0: + avg_comparison_time = total_comparison_time / (len(comparison.queries) - failure_count) # Summary table summary_table = Table(show_header=True, header_style="bold magenta") @@ -171,6 +237,7 @@ def compare( summary_table.add_row("Queries Faster", str(faster_count)) summary_table.add_row("Queries Slower", str(slower_count)) summary_table.add_row("Queries with No Change", str(no_change_count)) + summary_table.add_row("Queries with Failure", str(failure_count)) console.print(summary_table) @@ -193,10 +260,16 @@ def main() -> None: default=0.05, help="The threshold for statistically insignificant results (+/- %5).", ) + compare_parser.add_argument( + "--detailed", + action=argparse.BooleanOptionalAction, + default=False, + help="Show detailed result comparison instead of minimum runtime.", + ) options = parser.parse_args() - compare(options.baseline_path, options.comparison_path, options.noise_threshold) + compare(options.baseline_path, options.comparison_path, options.noise_threshold, options.detailed) diff --git a/benchmarks/queries/clickbench/README.md b/benchmarks/queries/clickbench/README.md index 2032427e1ef2..877ea0e0c319 100644 --- a/benchmarks/queries/clickbench/README.md +++ b/benchmarks/queries/clickbench/README.md @@ -5,17 +5,18 @@ This directory contains queries for the ClickBench benchmark https://benchmark.c ClickBench is focused on aggregation and filtering performance (though it has no Joins) ## Files: -* `queries.sql` - Actual ClickBench queries, downloaded from the [ClickBench repository] -* `extended.sql` - "Extended" DataFusion specific queries. -[ClickBench repository]: https://github.com/ClickHouse/ClickBench/blob/main/datafusion/queries.sql +- `queries/*.sql` - Actual ClickBench queries, downloaded from the [ClickBench repository](https://raw.githubusercontent.com/ClickHouse/ClickBench/main/datafusion/queries.sql) and split by the `update_queries.sh` script. +- `extended/*.sql` - "Extended" DataFusion specific queries. -## "Extended" Queries +[clickbench repository]: https://github.com/ClickHouse/ClickBench/blob/main/datafusion/queries.sql + +## "Extended" Queries The "extended" queries are not part of the official ClickBench benchmark. Instead they are used to test other DataFusion features that are not covered by -the standard benchmark. Each description below is for the corresponding line in -`extended.sql` (line 1 is `Q0`, line 2 is `Q1`, etc.) +the standard benchmark. Each description below is for the corresponding file in +`extended` ### Q0: Data Exploration @@ -25,7 +26,7 @@ the standard benchmark. Each description below is for the corresponding line in distinct string columns. ```sql -SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") +SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; ``` @@ -35,7 +36,6 @@ FROM hits; **Important Query Properties**: multiple `COUNT DISTINCT`s. All three are small strings (length either 1 or 2). - ```sql SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; @@ -43,21 +43,20 @@ FROM hits; ### Q2: Top 10 analysis -**Question**: "Find the top 10 "browser country" by number of distinct "social network"s, -including the distinct counts of "hit color", "browser language", +**Question**: "Find the top 10 "browser country" by number of distinct "social network"s, +including the distinct counts of "hit color", "browser language", and "social action"." **Important Query Properties**: GROUP BY short, string, multiple `COUNT DISTINCT`s. There are several small strings (length either 1 or 2). ```sql SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") -FROM hits -GROUP BY 1 -ORDER BY 2 DESC +FROM hits +GROUP BY 1 +ORDER BY 2 DESC LIMIT 10; ``` - ### Q3: What is the income distribution for users in specific regions **Question**: "What regions and social networks have the highest variance of parameter price?" @@ -65,17 +64,17 @@ LIMIT 10; **Important Query Properties**: STDDEV and VAR aggregation functions, GROUP BY multiple small ints ```sql -SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") -FROM 'hits.parquet' -GROUP BY "SocialSourceNetworkID", "RegionID" +SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") +FROM 'hits.parquet' +GROUP BY "SocialSourceNetworkID", "RegionID" HAVING s IS NOT NULL -ORDER BY s DESC +ORDER BY s DESC LIMIT 10; ``` ### Q4: Response start time distribution analysis (median) -**Question**: Find the WatchIDs with the highest median "ResponseStartTiming" without Java enabled +**Question**: Find the WatchIDs with the highest median "ResponseStartTiming" without Java enabled **Important Query Properties**: MEDIAN, functions, high cardinality grouping that skips intermediate aggregation @@ -102,17 +101,16 @@ Results look like +-------------+---------------------+---+------+------+------+ ``` - ### Q5: Response start time distribution analysis (p95) -**Question**: Find the WatchIDs with the highest p95 "ResponseStartTiming" without Java enabled +**Question**: Find the WatchIDs with the highest p95 "ResponseStartTiming" without Java enabled **Important Query Properties**: APPROX_PERCENTILE_CONT, functions, high cardinality grouping that skips intermediate aggregation Note this query is somewhat synthetic as "WatchID" is almost unique (there are a few duplicates) ```sql -SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT("ResponseStartTiming", 0.95) tp95, MAX("ResponseStartTiming") tmax +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY "ResponseStartTiming") tp95, MAX("ResponseStartTiming") tmax FROM 'hits.parquet' WHERE "JavaEnable" = 0 -- filters to 32M of 100M rows GROUP BY "ClientIP", "WatchID" @@ -122,6 +120,7 @@ LIMIT 10; ``` Results look like + ``` +-------------+---------------------+---+------+------+------+ | ClientIP | WatchID | c | tmin | tp95 | tmax | @@ -132,6 +131,7 @@ Results look like ``` ### Q6: How many social shares meet complex multi-stage filtering criteria? + **Question**: What is the count of sharing actions from iPhone mobile users on specific social networks, within common timezones, participating in seasonal campaigns, with high screen resolutions and closely matched UTM parameters? **Important Query Properties**: Simple filter with high-selectivity, Costly string matching, A large number of filters with high overhead are positioned relatively later in the process @@ -150,20 +150,89 @@ WHERE -- Stage 3: Heavy computations (expensive) AND regexp_match("Referer", '\/campaign\/(spring|summer)_promo') IS NOT NULL -- Find campaign-specific referrers - AND CASE - WHEN split_part(split_part("URL", 'resolution=', 2), '&', 1) ~ '^\d+$' - THEN split_part(split_part("URL", 'resolution=', 2), '&', 1)::INT - ELSE 0 + AND CASE + WHEN split_part(split_part("URL", 'resolution=', 2), '&', 1) ~ '^\d+$' + THEN split_part(split_part("URL", 'resolution=', 2), '&', 1)::INT + ELSE 0 END > 1920 -- Extract and validate resolution parameter - AND levenshtein("UTMSource", "UTMCampaign") < 3 -- Verify UTM parameter similarity + AND levenshtein(CAST("UTMSource" AS STRING), CAST("UTMCampaign" AS STRING)) < 3 -- Verify UTM parameter similarity ``` + Result is empty,Since it has already been filtered by `"SocialAction" = 'share'`. +### Q7: Device Resolution and Refresh Behavior Analysis + +**Question**: Identify the top 10 WatchIDs with the highest resolution range (min/max "ResolutionWidth") and total refresh count ("IsRefresh") in descending WatchID order + +**Important Query Properties**: Primitive aggregation functions, group by single primitive column, high cardinality grouping + +```sql +SELECT "WatchID", MIN("ResolutionWidth") as wmin, MAX("ResolutionWidth") as wmax, SUM("IsRefresh") as srefresh +FROM hits +GROUP BY "WatchID" +ORDER BY "WatchID" DESC +LIMIT 10; +``` + +Results look like + +``` ++---------------------+------+------+----------+ +| WatchID | wmin | wmax | srefresh | ++---------------------+------+------+----------+ +| 9223372033328793741 | 1368 | 1368 | 0 | +| 9223371941779979288 | 1479 | 1479 | 0 | +| 9223371906781104763 | 1638 | 1638 | 0 | +| 9223371803397398692 | 1990 | 1990 | 0 | +| 9223371799215233959 | 1638 | 1638 | 0 | +| 9223371785975219972 | 0 | 0 | 0 | +| 9223371776706839366 | 1368 | 1368 | 0 | +| 9223371740707848038 | 1750 | 1750 | 0 | +| 9223371715190479830 | 1368 | 1368 | 0 | +| 9223371620124912624 | 1828 | 1828 | 0 | ++---------------------+------+------+----------+ +``` + +### Q8: Average Latency and Response Time Analysis + +**Question**: Which combinations of operating system, region, and user agent exhibit the highest average latency? For each of these combinations, also report the average response time. + +**Important Query Properties**: Multiple average of Duration, high cardinality grouping + +```sql +SELECT "RegionID", "UserAgent", "OS", AVG(to_timestamp("ResponseEndTiming")-to_timestamp("ResponseStartTiming")) as avg_response_time, AVG(to_timestamp("ResponseEndTiming")-to_timestamp("ConnectTiming")) as avg_latency +FROM hits +GROUP BY "RegionID", "UserAgent", "OS" +ORDER BY avg_latency DESC +LIMIT 10; +``` + +Results look like + +``` ++----------+-----------+-----+------------------------------------------+------------------------------------------+ +| RegionID | UserAgent | OS | avg_response_time | avg_latency | ++----------+-----------+-----+------------------------------------------+------------------------------------------+ +| 22934 | 5 | 126 | 0 days 8 hours 20 mins 0.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 22735 | 82 | 74 | 0 days 8 hours 20 mins 0.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 21687 | 32 | 49 | 0 days 8 hours 20 mins 0.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 18518 | 82 | 77 | 0 days 8 hours 20 mins 0.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 14006 | 7 | 126 | 0 days 7 hours 58 mins 20.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 9803 | 82 | 77 | 0 days 8 hours 20 mins 0.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 107108 | 82 | 77 | 0 days 8 hours 20 mins 0.000000000 secs | 0 days 8 hours 20 mins 0.000000000 secs | +| 111626 | 7 | 44 | 0 days 7 hours 23 mins 12.500000000 secs | 0 days 8 hours 0 mins 47.000000000 secs | +| 17716 | 56 | 44 | 0 days 6 hours 48 mins 44.500000000 secs | 0 days 7 hours 35 mins 47.000000000 secs | +| 13631 | 82 | 45 | 0 days 7 hours 23 mins 1.000000000 secs | 0 days 7 hours 23 mins 1.000000000 secs | ++----------+-----------+-----+------------------------------------------+------------------------------------------+ +10 row(s) fetched. +Elapsed 30.195 seconds. +``` ## Data Notes Here are some interesting statistics about the data used in the queries Max length of `"SearchPhrase"` is 1113 characters + ```sql > select min(length("SearchPhrase")) as "SearchPhrase_len_min", max(length("SearchPhrase")) "SearchPhrase_len_max" from 'hits.parquet' limit 10; +----------------------+----------------------+ @@ -173,8 +242,8 @@ Max length of `"SearchPhrase"` is 1113 characters +----------------------+----------------------+ ``` - Here is the schema of the data + ```sql > describe 'hits.parquet'; +-----------------------+-----------+-------------+ diff --git a/benchmarks/queries/clickbench/extended.sql b/benchmarks/queries/clickbench/extended.sql deleted file mode 100644 index ef3a409c9c02..000000000000 --- a/benchmarks/queries/clickbench/extended.sql +++ /dev/null @@ -1,7 +0,0 @@ -SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; -SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; -SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; -SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") FROM hits GROUP BY "SocialSourceNetworkID", "RegionID" HAVING s IS NOT NULL ORDER BY s DESC LIMIT 10; -SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, MEDIAN("ResponseStartTiming") tmed, MAX("ResponseStartTiming") tmax FROM hits WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tmed DESC LIMIT 10; -SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT("ResponseStartTiming", 0.95) tp95, MAX("ResponseStartTiming") tmax FROM 'hits' WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tp95 DESC LIMIT 10; -SELECT COUNT(*) AS ShareCount FROM hits WHERE "IsMobile" = 1 AND "MobilePhoneModel" LIKE 'iPhone%' AND "SocialAction" = 'share' AND "SocialSourceNetworkID" IN (5, 12) AND "ClientTimeZone" BETWEEN -5 AND 5 AND regexp_match("Referer", '\/campaign\/(spring|summer)_promo') IS NOT NULL AND CASE WHEN split_part(split_part("URL", 'resolution=', 2), '&', 1) ~ '^\d+$' THEN split_part(split_part("URL", 'resolution=', 2), '&', 1)::INT ELSE 0 END > 1920 AND levenshtein("UTMSource", "UTMCampaign") < 3; \ No newline at end of file diff --git a/benchmarks/queries/clickbench/extended/q0.sql b/benchmarks/queries/clickbench/extended/q0.sql new file mode 100644 index 000000000000..a1e55b5b25ac --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q0.sql @@ -0,0 +1 @@ +SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; diff --git a/benchmarks/queries/clickbench/extended/q1.sql b/benchmarks/queries/clickbench/extended/q1.sql new file mode 100644 index 000000000000..84fac921c8cb --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q1.sql @@ -0,0 +1 @@ +SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; diff --git a/benchmarks/queries/clickbench/extended/q2.sql b/benchmarks/queries/clickbench/extended/q2.sql new file mode 100644 index 000000000000..9832ce44d4cb --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q2.sql @@ -0,0 +1 @@ +SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/extended/q3.sql b/benchmarks/queries/clickbench/extended/q3.sql new file mode 100644 index 000000000000..d1661bc216e5 --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q3.sql @@ -0,0 +1 @@ +SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") FROM hits GROUP BY "SocialSourceNetworkID", "RegionID" HAVING s IS NOT NULL ORDER BY s DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/extended/q4.sql b/benchmarks/queries/clickbench/extended/q4.sql new file mode 100644 index 000000000000..bd54956a2bcd --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q4.sql @@ -0,0 +1 @@ +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, MEDIAN("ResponseStartTiming") tmed, MAX("ResponseStartTiming") tmax FROM hits WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tmed DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/extended/q5.sql b/benchmarks/queries/clickbench/extended/q5.sql new file mode 100644 index 000000000000..9de2f517d09b --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q5.sql @@ -0,0 +1 @@ +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY "ResponseStartTiming") tp95, MAX("ResponseStartTiming") tmax FROM 'hits' WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tp95 DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/extended/q6.sql b/benchmarks/queries/clickbench/extended/q6.sql new file mode 100644 index 000000000000..091e8867c7ef --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q6.sql @@ -0,0 +1 @@ +SELECT COUNT(*) AS ShareCount FROM hits WHERE "IsMobile" = 1 AND "MobilePhoneModel" LIKE 'iPhone%' AND "SocialAction" = 'share' AND "SocialSourceNetworkID" IN (5, 12) AND "ClientTimeZone" BETWEEN -5 AND 5 AND regexp_match("Referer", '\/campaign\/(spring|summer)_promo') IS NOT NULL AND CASE WHEN split_part(split_part("URL", 'resolution=', 2), '&', 1) ~ '^\d+$' THEN split_part(split_part("URL", 'resolution=', 2), '&', 1)::INT ELSE 0 END > 1920 AND levenshtein(CAST("UTMSource" AS STRING), CAST("UTMCampaign" AS STRING)) < 3; diff --git a/benchmarks/queries/clickbench/extended/q7.sql b/benchmarks/queries/clickbench/extended/q7.sql new file mode 100644 index 000000000000..ddaff7f8804f --- /dev/null +++ b/benchmarks/queries/clickbench/extended/q7.sql @@ -0,0 +1 @@ +SELECT "WatchID", MIN("ResolutionWidth") as wmin, MAX("ResolutionWidth") as wmax, SUM("IsRefresh") as srefresh FROM hits GROUP BY "WatchID" ORDER BY "WatchID" DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries.sql b/benchmarks/queries/clickbench/queries.sql deleted file mode 100644 index 9a183cd6e259..000000000000 --- a/benchmarks/queries/clickbench/queries.sql +++ /dev/null @@ -1,43 +0,0 @@ -SELECT COUNT(*) FROM hits; -SELECT COUNT(*) FROM hits WHERE "AdvEngineID" <> 0; -SELECT SUM("AdvEngineID"), COUNT(*), AVG("ResolutionWidth") FROM hits; -SELECT AVG("UserID") FROM hits; -SELECT COUNT(DISTINCT "UserID") FROM hits; -SELECT COUNT(DISTINCT "SearchPhrase") FROM hits; -SELECT MIN("EventDate"), MAX("EventDate") FROM hits; -SELECT "AdvEngineID", COUNT(*) FROM hits WHERE "AdvEngineID" <> 0 GROUP BY "AdvEngineID" ORDER BY COUNT(*) DESC; -SELECT "RegionID", COUNT(DISTINCT "UserID") AS u FROM hits GROUP BY "RegionID" ORDER BY u DESC LIMIT 10; -SELECT "RegionID", SUM("AdvEngineID"), COUNT(*) AS c, AVG("ResolutionWidth"), COUNT(DISTINCT "UserID") FROM hits GROUP BY "RegionID" ORDER BY c DESC LIMIT 10; -SELECT "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhoneModel" ORDER BY u DESC LIMIT 10; -SELECT "MobilePhone", "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhone", "MobilePhoneModel" ORDER BY u DESC LIMIT 10; -SELECT "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; -SELECT "SearchPhrase", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY u DESC LIMIT 10; -SELECT "SearchEngineID", "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "SearchPhrase" ORDER BY c DESC LIMIT 10; -SELECT "UserID", COUNT(*) FROM hits GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10; -SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10; -SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" LIMIT 10; -SELECT "UserID", extract(minute FROM to_timestamp_seconds("EventTime")) AS m, "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", m, "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10; -SELECT "UserID" FROM hits WHERE "UserID" = 435090932899640449; -SELECT COUNT(*) FROM hits WHERE "URL" LIKE '%google%'; -SELECT "SearchPhrase", MIN("URL"), COUNT(*) AS c FROM hits WHERE "URL" LIKE '%google%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; -SELECT "SearchPhrase", MIN("URL"), MIN("Title"), COUNT(*) AS c, COUNT(DISTINCT "UserID") FROM hits WHERE "Title" LIKE '%Google%' AND "URL" NOT LIKE '%.google.%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; -SELECT * FROM hits WHERE "URL" LIKE '%google%' ORDER BY "EventTime" LIMIT 10; -SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "EventTime" LIMIT 10; -SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "SearchPhrase" LIMIT 10; -SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "EventTime", "SearchPhrase" LIMIT 10; -SELECT "CounterID", AVG(length("URL")) AS l, COUNT(*) AS c FROM hits WHERE "URL" <> '' GROUP BY "CounterID" HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; -SELECT REGEXP_REPLACE("Referer", '^https?://(?:www\.)?([^/]+)/.*$', '\1') AS k, AVG(length("Referer")) AS l, COUNT(*) AS c, MIN("Referer") FROM hits WHERE "Referer" <> '' GROUP BY k HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; -SELECT SUM("ResolutionWidth"), SUM("ResolutionWidth" + 1), SUM("ResolutionWidth" + 2), SUM("ResolutionWidth" + 3), SUM("ResolutionWidth" + 4), SUM("ResolutionWidth" + 5), SUM("ResolutionWidth" + 6), SUM("ResolutionWidth" + 7), SUM("ResolutionWidth" + 8), SUM("ResolutionWidth" + 9), SUM("ResolutionWidth" + 10), SUM("ResolutionWidth" + 11), SUM("ResolutionWidth" + 12), SUM("ResolutionWidth" + 13), SUM("ResolutionWidth" + 14), SUM("ResolutionWidth" + 15), SUM("ResolutionWidth" + 16), SUM("ResolutionWidth" + 17), SUM("ResolutionWidth" + 18), SUM("ResolutionWidth" + 19), SUM("ResolutionWidth" + 20), SUM("ResolutionWidth" + 21), SUM("ResolutionWidth" + 22), SUM("ResolutionWidth" + 23), SUM("ResolutionWidth" + 24), SUM("ResolutionWidth" + 25), SUM("ResolutionWidth" + 26), SUM("ResolutionWidth" + 27), SUM("ResolutionWidth" + 28), SUM("ResolutionWidth" + 29), SUM("ResolutionWidth" + 30), SUM("ResolutionWidth" + 31), SUM("ResolutionWidth" + 32), SUM("ResolutionWidth" + 33), SUM("ResolutionWidth" + 34), SUM("ResolutionWidth" + 35), SUM("ResolutionWidth" + 36), SUM("ResolutionWidth" + 37), SUM("ResolutionWidth" + 38), SUM("ResolutionWidth" + 39), SUM("ResolutionWidth" + 40), SUM("ResolutionWidth" + 41), SUM("ResolutionWidth" + 42), SUM("ResolutionWidth" + 43), SUM("ResolutionWidth" + 44), SUM("ResolutionWidth" + 45), SUM("ResolutionWidth" + 46), SUM("ResolutionWidth" + 47), SUM("ResolutionWidth" + 48), SUM("ResolutionWidth" + 49), SUM("ResolutionWidth" + 50), SUM("ResolutionWidth" + 51), SUM("ResolutionWidth" + 52), SUM("ResolutionWidth" + 53), SUM("ResolutionWidth" + 54), SUM("ResolutionWidth" + 55), SUM("ResolutionWidth" + 56), SUM("ResolutionWidth" + 57), SUM("ResolutionWidth" + 58), SUM("ResolutionWidth" + 59), SUM("ResolutionWidth" + 60), SUM("ResolutionWidth" + 61), SUM("ResolutionWidth" + 62), SUM("ResolutionWidth" + 63), SUM("ResolutionWidth" + 64), SUM("ResolutionWidth" + 65), SUM("ResolutionWidth" + 66), SUM("ResolutionWidth" + 67), SUM("ResolutionWidth" + 68), SUM("ResolutionWidth" + 69), SUM("ResolutionWidth" + 70), SUM("ResolutionWidth" + 71), SUM("ResolutionWidth" + 72), SUM("ResolutionWidth" + 73), SUM("ResolutionWidth" + 74), SUM("ResolutionWidth" + 75), SUM("ResolutionWidth" + 76), SUM("ResolutionWidth" + 77), SUM("ResolutionWidth" + 78), SUM("ResolutionWidth" + 79), SUM("ResolutionWidth" + 80), SUM("ResolutionWidth" + 81), SUM("ResolutionWidth" + 82), SUM("ResolutionWidth" + 83), SUM("ResolutionWidth" + 84), SUM("ResolutionWidth" + 85), SUM("ResolutionWidth" + 86), SUM("ResolutionWidth" + 87), SUM("ResolutionWidth" + 88), SUM("ResolutionWidth" + 89) FROM hits; -SELECT "SearchEngineID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "ClientIP" ORDER BY c DESC LIMIT 10; -SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10; -SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10; -SELECT "URL", COUNT(*) AS c FROM hits GROUP BY "URL" ORDER BY c DESC LIMIT 10; -SELECT 1, "URL", COUNT(*) AS c FROM hits GROUP BY 1, "URL" ORDER BY c DESC LIMIT 10; -SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*) AS c FROM hits GROUP BY "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3 ORDER BY c DESC LIMIT 10; -SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "URL" <> '' GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10; -SELECT "Title", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "Title" <> '' GROUP BY "Title" ORDER BY PageViews DESC LIMIT 10; -SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "IsLink" <> 0 AND "IsDownload" = 0 GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; -SELECT "TraficSourceID", "SearchEngineID", "AdvEngineID", CASE WHEN ("SearchEngineID" = 0 AND "AdvEngineID" = 0) THEN "Referer" ELSE '' END AS Src, "URL" AS Dst, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 GROUP BY "TraficSourceID", "SearchEngineID", "AdvEngineID", Src, Dst ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; -SELECT "URLHash", "EventDate", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "TraficSourceID" IN (-1, 6) AND "RefererHash" = 3594120000172545465 GROUP BY "URLHash", "EventDate" ORDER BY PageViews DESC LIMIT 10 OFFSET 100; -SELECT "WindowClientWidth", "WindowClientHeight", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "DontCountHits" = 0 AND "URLHash" = 2868770270353813622 GROUP BY "WindowClientWidth", "WindowClientHeight" ORDER BY PageViews DESC LIMIT 10 OFFSET 10000; -SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-14' AND "EventDate" <= '2013-07-15' AND "IsRefresh" = 0 AND "DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 OFFSET 1000; diff --git a/benchmarks/queries/clickbench/queries/q0.sql b/benchmarks/queries/clickbench/queries/q0.sql new file mode 100644 index 000000000000..c70aa7a844d7 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q0.sql @@ -0,0 +1 @@ +SELECT COUNT(*) FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q1.sql b/benchmarks/queries/clickbench/queries/q1.sql new file mode 100644 index 000000000000..283a5c3cc82b --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q1.sql @@ -0,0 +1 @@ +SELECT COUNT(*) FROM hits WHERE "AdvEngineID" <> 0; diff --git a/benchmarks/queries/clickbench/queries/q10.sql b/benchmarks/queries/clickbench/queries/q10.sql new file mode 100644 index 000000000000..dd44e5c49368 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q10.sql @@ -0,0 +1 @@ +SELECT "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhoneModel" ORDER BY u DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q11.sql b/benchmarks/queries/clickbench/queries/q11.sql new file mode 100644 index 000000000000..9349d450699c --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q11.sql @@ -0,0 +1 @@ +SELECT "MobilePhone", "MobilePhoneModel", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "MobilePhoneModel" <> '' GROUP BY "MobilePhone", "MobilePhoneModel" ORDER BY u DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q12.sql b/benchmarks/queries/clickbench/queries/q12.sql new file mode 100644 index 000000000000..908af6314988 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q12.sql @@ -0,0 +1 @@ +SELECT "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q13.sql b/benchmarks/queries/clickbench/queries/q13.sql new file mode 100644 index 000000000000..46e1e6b4a74d --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q13.sql @@ -0,0 +1 @@ +SELECT "SearchPhrase", COUNT(DISTINCT "UserID") AS u FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY u DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q14.sql b/benchmarks/queries/clickbench/queries/q14.sql new file mode 100644 index 000000000000..d6c5118168f0 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q14.sql @@ -0,0 +1 @@ +SELECT "SearchEngineID", "SearchPhrase", COUNT(*) AS c FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "SearchPhrase" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q15.sql b/benchmarks/queries/clickbench/queries/q15.sql new file mode 100644 index 000000000000..f5b4e511a886 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q15.sql @@ -0,0 +1 @@ +SELECT "UserID", COUNT(*) FROM hits GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q16.sql b/benchmarks/queries/clickbench/queries/q16.sql new file mode 100644 index 000000000000..38e44b684941 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q16.sql @@ -0,0 +1 @@ +SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q17.sql b/benchmarks/queries/clickbench/queries/q17.sql new file mode 100644 index 000000000000..1a97cdd36a24 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q17.sql @@ -0,0 +1 @@ +SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q18.sql b/benchmarks/queries/clickbench/queries/q18.sql new file mode 100644 index 000000000000..5aeeedf78ee0 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q18.sql @@ -0,0 +1 @@ +SELECT "UserID", extract(minute FROM to_timestamp_seconds("EventTime")) AS m, "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", m, "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q19.sql b/benchmarks/queries/clickbench/queries/q19.sql new file mode 100644 index 000000000000..e388497dd1ec --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q19.sql @@ -0,0 +1 @@ +SELECT "UserID" FROM hits WHERE "UserID" = 435090932899640449; diff --git a/benchmarks/queries/clickbench/queries/q2.sql b/benchmarks/queries/clickbench/queries/q2.sql new file mode 100644 index 000000000000..9938e3081dd2 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q2.sql @@ -0,0 +1 @@ +SELECT SUM("AdvEngineID"), COUNT(*), AVG("ResolutionWidth") FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q20.sql b/benchmarks/queries/clickbench/queries/q20.sql new file mode 100644 index 000000000000..a7e6995c1f1b --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q20.sql @@ -0,0 +1 @@ +SELECT COUNT(*) FROM hits WHERE "URL" LIKE '%google%'; diff --git a/benchmarks/queries/clickbench/queries/q21.sql b/benchmarks/queries/clickbench/queries/q21.sql new file mode 100644 index 000000000000..d857899d136c --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q21.sql @@ -0,0 +1 @@ +SELECT "SearchPhrase", MIN("URL"), COUNT(*) AS c FROM hits WHERE "URL" LIKE '%google%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q22.sql b/benchmarks/queries/clickbench/queries/q22.sql new file mode 100644 index 000000000000..8ac4f099c484 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q22.sql @@ -0,0 +1 @@ +SELECT "SearchPhrase", MIN("URL"), MIN("Title"), COUNT(*) AS c, COUNT(DISTINCT "UserID") FROM hits WHERE "Title" LIKE '%Google%' AND "URL" NOT LIKE '%.google.%' AND "SearchPhrase" <> '' GROUP BY "SearchPhrase" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q23.sql b/benchmarks/queries/clickbench/queries/q23.sql new file mode 100644 index 000000000000..3623b0fed806 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q23.sql @@ -0,0 +1 @@ +SELECT * FROM hits WHERE "URL" LIKE '%google%' ORDER BY "EventTime" LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q24.sql b/benchmarks/queries/clickbench/queries/q24.sql new file mode 100644 index 000000000000..cee774aafe53 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q24.sql @@ -0,0 +1 @@ +SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "EventTime" LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q25.sql b/benchmarks/queries/clickbench/queries/q25.sql new file mode 100644 index 000000000000..048b4cd9d3e2 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q25.sql @@ -0,0 +1 @@ +SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "SearchPhrase" LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q26.sql b/benchmarks/queries/clickbench/queries/q26.sql new file mode 100644 index 000000000000..104e8d50ecb0 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q26.sql @@ -0,0 +1 @@ +SELECT "SearchPhrase" FROM hits WHERE "SearchPhrase" <> '' ORDER BY "EventTime", "SearchPhrase" LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q27.sql b/benchmarks/queries/clickbench/queries/q27.sql new file mode 100644 index 000000000000..c84cad9296e0 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q27.sql @@ -0,0 +1 @@ +SELECT "CounterID", AVG(length("URL")) AS l, COUNT(*) AS c FROM hits WHERE "URL" <> '' GROUP BY "CounterID" HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; diff --git a/benchmarks/queries/clickbench/queries/q28.sql b/benchmarks/queries/clickbench/queries/q28.sql new file mode 100644 index 000000000000..8c5a51877f32 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q28.sql @@ -0,0 +1 @@ +SELECT REGEXP_REPLACE("Referer", '^https?://(?:www\.)?([^/]+)/.*$', '\1') AS k, AVG(length("Referer")) AS l, COUNT(*) AS c, MIN("Referer") FROM hits WHERE "Referer" <> '' GROUP BY k HAVING COUNT(*) > 100000 ORDER BY l DESC LIMIT 25; diff --git a/benchmarks/queries/clickbench/queries/q29.sql b/benchmarks/queries/clickbench/queries/q29.sql new file mode 100644 index 000000000000..bfff2509062d --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q29.sql @@ -0,0 +1 @@ +SELECT SUM("ResolutionWidth"), SUM("ResolutionWidth" + 1), SUM("ResolutionWidth" + 2), SUM("ResolutionWidth" + 3), SUM("ResolutionWidth" + 4), SUM("ResolutionWidth" + 5), SUM("ResolutionWidth" + 6), SUM("ResolutionWidth" + 7), SUM("ResolutionWidth" + 8), SUM("ResolutionWidth" + 9), SUM("ResolutionWidth" + 10), SUM("ResolutionWidth" + 11), SUM("ResolutionWidth" + 12), SUM("ResolutionWidth" + 13), SUM("ResolutionWidth" + 14), SUM("ResolutionWidth" + 15), SUM("ResolutionWidth" + 16), SUM("ResolutionWidth" + 17), SUM("ResolutionWidth" + 18), SUM("ResolutionWidth" + 19), SUM("ResolutionWidth" + 20), SUM("ResolutionWidth" + 21), SUM("ResolutionWidth" + 22), SUM("ResolutionWidth" + 23), SUM("ResolutionWidth" + 24), SUM("ResolutionWidth" + 25), SUM("ResolutionWidth" + 26), SUM("ResolutionWidth" + 27), SUM("ResolutionWidth" + 28), SUM("ResolutionWidth" + 29), SUM("ResolutionWidth" + 30), SUM("ResolutionWidth" + 31), SUM("ResolutionWidth" + 32), SUM("ResolutionWidth" + 33), SUM("ResolutionWidth" + 34), SUM("ResolutionWidth" + 35), SUM("ResolutionWidth" + 36), SUM("ResolutionWidth" + 37), SUM("ResolutionWidth" + 38), SUM("ResolutionWidth" + 39), SUM("ResolutionWidth" + 40), SUM("ResolutionWidth" + 41), SUM("ResolutionWidth" + 42), SUM("ResolutionWidth" + 43), SUM("ResolutionWidth" + 44), SUM("ResolutionWidth" + 45), SUM("ResolutionWidth" + 46), SUM("ResolutionWidth" + 47), SUM("ResolutionWidth" + 48), SUM("ResolutionWidth" + 49), SUM("ResolutionWidth" + 50), SUM("ResolutionWidth" + 51), SUM("ResolutionWidth" + 52), SUM("ResolutionWidth" + 53), SUM("ResolutionWidth" + 54), SUM("ResolutionWidth" + 55), SUM("ResolutionWidth" + 56), SUM("ResolutionWidth" + 57), SUM("ResolutionWidth" + 58), SUM("ResolutionWidth" + 59), SUM("ResolutionWidth" + 60), SUM("ResolutionWidth" + 61), SUM("ResolutionWidth" + 62), SUM("ResolutionWidth" + 63), SUM("ResolutionWidth" + 64), SUM("ResolutionWidth" + 65), SUM("ResolutionWidth" + 66), SUM("ResolutionWidth" + 67), SUM("ResolutionWidth" + 68), SUM("ResolutionWidth" + 69), SUM("ResolutionWidth" + 70), SUM("ResolutionWidth" + 71), SUM("ResolutionWidth" + 72), SUM("ResolutionWidth" + 73), SUM("ResolutionWidth" + 74), SUM("ResolutionWidth" + 75), SUM("ResolutionWidth" + 76), SUM("ResolutionWidth" + 77), SUM("ResolutionWidth" + 78), SUM("ResolutionWidth" + 79), SUM("ResolutionWidth" + 80), SUM("ResolutionWidth" + 81), SUM("ResolutionWidth" + 82), SUM("ResolutionWidth" + 83), SUM("ResolutionWidth" + 84), SUM("ResolutionWidth" + 85), SUM("ResolutionWidth" + 86), SUM("ResolutionWidth" + 87), SUM("ResolutionWidth" + 88), SUM("ResolutionWidth" + 89) FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q3.sql b/benchmarks/queries/clickbench/queries/q3.sql new file mode 100644 index 000000000000..db818fa013ef --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q3.sql @@ -0,0 +1 @@ +SELECT AVG("UserID") FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q30.sql b/benchmarks/queries/clickbench/queries/q30.sql new file mode 100644 index 000000000000..8b4bf19b7f9c --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q30.sql @@ -0,0 +1 @@ +SELECT "SearchEngineID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "SearchEngineID", "ClientIP" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q31.sql b/benchmarks/queries/clickbench/queries/q31.sql new file mode 100644 index 000000000000..5ab49a38b804 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q31.sql @@ -0,0 +1 @@ +SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits WHERE "SearchPhrase" <> '' GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q32.sql b/benchmarks/queries/clickbench/queries/q32.sql new file mode 100644 index 000000000000..d00bc12405ed --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q32.sql @@ -0,0 +1 @@ +SELECT "WatchID", "ClientIP", COUNT(*) AS c, SUM("IsRefresh"), AVG("ResolutionWidth") FROM hits GROUP BY "WatchID", "ClientIP" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q33.sql b/benchmarks/queries/clickbench/queries/q33.sql new file mode 100644 index 000000000000..45d491d1c30b --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q33.sql @@ -0,0 +1 @@ +SELECT "URL", COUNT(*) AS c FROM hits GROUP BY "URL" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q34.sql b/benchmarks/queries/clickbench/queries/q34.sql new file mode 100644 index 000000000000..7e878804de06 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q34.sql @@ -0,0 +1 @@ +SELECT 1, "URL", COUNT(*) AS c FROM hits GROUP BY 1, "URL" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q35.sql b/benchmarks/queries/clickbench/queries/q35.sql new file mode 100644 index 000000000000..c03da84fb19e --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q35.sql @@ -0,0 +1 @@ +SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*) AS c FROM hits GROUP BY "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3 ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q36.sql b/benchmarks/queries/clickbench/queries/q36.sql new file mode 100644 index 000000000000..b76dce5cab9e --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q36.sql @@ -0,0 +1 @@ +SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "URL" <> '' GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q37.sql b/benchmarks/queries/clickbench/queries/q37.sql new file mode 100644 index 000000000000..49017e3a5f1d --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q37.sql @@ -0,0 +1 @@ +SELECT "Title", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "DontCountHits" = 0 AND "IsRefresh" = 0 AND "Title" <> '' GROUP BY "Title" ORDER BY PageViews DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q38.sql b/benchmarks/queries/clickbench/queries/q38.sql new file mode 100644 index 000000000000..b0cb6814bd85 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q38.sql @@ -0,0 +1 @@ +SELECT "URL", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "IsLink" <> 0 AND "IsDownload" = 0 GROUP BY "URL" ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; diff --git a/benchmarks/queries/clickbench/queries/q39.sql b/benchmarks/queries/clickbench/queries/q39.sql new file mode 100644 index 000000000000..8327eb9bd572 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q39.sql @@ -0,0 +1 @@ +SELECT "TraficSourceID", "SearchEngineID", "AdvEngineID", CASE WHEN ("SearchEngineID" = 0 AND "AdvEngineID" = 0) THEN "Referer" ELSE '' END AS Src, "URL" AS Dst, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 GROUP BY "TraficSourceID", "SearchEngineID", "AdvEngineID", Src, Dst ORDER BY PageViews DESC LIMIT 10 OFFSET 1000; diff --git a/benchmarks/queries/clickbench/queries/q4.sql b/benchmarks/queries/clickbench/queries/q4.sql new file mode 100644 index 000000000000..027310ad7526 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q4.sql @@ -0,0 +1 @@ +SELECT COUNT(DISTINCT "UserID") FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q40.sql b/benchmarks/queries/clickbench/queries/q40.sql new file mode 100644 index 000000000000..d30d7c414271 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q40.sql @@ -0,0 +1 @@ +SELECT "URLHash", "EventDate", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "TraficSourceID" IN (-1, 6) AND "RefererHash" = 3594120000172545465 GROUP BY "URLHash", "EventDate" ORDER BY PageViews DESC LIMIT 10 OFFSET 100; diff --git a/benchmarks/queries/clickbench/queries/q41.sql b/benchmarks/queries/clickbench/queries/q41.sql new file mode 100644 index 000000000000..0e9a51a7f54c --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q41.sql @@ -0,0 +1 @@ +SELECT "WindowClientWidth", "WindowClientHeight", COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-01' AND "EventDate" <= '2013-07-31' AND "IsRefresh" = 0 AND "DontCountHits" = 0 AND "URLHash" = 2868770270353813622 GROUP BY "WindowClientWidth", "WindowClientHeight" ORDER BY PageViews DESC LIMIT 10 OFFSET 10000; diff --git a/benchmarks/queries/clickbench/queries/q42.sql b/benchmarks/queries/clickbench/queries/q42.sql new file mode 100644 index 000000000000..dcad5daa1b67 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q42.sql @@ -0,0 +1 @@ +SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate" >= '2013-07-14' AND "EventDate" <= '2013-07-15' AND "IsRefresh" = 0 AND "DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 OFFSET 1000; diff --git a/benchmarks/queries/clickbench/queries/q5.sql b/benchmarks/queries/clickbench/queries/q5.sql new file mode 100644 index 000000000000..35b17097d87c --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q5.sql @@ -0,0 +1 @@ +SELECT COUNT(DISTINCT "SearchPhrase") FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q6.sql b/benchmarks/queries/clickbench/queries/q6.sql new file mode 100644 index 000000000000..684103643652 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q6.sql @@ -0,0 +1 @@ +SELECT MIN("EventDate"), MAX("EventDate") FROM hits; diff --git a/benchmarks/queries/clickbench/queries/q7.sql b/benchmarks/queries/clickbench/queries/q7.sql new file mode 100644 index 000000000000..ab8528c1b141 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q7.sql @@ -0,0 +1 @@ +SELECT "AdvEngineID", COUNT(*) FROM hits WHERE "AdvEngineID" <> 0 GROUP BY "AdvEngineID" ORDER BY COUNT(*) DESC; diff --git a/benchmarks/queries/clickbench/queries/q8.sql b/benchmarks/queries/clickbench/queries/q8.sql new file mode 100644 index 000000000000..e5691bb66f81 --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q8.sql @@ -0,0 +1 @@ +SELECT "RegionID", COUNT(DISTINCT "UserID") AS u FROM hits GROUP BY "RegionID" ORDER BY u DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/queries/q9.sql b/benchmarks/queries/clickbench/queries/q9.sql new file mode 100644 index 000000000000..42c22d96852d --- /dev/null +++ b/benchmarks/queries/clickbench/queries/q9.sql @@ -0,0 +1 @@ +SELECT "RegionID", SUM("AdvEngineID"), COUNT(*) AS c, AVG("ResolutionWidth"), COUNT(DISTINCT "UserID") FROM hits GROUP BY "RegionID" ORDER BY c DESC LIMIT 10; diff --git a/benchmarks/queries/clickbench/update_queries.sh b/benchmarks/queries/clickbench/update_queries.sh new file mode 100755 index 000000000000..d7db7359aa39 --- /dev/null +++ b/benchmarks/queries/clickbench/update_queries.sh @@ -0,0 +1,80 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This script is meant for developers of DataFusion -- it is runnable +# from the standard DataFusion development environment and uses cargo, +# etc and orchestrates gathering data and run the benchmark binary in +# different configurations. + +# Script to download ClickBench queries and split them into individual files + +set -e # Exit on any error + +# URL for the raw file (not the GitHub page) +URL="https://raw.githubusercontent.com/ClickHouse/ClickBench/main/datafusion/queries.sql" + +# Temporary file to store downloaded content +TEMP_FILE="queries.sql" + +TARGET_DIR="queries" + +# Download the file +echo "Downloading queries from $URL..." +if command -v curl &> /dev/null; then + curl -s -o "$TEMP_FILE" "$URL" +elif command -v wget &> /dev/null; then + wget -q -O "$TEMP_FILE" "$URL" +else + echo "Error: Neither curl nor wget is available. Please install one of them." + exit 1 +fi + +# Check if download was successful +if [ ! -f "$TEMP_FILE" ] || [ ! -s "$TEMP_FILE" ]; then + echo "Error: Failed to download or file is empty" + exit 1 +fi + +# Initialize counter +counter=0 + +# Ensure the target directory exists +if [ ! -d ${TARGET_DIR} ]; then + mkdir -p ${TARGET_DIR} +fi + +# Read the file line by line and create individual query files +mapfile -t lines < $TEMP_FILE +for line in "${lines[@]}"; do + # Skip empty lines + if [ -n "$line" ]; then + # Create filename with zero-padded counter + filename="q${counter}.sql" + + # Write the line to the individual file + echo "$line" > "${TARGET_DIR}/$filename" + + echo "Created ${TARGET_DIR}/$filename" + + # Increment counter + (( counter += 1 )) + fi +done + +# Clean up temporary file +rm "$TEMP_FILE" \ No newline at end of file diff --git a/benchmarks/queries/h2o/groupby.sql b/benchmarks/queries/h2o/groupby.sql index c2101ef8ada2..4fae7a13810d 100644 --- a/benchmarks/queries/h2o/groupby.sql +++ b/benchmarks/queries/h2o/groupby.sql @@ -1,10 +1,19 @@ SELECT id1, SUM(v1) AS v1 FROM x GROUP BY id1; + SELECT id1, id2, SUM(v1) AS v1 FROM x GROUP BY id1, id2; + SELECT id3, SUM(v1) AS v1, AVG(v3) AS v3 FROM x GROUP BY id3; + SELECT id4, AVG(v1) AS v1, AVG(v2) AS v2, AVG(v3) AS v3 FROM x GROUP BY id4; + SELECT id6, SUM(v1) AS v1, SUM(v2) AS v2, SUM(v3) AS v3 FROM x GROUP BY id6; + SELECT id4, id5, MEDIAN(v3) AS median_v3, STDDEV(v3) AS sd_v3 FROM x GROUP BY id4, id5; + SELECT id3, MAX(v1) - MIN(v2) AS range_v1_v2 FROM x GROUP BY id3; + SELECT id6, largest2_v3 FROM (SELECT id6, v3 AS largest2_v3, ROW_NUMBER() OVER (PARTITION BY id6 ORDER BY v3 DESC) AS order_v3 FROM x WHERE v3 IS NOT NULL) sub_query WHERE order_v3 <= 2; + SELECT id2, id4, POWER(CORR(v1, v2), 2) AS r2 FROM x GROUP BY id2, id4; -SELECT id1, id2, id3, id4, id5, id6, SUM(v3) AS v3, COUNT(*) AS count FROM x GROUP BY id1, id2, id3, id4, id5, id6; + +SELECT id1, id2, id3, id4, id5, id6, SUM(v3) AS v3, COUNT(*) AS count FROM x GROUP BY id1, id2, id3, id4, id5, id6; \ No newline at end of file diff --git a/benchmarks/queries/h2o/join.sql b/benchmarks/queries/h2o/join.sql index 8546b9292dbb..84cd661fdd59 100644 --- a/benchmarks/queries/h2o/join.sql +++ b/benchmarks/queries/h2o/join.sql @@ -1,5 +1,9 @@ SELECT x.id1, x.id2, x.id3, x.id4 as xid4, small.id4 as smallid4, x.id5, x.id6, x.v1, small.v2 FROM x INNER JOIN small ON x.id1 = small.id1; + SELECT x.id1 as xid1, medium.id1 as mediumid1, x.id2, x.id3, x.id4 as xid4, medium.id4 as mediumid4, x.id5 as xid5, medium.id5 as mediumid5, x.id6, x.v1, medium.v2 FROM x INNER JOIN medium ON x.id2 = medium.id2; + SELECT x.id1 as xid1, medium.id1 as mediumid1, x.id2, x.id3, x.id4 as xid4, medium.id4 as mediumid4, x.id5 as xid5, medium.id5 as mediumid5, x.id6, x.v1, medium.v2 FROM x LEFT JOIN medium ON x.id2 = medium.id2; + SELECT x.id1 as xid1, medium.id1 as mediumid1, x.id2, x.id3, x.id4 as xid4, medium.id4 as mediumid4, x.id5 as xid5, medium.id5 as mediumid5, x.id6, x.v1, medium.v2 FROM x JOIN medium ON x.id5 = medium.id5; -SELECT x.id1 as xid1, large.id1 as largeid1, x.id2 as xid2, large.id2 as largeid2, x.id3, x.id4 as xid4, large.id4 as largeid4, x.id5 as xid5, large.id5 as largeid5, x.id6 as xid6, large.id6 as largeid6, x.v1, large.v2 FROM x JOIN large ON x.id3 = large.id3; + +SELECT x.id1 as xid1, large.id1 as largeid1, x.id2 as xid2, large.id2 as largeid2, x.id3, x.id4 as xid4, large.id4 as largeid4, x.id5 as xid5, large.id5 as largeid5, x.id6 as xid6, large.id6 as largeid6, x.v1, large.v2 FROM x JOIN large ON x.id3 = large.id3; \ No newline at end of file diff --git a/benchmarks/queries/h2o/window.sql b/benchmarks/queries/h2o/window.sql new file mode 100644 index 000000000000..071540927a4c --- /dev/null +++ b/benchmarks/queries/h2o/window.sql @@ -0,0 +1,112 @@ +-- Basic Window +SELECT + id1, + id2, + id3, + v2, + sum(v2) OVER () AS window_basic +FROM large; + +-- Sorted Window +SELECT + id1, + id2, + id3, + v2, + first_value(v2) OVER (ORDER BY id3) AS first_order_by, + row_number() OVER (ORDER BY id3) AS row_number_order_by +FROM large; + +-- PARTITION BY +SELECT + id1, + id2, + id3, + v2, + sum(v2) OVER (PARTITION BY id1) AS sum_by_id1, + sum(v2) OVER (PARTITION BY id2) AS sum_by_id2, + sum(v2) OVER (PARTITION BY id3) AS sum_by_id3 +FROM large; + +-- PARTITION BY ORDER BY +SELECT + id1, + id2, + id3, + v2, + first_value(v2) OVER (PARTITION BY id2 ORDER BY id3) AS first_by_id2_ordered_by_id3 +FROM large; + +-- Lead and Lag +SELECT + id1, + id2, + id3, + v2, + first_value(v2) OVER (ORDER BY id3 ROWS BETWEEN 1 PRECEDING AND 1 PRECEDING) AS my_lag, + first_value(v2) OVER (ORDER BY id3 ROWS BETWEEN 1 FOLLOWING AND 1 FOLLOWING) AS my_lead +FROM large; + +-- Moving Averages +SELECT + id1, + id2, + id3, + v2, + avg(v2) OVER (ORDER BY id3 ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) AS my_moving_average +FROM large; + +-- Rolling Sum +SELECT + id1, + id2, + id3, + v2, + sum(v2) OVER (ORDER BY id3 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS my_rolling_sum +FROM large; + +-- RANGE BETWEEN +SELECT + id1, + id2, + id3, + v2, + sum(v2) OVER (ORDER BY v2 RANGE BETWEEN 3 PRECEDING AND CURRENT ROW) AS my_range_between +FROM large; + +-- First PARTITION BY ROWS BETWEEN +SELECT + id1, + id2, + id3, + v2, + first_value(v2) OVER (PARTITION BY id2 ORDER BY id3 ROWS BETWEEN 1 PRECEDING AND 1 PRECEDING) AS my_lag_by_id2, + first_value(v2) OVER (PARTITION BY id2 ORDER BY id3 ROWS BETWEEN 1 FOLLOWING AND 1 FOLLOWING) AS my_lead_by_id2 +FROM large; + +-- Moving Averages PARTITION BY +SELECT + id1, + id2, + id3, + v2, + avg(v2) OVER (PARTITION BY id2 ORDER BY id3 ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) AS my_moving_average_by_id2 +FROM large; + +-- Rolling Sum PARTITION BY +SELECT + id1, + id2, + id3, + v2, + sum(v2) OVER (PARTITION BY id2 ORDER BY id3 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS my_rolling_sum_by_id2 +FROM large; + +-- RANGE BETWEEN PARTITION BY +SELECT + id1, + id2, + id3, + v2, + sum(v2) OVER (PARTITION BY id2 ORDER BY v2 RANGE BETWEEN 3 PRECEDING AND CURRENT ROW) AS my_range_between_by_id2 +FROM large; \ No newline at end of file diff --git a/benchmarks/queries/q10.sql b/benchmarks/queries/q10.sql index cf45e43485fb..8613fd496283 100644 --- a/benchmarks/queries/q10.sql +++ b/benchmarks/queries/q10.sql @@ -28,4 +28,5 @@ group by c_address, c_comment order by - revenue desc; \ No newline at end of file + revenue desc +limit 20; diff --git a/benchmarks/queries/q18.sql b/benchmarks/queries/q18.sql index 835de28a57be..ba7ee7f716cf 100644 --- a/benchmarks/queries/q18.sql +++ b/benchmarks/queries/q18.sql @@ -29,4 +29,5 @@ group by o_totalprice order by o_totalprice desc, - o_orderdate; \ No newline at end of file + o_orderdate +limit 100; diff --git a/benchmarks/queries/q2.sql b/benchmarks/queries/q2.sql index f66af210205e..68e478f65d3f 100644 --- a/benchmarks/queries/q2.sql +++ b/benchmarks/queries/q2.sql @@ -40,4 +40,5 @@ order by s_acctbal desc, n_name, s_name, - p_partkey; \ No newline at end of file + p_partkey +limit 100; diff --git a/benchmarks/queries/q21.sql b/benchmarks/queries/q21.sql index 9d2fe32cee22..b95e7b0dfca0 100644 --- a/benchmarks/queries/q21.sql +++ b/benchmarks/queries/q21.sql @@ -36,4 +36,5 @@ group by s_name order by numwait desc, - s_name; \ No newline at end of file + s_name +limit 100; diff --git a/benchmarks/queries/q3.sql b/benchmarks/queries/q3.sql index 7dbc6d9ef678..e5fa9e38664c 100644 --- a/benchmarks/queries/q3.sql +++ b/benchmarks/queries/q3.sql @@ -19,4 +19,5 @@ group by o_shippriority order by revenue desc, - o_orderdate; \ No newline at end of file + o_orderdate +limit 10; diff --git a/benchmarks/src/bin/dfbench.rs b/benchmarks/src/bin/dfbench.rs index 06337cb75888..41b64063c099 100644 --- a/benchmarks/src/bin/dfbench.rs +++ b/benchmarks/src/bin/dfbench.rs @@ -60,11 +60,11 @@ pub async fn main() -> Result<()> { Options::Cancellation(opt) => opt.run().await, Options::Clickbench(opt) => opt.run().await, Options::H2o(opt) => opt.run().await, - Options::Imdb(opt) => opt.run().await, + Options::Imdb(opt) => Box::pin(opt.run()).await, Options::ParquetFilter(opt) => opt.run().await, Options::Sort(opt) => opt.run().await, Options::SortTpch(opt) => opt.run().await, - Options::Tpch(opt) => opt.run().await, + Options::Tpch(opt) => Box::pin(opt.run()).await, Options::TpchConvert(opt) => opt.run().await, } } diff --git a/benchmarks/src/bin/external_aggr.rs b/benchmarks/src/bin/external_aggr.rs index 578f71f8275d..0e519367badb 100644 --- a/benchmarks/src/bin/external_aggr.rs +++ b/benchmarks/src/bin/external_aggr.rs @@ -40,7 +40,7 @@ use datafusion::execution::SessionStateBuilder; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::{collect, displayable}; use datafusion::prelude::*; -use datafusion_benchmarks::util::{BenchmarkRun, CommonOpt}; +use datafusion_benchmarks::util::{BenchmarkRun, CommonOpt, QueryResult}; use datafusion_common::instant::Instant; use datafusion_common::utils::get_available_parallelism; use datafusion_common::{exec_err, DEFAULT_PARQUET_EXTENSION}; @@ -77,11 +77,6 @@ struct ExternalAggrConfig { output_path: Option, } -struct QueryResult { - elapsed: std::time::Duration, - row_count: usize, -} - /// Query Memory Limits /// Map query id to predefined memory limits /// @@ -189,7 +184,7 @@ impl ExternalAggrConfig { ) -> Result> { let query_name = format!("Q{query_id}({})", human_readable_size(mem_limit as usize)); - let config = self.common.config(); + let config = self.common.config()?; let memory_pool: Arc = match mem_pool_type { "fair" => Arc::new(FairSpillPool::new(mem_limit as usize)), "greedy" => Arc::new(GreedyMemoryPool::new(mem_limit as usize)), @@ -335,7 +330,7 @@ impl ExternalAggrConfig { fn partitions(&self) -> usize { self.common .partitions - .unwrap_or(get_available_parallelism()) + .unwrap_or_else(get_available_parallelism) } } diff --git a/benchmarks/src/bin/imdb.rs b/benchmarks/src/bin/imdb.rs index 13421f8a89a9..5ce99928df66 100644 --- a/benchmarks/src/bin/imdb.rs +++ b/benchmarks/src/bin/imdb.rs @@ -53,7 +53,7 @@ pub async fn main() -> Result<()> { env_logger::init(); match ImdbOpt::from_args() { ImdbOpt::Benchmark(BenchmarkSubCommandOpt::DataFusionBenchmark(opt)) => { - opt.run().await + Box::pin(opt.run()).await } ImdbOpt::Convert(opt) => opt.run().await, } diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 3270b082cfb4..ca2bb8e57c0e 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -58,7 +58,7 @@ async fn main() -> Result<()> { env_logger::init(); match TpchOpt::from_args() { TpchOpt::Benchmark(BenchmarkSubCommandOpt::DataFusionBenchmark(opt)) => { - opt.run().await + Box::pin(opt.run()).await } TpchOpt::Convert(opt) => opt.run().await, } diff --git a/benchmarks/src/cancellation.rs b/benchmarks/src/cancellation.rs index f5740bdc96e0..fcf03fbc5455 100644 --- a/benchmarks/src/cancellation.rs +++ b/benchmarks/src/cancellation.rs @@ -38,7 +38,7 @@ use futures::TryStreamExt; use object_store::ObjectStore; use parquet::arrow::async_writer::ParquetObjectWriter; use parquet::arrow::AsyncArrowWriter; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use rand::rngs::ThreadRng; use rand::Rng; use structopt::StructOpt; @@ -237,7 +237,7 @@ fn find_files_on_disk(data_dir: impl AsRef) -> Result> { let path = file.unwrap().path(); if path .extension() - .map(|ext| (ext == "parquet")) + .map(|ext| ext == "parquet") .unwrap_or(false) { Some(path) @@ -312,15 +312,15 @@ async fn generate_data( } fn random_data(column_type: &DataType, rows: usize) -> Arc { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let values = (0..rows).map(|_| random_value(&mut rng, column_type)); ScalarValue::iter_to_array(values).unwrap() } fn random_value(rng: &mut ThreadRng, column_type: &DataType) -> ScalarValue { match column_type { - DataType::Float64 => ScalarValue::Float64(Some(rng.gen())), - DataType::Boolean => ScalarValue::Boolean(Some(rng.gen())), + DataType::Float64 => ScalarValue::Float64(Some(rng.random())), + DataType::Boolean => ScalarValue::Boolean(Some(rng.random())), DataType::Utf8 => ScalarValue::Utf8(Some( rng.sample_iter(&Alphanumeric) .take(10) diff --git a/benchmarks/src/clickbench.rs b/benchmarks/src/clickbench.rs index 923c2bdd7cdf..8d1847b1b874 100644 --- a/benchmarks/src/clickbench.rs +++ b/benchmarks/src/clickbench.rs @@ -15,10 +15,11 @@ // specific language governing permissions and limitations // under the License. -use std::path::Path; -use std::path::PathBuf; +use std::fs; +use std::io::ErrorKind; +use std::path::{Path, PathBuf}; -use crate::util::{BenchmarkRun, CommonOpt}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; use datafusion::{ error::{DataFusionError, Result}, prelude::SessionContext, @@ -56,12 +57,12 @@ pub struct RunOpt { )] path: PathBuf, - /// Path to queries.sql (single file) + /// Path to queries directory #[structopt( parse(from_os_str), short = "r", long = "queries-path", - default_value = "benchmarks/queries/clickbench/queries.sql" + default_value = "benchmarks/queries/clickbench/queries" )] queries_path: PathBuf, @@ -70,53 +71,51 @@ pub struct RunOpt { output_path: Option, } -struct AllQueries { - queries: Vec, +/// Get the SQL file path +pub fn get_query_path(query_dir: &Path, query: usize) -> PathBuf { + let mut query_path = query_dir.to_path_buf(); + query_path.push(format!("q{query}.sql")); + query_path } -impl AllQueries { - fn try_new(path: &Path) -> Result { - // ClickBench has all queries in a single file identified by line number - let all_queries = std::fs::read_to_string(path) - .map_err(|e| exec_datafusion_err!("Could not open {path:?}: {e}"))?; - Ok(Self { - queries: all_queries.lines().map(|s| s.to_string()).collect(), - }) +/// Get the SQL statement from the specified query file +pub fn get_query_sql(query_path: &Path) -> Result> { + if fs::exists(query_path)? { + Ok(Some(fs::read_to_string(query_path)?)) + } else { + Ok(None) } +} + +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running benchmarks with the following options: {self:?}"); - /// Returns the text of query `query_id` - fn get_query(&self, query_id: usize) -> Result<&str> { - self.queries - .get(query_id) - .ok_or_else(|| { - let min_id = self.min_query_id(); - let max_id = self.max_query_id(); + let query_dir_metadata = fs::metadata(&self.queries_path).map_err(|e| { + if e.kind() == ErrorKind::NotFound { exec_datafusion_err!( - "Invalid query id {query_id}. Must be between {min_id} and {max_id}" + "Query path '{}' does not exist.", + &self.queries_path.to_str().unwrap() ) - }) - .map(|s| s.as_str()) - } + } else { + DataFusionError::External(Box::new(e)) + } + })?; - fn min_query_id(&self) -> usize { - 0 - } + if !query_dir_metadata.is_dir() { + return Err(exec_datafusion_err!( + "Query path '{}' is not a directory.", + &self.queries_path.to_str().unwrap() + )); + } - fn max_query_id(&self) -> usize { - self.queries.len() - 1 - } -} -impl RunOpt { - pub async fn run(self) -> Result<()> { - println!("Running benchmarks with the following options: {self:?}"); - let queries = AllQueries::try_new(self.queries_path.as_path())?; let query_range = match self.query { Some(query_id) => query_id..=query_id, - None => queries.min_query_id()..=queries.max_query_id(), + None => 0..=usize::MAX, }; // configure parquet options - let mut config = self.common.config(); + let mut config = self.common.config()?; { let parquet_options = &mut config.options_mut().execution.parquet; // The hits_partitioned dataset specifies string columns @@ -128,36 +127,67 @@ impl RunOpt { let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); self.register_hits(&ctx).await?; - let iterations = self.common.iterations; let mut benchmark_run = BenchmarkRun::new(); for query_id in query_range { - let mut millis = Vec::with_capacity(iterations); + let query_path = get_query_path(&self.queries_path, query_id); + let Some(sql) = get_query_sql(&query_path)? else { + if self.query.is_some() { + return Err(exec_datafusion_err!( + "Could not load query file '{}'.", + &query_path.to_str().unwrap() + )); + } + break; + }; benchmark_run.start_new_case(&format!("Query {query_id}")); - let sql = queries.get_query(query_id)?; - println!("Q{query_id}: {sql}"); - - for i in 0..iterations { - let start = Instant::now(); - let results = ctx.sql(sql).await?.collect().await?; - let elapsed = start.elapsed(); - let ms = elapsed.as_secs_f64() * 1000.0; - millis.push(ms); - let row_count: usize = results.iter().map(|b| b.num_rows()).sum(); - println!( - "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" - ); - benchmark_run.write_iter(elapsed, row_count); - } - if self.common.debug { - ctx.sql(sql).await?.explain(false, false)?.show().await?; + let query_run = self.benchmark_query(&sql, query_id, &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + benchmark_run.mark_failed(); + eprintln!("Query {query_id} failed: {e}"); + } } - let avg = millis.iter().sum::() / millis.len() as f64; - println!("Query {query_id} avg time: {avg:.2} ms"); } benchmark_run.maybe_write_json(self.output_path.as_ref())?; + benchmark_run.maybe_print_failures(); Ok(()) } + async fn benchmark_query( + &self, + sql: &str, + query_id: usize, + ctx: &SessionContext, + ) -> Result> { + println!("Q{query_id}: {sql}"); + + let mut millis = Vec::with_capacity(self.iterations()); + let mut query_results = vec![]; + for i in 0..self.iterations() { + let start = Instant::now(); + let results = ctx.sql(sql).await?.collect().await?; + let elapsed = start.elapsed(); + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + let row_count: usize = results.iter().map(|b| b.num_rows()).sum(); + println!( + "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }) + } + if self.common.debug { + ctx.sql(sql).await?.explain(false, false)?.show().await?; + } + let avg = millis.iter().sum::() / millis.len() as f64; + println!("Query {query_id} avg time: {avg:.2} ms"); + Ok(query_results) + } + /// Registers the `hits.parquet` as a table named `hits` async fn register_hits(&self, ctx: &SessionContext) -> Result<()> { let options = Default::default(); @@ -171,4 +201,8 @@ impl RunOpt { ) }) } + + fn iterations(&self) -> usize { + self.common.iterations + } } diff --git a/benchmarks/src/h2o.rs b/benchmarks/src/h2o.rs index cc463e70d74a..23dba07f426d 100644 --- a/benchmarks/src/h2o.rs +++ b/benchmarks/src/h2o.rs @@ -15,9 +15,16 @@ // specific language governing permissions and limitations // under the License. +//! H2O benchmark implementation for groupby, join and window operations +//! Reference: +//! - [H2O AI Benchmark](https://duckdb.org/2023/04/14/h2oai.html) +//! - [Extended window function benchmark](https://duckdb.org/2024/06/26/benchmarks-over-time.html#window-functions-benchmark) + use crate::util::{BenchmarkRun, CommonOpt}; use datafusion::{error::Result, prelude::SessionContext}; -use datafusion_common::{exec_datafusion_err, instant::Instant, DataFusionError}; +use datafusion_common::{ + exec_datafusion_err, instant::Instant, internal_err, DataFusionError, +}; use std::path::{Path, PathBuf}; use structopt::StructOpt; @@ -77,19 +84,28 @@ impl RunOpt { None => queries.min_query_id()..=queries.max_query_id(), }; - let config = self.common.config(); + let config = self.common.config()?; let rt_builder = self.common.runtime_env_builder()?; let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); - if self.queries_path.to_str().unwrap().contains("join") { + // Register tables depending on which h2o benchmark is being run + // (groupby/join/window) + if self.queries_path.to_str().unwrap().ends_with("groupby.sql") { + self.register_data(&ctx).await?; + } else if self.queries_path.to_str().unwrap().ends_with("join.sql") { let join_paths: Vec<&str> = self.join_paths.split(',').collect(); let table_name: Vec<&str> = vec!["x", "small", "medium", "large"]; for (i, path) in join_paths.iter().enumerate() { ctx.register_csv(table_name[i], path, Default::default()) .await?; } - } else if self.queries_path.to_str().unwrap().contains("groupby") { - self.register_data(&ctx).await?; + } else if self.queries_path.to_str().unwrap().ends_with("window.sql") { + // Only register the 'large' table in h2o-join dataset + let h2o_join_large_path = self.join_paths.split(',').nth(3).unwrap(); + ctx.register_csv("large", h2o_join_large_path, Default::default()) + .await?; + } else { + return internal_err!("Invalid query file path"); } let iterations = self.common.iterations; @@ -171,7 +187,7 @@ impl AllQueries { .map_err(|e| exec_datafusion_err!("Could not open {path:?}: {e}"))?; Ok(Self { - queries: all_queries.lines().map(|s| s.to_string()).collect(), + queries: all_queries.split("\n\n").map(|s| s.to_string()).collect(), }) } diff --git a/benchmarks/src/imdb/run.rs b/benchmarks/src/imdb/run.rs index d7d7a56d0540..0d9bdf536d10 100644 --- a/benchmarks/src/imdb/run.rs +++ b/benchmarks/src/imdb/run.rs @@ -19,7 +19,7 @@ use std::path::PathBuf; use std::sync::Arc; use super::{get_imdb_table_schema, get_query_sql, IMDB_TABLES}; -use crate::util::{BenchmarkRun, CommonOpt}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::{self, pretty_format_batches}; @@ -303,7 +303,7 @@ impl RunOpt { async fn benchmark_query(&self, query_id: usize) -> Result> { let mut config = self .common - .config() + .config()? .with_collect_statistics(!self.disable_statistics); config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; let rt_builder = self.common.runtime_env_builder()?; @@ -471,15 +471,10 @@ impl RunOpt { fn partitions(&self) -> usize { self.common .partitions - .unwrap_or(get_available_parallelism()) + .unwrap_or_else(get_available_parallelism) } } -struct QueryResult { - elapsed: std::time::Duration, - row_count: usize, -} - #[cfg(test)] // Only run with "ci" mode when we have the data #[cfg(feature = "ci")] @@ -514,7 +509,7 @@ mod tests { let common = CommonOpt { iterations: 1, partitions: Some(2), - batch_size: 8192, + batch_size: Some(8192), mem_pool_type: "fair".to_string(), memory_limit: None, sort_spill_reservation_bytes: None, @@ -550,7 +545,7 @@ mod tests { let common = CommonOpt { iterations: 1, partitions: Some(2), - batch_size: 8192, + batch_size: Some(8192), mem_pool_type: "fair".to_string(), memory_limit: None, sort_spill_reservation_bytes: None, diff --git a/benchmarks/src/sort.rs b/benchmarks/src/sort.rs index 9cf09c57205a..cbbd3b54ea9e 100644 --- a/benchmarks/src/sort.rs +++ b/benchmarks/src/sort.rs @@ -70,28 +70,31 @@ impl RunOpt { let sort_cases = vec![ ( "sort utf8", - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("request_method", &schema)?, options: Default::default(), - }]), + }] + .into(), ), ( "sort int", - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("response_bytes", &schema)?, options: Default::default(), - }]), + }] + .into(), ), ( "sort decimal", - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("decimal_price", &schema)?, options: Default::default(), - }]), + }] + .into(), ), ( "sort integer tuple", - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("request_bytes", &schema)?, options: Default::default(), @@ -100,11 +103,12 @@ impl RunOpt { expr: col("response_bytes", &schema)?, options: Default::default(), }, - ]), + ] + .into(), ), ( "sort utf8 tuple", - LexOrdering::new(vec![ + [ // sort utf8 tuple PhysicalSortExpr { expr: col("service", &schema)?, @@ -122,11 +126,12 @@ impl RunOpt { expr: col("image", &schema)?, options: Default::default(), }, - ]), + ] + .into(), ), ( "sort mixed tuple", - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("service", &schema)?, options: Default::default(), @@ -139,7 +144,8 @@ impl RunOpt { expr: col("decimal_price", &schema)?, options: Default::default(), }, - ]), + ] + .into(), ), ]; for (title, expr) in sort_cases { @@ -149,7 +155,7 @@ impl RunOpt { let config = SessionConfig::new().with_target_partitions( self.common .partitions - .unwrap_or(get_available_parallelism()), + .unwrap_or_else(get_available_parallelism), ); let ctx = SessionContext::new_with_config(config); let (rows, elapsed) = diff --git a/benchmarks/src/sort_tpch.rs b/benchmarks/src/sort_tpch.rs index 176234eca541..21897f5bf2d7 100644 --- a/benchmarks/src/sort_tpch.rs +++ b/benchmarks/src/sort_tpch.rs @@ -40,7 +40,7 @@ use datafusion_common::instant::Instant; use datafusion_common::utils::get_available_parallelism; use datafusion_common::DEFAULT_PARQUET_EXTENSION; -use crate::util::{BenchmarkRun, CommonOpt}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; #[derive(Debug, StructOpt)] pub struct RunOpt { @@ -74,11 +74,6 @@ pub struct RunOpt { limit: Option, } -struct QueryResult { - elapsed: std::time::Duration, - row_count: usize, -} - impl RunOpt { const SORT_TABLES: [&'static str; 1] = ["lineitem"]; @@ -179,7 +174,7 @@ impl RunOpt { /// If query is specified from command line, run only that query. /// Otherwise, run all queries. pub async fn run(&self) -> Result<()> { - let mut benchmark_run = BenchmarkRun::new(); + let mut benchmark_run: BenchmarkRun = BenchmarkRun::new(); let query_range = match self.query { Some(query_id) => query_id..=query_id, @@ -189,20 +184,28 @@ impl RunOpt { for query_id in query_range { benchmark_run.start_new_case(&format!("{query_id}")); - let query_results = self.benchmark_query(query_id).await?; - for iter in query_results { - benchmark_run.write_iter(iter.elapsed, iter.row_count); + let query_results = self.benchmark_query(query_id).await; + match query_results { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + benchmark_run.mark_failed(); + eprintln!("Query {query_id} failed: {e}"); + } } } benchmark_run.maybe_write_json(self.output_path.as_ref())?; - + benchmark_run.maybe_print_failures(); Ok(()) } /// Benchmark query `query_id` in `SORT_QUERIES` async fn benchmark_query(&self, query_id: usize) -> Result> { - let config = self.common.config(); + let config = self.common.config()?; let rt_builder = self.common.runtime_env_builder()?; let state = SessionStateBuilder::new() .with_config(config) @@ -294,7 +297,7 @@ impl RunOpt { let mut stream = execute_stream(physical_plan.clone(), state.task_ctx())?; while let Some(batch) = stream.next().await { - row_count += batch.unwrap().num_rows(); + row_count += batch?.num_rows(); } if debug { @@ -352,6 +355,6 @@ impl RunOpt { fn partitions(&self) -> usize { self.common .partitions - .unwrap_or(get_available_parallelism()) + .unwrap_or_else(get_available_parallelism) } } diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index 752a5a1a6ba0..88960d7c7d16 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use super::{ get_query_sql, get_tbl_tpch_table_schema, get_tpch_table_schema, TPCH_TABLES, }; -use crate::util::{BenchmarkRun, CommonOpt}; +use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::{self, pretty_format_batches}; @@ -109,51 +109,64 @@ impl RunOpt { }; let mut benchmark_run = BenchmarkRun::new(); - for query_id in query_range { - benchmark_run.start_new_case(&format!("Query {query_id}")); - let query_run = self.benchmark_query(query_id).await?; - for iter in query_run { - benchmark_run.write_iter(iter.elapsed, iter.row_count); - } - } - benchmark_run.maybe_write_json(self.output_path.as_ref())?; - Ok(()) - } - - async fn benchmark_query(&self, query_id: usize) -> Result> { let mut config = self .common - .config() + .config()? .with_collect_statistics(!self.disable_statistics); config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; let rt_builder = self.common.runtime_env_builder()?; let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); - // register tables self.register_tables(&ctx).await?; + for query_id in query_range { + benchmark_run.start_new_case(&format!("Query {query_id}")); + let query_run = self.benchmark_query(query_id, &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + benchmark_run.mark_failed(); + eprintln!("Query {query_id} failed: {e}"); + } + } + } + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + benchmark_run.maybe_print_failures(); + Ok(()) + } + + async fn benchmark_query( + &self, + query_id: usize, + ctx: &SessionContext, + ) -> Result> { let mut millis = vec![]; // run benchmark let mut query_results = vec![]; + + let sql = &get_query_sql(query_id)?; + for i in 0..self.iterations() { let start = Instant::now(); - let sql = &get_query_sql(query_id)?; - // query 15 is special, with 3 statements. the second statement is the one from which we // want to capture the results let mut result = vec![]; if query_id == 15 { for (n, query) in sql.iter().enumerate() { if n == 1 { - result = self.execute_query(&ctx, query).await?; + result = self.execute_query(ctx, query).await?; } else { - self.execute_query(&ctx, query).await?; + self.execute_query(ctx, query).await?; } } } else { for query in sql { - result = self.execute_query(&ctx, query).await?; + result = self.execute_query(ctx, query).await?; } } @@ -261,7 +274,7 @@ impl RunOpt { (Arc::new(format), path, ".tbl") } "csv" => { - let path = format!("{path}/{table}"); + let path = format!("{path}/csv/{table}"); let format = CsvFormat::default() .with_delimiter(b',') .with_has_header(true); @@ -313,15 +326,10 @@ impl RunOpt { fn partitions(&self) -> usize { self.common .partitions - .unwrap_or(get_available_parallelism()) + .unwrap_or_else(get_available_parallelism) } } -struct QueryResult { - elapsed: std::time::Duration, - row_count: usize, -} - #[cfg(test)] // Only run with "ci" mode when we have the data #[cfg(feature = "ci")] @@ -355,7 +363,7 @@ mod tests { let common = CommonOpt { iterations: 1, partitions: Some(2), - batch_size: 8192, + batch_size: Some(8192), mem_pool_type: "fair".to_string(), memory_limit: None, sort_spill_reservation_bytes: None, @@ -392,7 +400,7 @@ mod tests { let common = CommonOpt { iterations: 1, partitions: Some(2), - batch_size: 8192, + batch_size: Some(8192), mem_pool_type: "fair".to_string(), memory_limit: None, sort_spill_reservation_bytes: None, diff --git a/benchmarks/src/util/mod.rs b/benchmarks/src/util/mod.rs index 95c6e5f53d0f..420d52401c4e 100644 --- a/benchmarks/src/util/mod.rs +++ b/benchmarks/src/util/mod.rs @@ -22,4 +22,4 @@ mod run; pub use access_log::AccessLogOpt; pub use options::CommonOpt; -pub use run::{BenchQuery, BenchmarkRun}; +pub use run::{BenchQuery, BenchmarkRun, QueryResult}; diff --git a/benchmarks/src/util/options.rs b/benchmarks/src/util/options.rs index a1cf31525dd9..6627a287dfcd 100644 --- a/benchmarks/src/util/options.rs +++ b/benchmarks/src/util/options.rs @@ -19,13 +19,13 @@ use std::{num::NonZeroUsize, sync::Arc}; use datafusion::{ execution::{ - disk_manager::DiskManagerConfig, + disk_manager::DiskManagerBuilder, memory_pool::{FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool}, runtime_env::RuntimeEnvBuilder, }, prelude::SessionConfig, }; -use datafusion_common::{utils::get_available_parallelism, DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result}; use structopt::StructOpt; // Common benchmark options (don't use doc comments otherwise this doc @@ -41,8 +41,8 @@ pub struct CommonOpt { pub partitions: Option, /// Batch size when reading CSV or Parquet files - #[structopt(short = "s", long = "batch-size", default_value = "8192")] - pub batch_size: usize, + #[structopt(short = "s", long = "batch-size")] + pub batch_size: Option, /// The memory pool type to use, should be one of "fair" or "greedy" #[structopt(long = "mem-pool-type", default_value = "fair")] @@ -65,21 +65,25 @@ pub struct CommonOpt { impl CommonOpt { /// Return an appropriately configured `SessionConfig` - pub fn config(&self) -> SessionConfig { - self.update_config(SessionConfig::new()) + pub fn config(&self) -> Result { + SessionConfig::from_env().map(|config| self.update_config(config)) } /// Modify the existing config appropriately - pub fn update_config(&self, config: SessionConfig) -> SessionConfig { - let mut config = config - .with_target_partitions( - self.partitions.unwrap_or(get_available_parallelism()), - ) - .with_batch_size(self.batch_size); + pub fn update_config(&self, mut config: SessionConfig) -> SessionConfig { + if let Some(batch_size) = self.batch_size { + config = config.with_batch_size(batch_size); + } + + if let Some(partitions) = self.partitions { + config = config.with_target_partitions(partitions); + } + if let Some(sort_spill_reservation_bytes) = self.sort_spill_reservation_bytes { config = config.with_sort_spill_reservation_bytes(sort_spill_reservation_bytes); } + config } @@ -106,7 +110,7 @@ impl CommonOpt { }; rt_builder = rt_builder .with_memory_pool(pool) - .with_disk_manager(DiskManagerConfig::NewOs); + .with_disk_manager_builder(DiskManagerBuilder::default()); } Ok(rt_builder) } @@ -118,15 +122,14 @@ fn parse_memory_limit(limit: &str) -> Result { let (number, unit) = limit.split_at(limit.len() - 1); let number: f64 = number .parse() - .map_err(|_| format!("Failed to parse number from memory limit '{}'", limit))?; + .map_err(|_| format!("Failed to parse number from memory limit '{limit}'"))?; match unit { "K" => Ok((number * 1024.0) as usize), "M" => Ok((number * 1024.0 * 1024.0) as usize), "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize), _ => Err(format!( - "Unsupported unit '{}' in memory limit '{}'", - unit, limit + "Unsupported unit '{unit}' in memory limit '{limit}'" )), } } diff --git a/benchmarks/src/util/run.rs b/benchmarks/src/util/run.rs index 13969f4d3949..764ea648ff72 100644 --- a/benchmarks/src/util/run.rs +++ b/benchmarks/src/util/run.rs @@ -90,8 +90,13 @@ pub struct BenchQuery { iterations: Vec, #[serde(serialize_with = "serialize_start_time")] start_time: SystemTime, + success: bool, +} +/// Internal representation of a single benchmark query iteration result. +pub struct QueryResult { + pub elapsed: Duration, + pub row_count: usize, } - /// collects benchmark run data and then serializes it at the end pub struct BenchmarkRun { context: RunContext, @@ -120,6 +125,7 @@ impl BenchmarkRun { query: id.to_owned(), iterations: vec![], start_time: SystemTime::now(), + success: true, }); if let Some(c) = self.current_case.as_mut() { *c += 1; @@ -138,6 +144,28 @@ impl BenchmarkRun { } } + /// Print the names of failed queries, if any + pub fn maybe_print_failures(&self) { + let failed_queries: Vec<&str> = self + .queries + .iter() + .filter_map(|q| (!q.success).then_some(q.query.as_str())) + .collect(); + + if !failed_queries.is_empty() { + println!("Failed Queries: {}", failed_queries.join(", ")); + } + } + + /// Mark current query + pub fn mark_failed(&mut self) { + if let Some(idx) = self.current_case { + self.queries[idx].success = false; + } else { + unreachable!("Cannot mark failure: no current case"); + } + } + /// Stringify data into formatted json pub fn to_json(&self) -> String { let mut output = HashMap::<&str, Value>::new(); diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 566aafb319bf..63662e56ca75 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -37,9 +37,9 @@ backtrace = ["datafusion/backtrace"] [dependencies] arrow = { workspace = true } async-trait = { workspace = true } -aws-config = "1.6.1" +aws-config = "1.8.0" aws-credential-types = "1.2.0" -clap = { version = "4.5.35", features = ["derive", "cargo"] } +clap = { version = "4.5.40", features = ["derive", "cargo"] } datafusion = { workspace = true, features = [ "avro", "crypto_expressions", @@ -55,12 +55,13 @@ datafusion = { workspace = true, features = [ dirs = "6.0.0" env_logger = { workspace = true } futures = { workspace = true } +log = { workspace = true } mimalloc = { version = "0.1", default-features = false } object_store = { workspace = true, features = ["aws", "gcp", "http"] } parking_lot = { workspace = true } parquet = { workspace = true, default-features = false } regex = { workspace = true } -rustyline = "15.0" +rustyline = "16.0" tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] } url = { workspace = true } diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index ceb72dbc546b..fd83b52de299 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -200,6 +200,7 @@ impl SchemaProvider for DynamicObjectStoreSchemaProvider { table_url.scheme(), url, &state.default_table_options(), + false, ) .await?; state.runtime_env().register_object_store(url, store); @@ -229,7 +230,6 @@ pub fn substitute_tilde(cur: String) -> String { } #[cfg(test)] mod tests { - use super::*; use datafusion::catalog::SchemaProvider; @@ -337,8 +337,7 @@ mod tests { #[cfg(not(target_os = "windows"))] #[test] fn test_substitute_tilde() { - use std::env; - use std::path::MAIN_SEPARATOR; + use std::{env, path::PathBuf}; let original_home = home_dir(); let test_home_path = if cfg!(windows) { "C:\\Users\\user" @@ -350,17 +349,16 @@ mod tests { test_home_path, ); let input = "~/Code/datafusion/benchmarks/data/tpch_sf1/part/part-0.parquet"; - let expected = format!( - "{}{}Code{}datafusion{}benchmarks{}data{}tpch_sf1{}part{}part-0.parquet", - test_home_path, - MAIN_SEPARATOR, - MAIN_SEPARATOR, - MAIN_SEPARATOR, - MAIN_SEPARATOR, - MAIN_SEPARATOR, - MAIN_SEPARATOR, - MAIN_SEPARATOR - ); + let expected = PathBuf::from(test_home_path) + .join("Code") + .join("datafusion") + .join("benchmarks") + .join("data") + .join("tpch_sf1") + .join("part") + .join("part-0.parquet") + .to_string_lossy() + .to_string(); let actual = substitute_tilde(input.to_string()); assert_eq!(actual, expected); match original_home { diff --git a/datafusion-cli/src/command.rs b/datafusion-cli/src/command.rs index fc7d1a2617cf..77bc8d3d2000 100644 --- a/datafusion-cli/src/command.rs +++ b/datafusion-cli/src/command.rs @@ -64,21 +64,28 @@ impl Command { let command_batch = all_commands_info(); let schema = command_batch.schema(); let num_rows = command_batch.num_rows(); - print_options.print_batches(schema, &[command_batch], now, num_rows) + let task_ctx = ctx.task_ctx(); + let config = &task_ctx.session_config().options().format; + print_options.print_batches( + schema, + &[command_batch], + now, + num_rows, + config, + ) } Self::ListTables => { exec_and_print(ctx, print_options, "SHOW TABLES".into()).await } Self::DescribeTableStmt(name) => { - exec_and_print(ctx, print_options, format!("SHOW COLUMNS FROM {}", name)) + exec_and_print(ctx, print_options, format!("SHOW COLUMNS FROM {name}")) .await } Self::Include(filename) => { if let Some(filename) = filename { let file = File::open(filename).map_err(|e| { DataFusionError::Execution(format!( - "Error opening {:?} {}", - filename, e + "Error opening {filename:?} {e}" )) })?; exec_from_lines(ctx, &mut BufReader::new(file), print_options) @@ -108,7 +115,7 @@ impl Command { Self::SearchFunctions(function) => { if let Ok(func) = function.parse::() { let details = func.function_details()?; - println!("{}", details); + println!("{details}"); Ok(()) } else { exec_err!("{function} is not a supported function") diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 0f4d70c1cca9..ce190c1b40d3 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -26,28 +26,28 @@ use crate::{ object_storage::get_object_store, print_options::{MaxRows, PrintOptions}, }; -use futures::StreamExt; -use std::collections::HashMap; -use std::fs::File; -use std::io::prelude::*; -use std::io::BufReader; - use datafusion::common::instant::Instant; use datafusion::common::{plan_datafusion_err, plan_err}; use datafusion::config::ConfigFileType; use datafusion::datasource::listing::ListingTableUrl; use datafusion::error::{DataFusionError, Result}; +use datafusion::execution::memory_pool::MemoryConsumer; use datafusion::logical_expr::{DdlStatement, LogicalPlan}; use datafusion::physical_plan::execution_plan::EmissionType; +use datafusion::physical_plan::spill::get_record_batch_memory_size; use datafusion::physical_plan::{execute_stream, ExecutionPlanProperties}; use datafusion::sql::parser::{DFParser, Statement}; -use datafusion::sql::sqlparser::dialect::dialect_from_str; - -use datafusion::execution::memory_pool::MemoryConsumer; -use datafusion::physical_plan::spill::get_record_batch_memory_size; use datafusion::sql::sqlparser; +use datafusion::sql::sqlparser::dialect::dialect_from_str; +use futures::StreamExt; +use log::warn; +use object_store::Error::Generic; use rustyline::error::ReadlineError; use rustyline::Editor; +use std::collections::HashMap; +use std::fs::File; +use std::io::prelude::*; +use std::io::BufReader; use tokio::signal; /// run and execute SQL statements and commands, against a context with the given print options @@ -200,7 +200,7 @@ pub async fn exec_from_repl( break; } Err(err) => { - eprintln!("Unknown error happened {:?}", err); + eprintln!("Unknown error happened {err:?}"); break; } } @@ -216,7 +216,8 @@ pub(super) async fn exec_and_print( ) -> Result<()> { let now = Instant::now(); let task_ctx = ctx.task_ctx(); - let dialect = &task_ctx.session_config().options().sql_parser.dialect; + let options = task_ctx.session_config().options(); + let dialect = &options.sql_parser.dialect; let dialect = dialect_from_str(dialect).ok_or_else(|| { plan_datafusion_err!( "Unsupported SQL dialect: {dialect}. Available dialects: \ @@ -230,10 +231,21 @@ pub(super) async fn exec_and_print( let adjusted = AdjustedPrintOptions::new(print_options.clone()).with_statement(&statement); - let plan = create_plan(ctx, statement).await?; + let plan = create_plan(ctx, statement.clone(), false).await?; let adjusted = adjusted.with_plan(&plan); - let df = ctx.execute_logical_plan(plan).await?; + let df = match ctx.execute_logical_plan(plan).await { + Ok(df) => df, + Err(DataFusionError::ObjectStore(Generic { store, source: _ })) + if "S3".eq_ignore_ascii_case(store) + && matches!(&statement, Statement::CreateExternalTable(_)) => + { + warn!("S3 region is incorrect, auto-detecting the correct region (this may be slow). Consider updating your region configuration."); + let plan = create_plan(ctx, statement, true).await?; + ctx.execute_logical_plan(plan).await? + } + Err(e) => return Err(e), + }; let physical_plan = df.create_physical_plan().await?; // Track memory usage for the query result if it's bounded @@ -250,7 +262,9 @@ pub(super) async fn exec_and_print( // As the input stream comes, we can generate results. // However, memory safety is not guaranteed. let stream = execute_stream(physical_plan, task_ctx.clone())?; - print_options.print_stream(stream, now).await?; + print_options + .print_stream(stream, now, &options.format) + .await?; } else { // Bounded stream; collected results size is limited by the maxrows option let schema = physical_plan.schema(); @@ -273,9 +287,13 @@ pub(super) async fn exec_and_print( } row_count += curr_num_rows; } - adjusted - .into_inner() - .print_batches(schema, &results, now, row_count)?; + adjusted.into_inner().print_batches( + schema, + &results, + now, + row_count, + &options.format, + )?; reservation.free(); } } @@ -341,6 +359,7 @@ fn config_file_type_from_str(ext: &str) -> Option { async fn create_plan( ctx: &dyn CliSessionContext, statement: Statement, + resolve_region: bool, ) -> Result { let mut plan = ctx.session_state().statement_to_plan(statement).await?; @@ -355,6 +374,7 @@ async fn create_plan( &cmd.location, &cmd.options, format, + resolve_region, ) .await?; } @@ -367,6 +387,7 @@ async fn create_plan( ©_to.output_url, ©_to.options, format, + false, ) .await?; } @@ -405,6 +426,7 @@ pub(crate) async fn register_object_store_and_config_extensions( location: &String, options: &HashMap, format: Option, + resolve_region: bool, ) -> Result<()> { // Parse the location URL to extract the scheme and other components let table_path = ListingTableUrl::parse(location)?; @@ -426,8 +448,14 @@ pub(crate) async fn register_object_store_and_config_extensions( table_options.alter_with_string_hash_map(options)?; // Retrieve the appropriate object store based on the scheme, URL, and modified table options - let store = - get_object_store(&ctx.session_state(), scheme, url, &table_options).await?; + let store = get_object_store( + &ctx.session_state(), + scheme, + url, + &table_options, + resolve_region, + ) + .await?; // Register the retrieved object store in the session context's runtime environment ctx.register_object_store(url, store); @@ -455,6 +483,7 @@ mod tests { &cmd.location, &cmd.options, format, + false, ) .await?; } else { @@ -481,6 +510,7 @@ mod tests { &cmd.output_url, &cmd.options, format, + false, ) .await?; } else { @@ -523,11 +553,11 @@ mod tests { ) })?; for location in locations { - let sql = format!("copy (values (1,2)) to '{}' STORED AS PARQUET;", location); + let sql = format!("copy (values (1,2)) to '{location}' STORED AS PARQUET;"); let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; for statement in statements { //Should not fail - let mut plan = create_plan(&ctx, statement).await?; + let mut plan = create_plan(&ctx, statement, false).await?; if let LogicalPlan::Copy(copy_to) = &mut plan { assert_eq!(copy_to.output_url, location); assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string()); diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index 13d2d5fd3547..f07dac649df9 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -205,7 +205,7 @@ pub fn display_all_functions() -> Result<()> { let array = StringArray::from( ALL_FUNCTIONS .iter() - .map(|f| format!("{}", f)) + .map(|f| format!("{f}")) .collect::>(), ); let schema = Schema::new(vec![Field::new("Function", DataType::Utf8, false)]); @@ -322,7 +322,7 @@ pub struct ParquetMetadataFunc {} impl TableFunctionImpl for ParquetMetadataFunc { fn call(&self, exprs: &[Expr]) -> Result> { let filename = match exprs.first() { - Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet') + Some(Expr::Literal(ScalarValue::Utf8(Some(s)), _)) => s, // single quote: parquet_metadata('x.parquet') Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") _ => { return plan_err!( diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index dad2d15f01a1..fdecb185e33e 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -17,15 +17,17 @@ use std::collections::HashMap; use std::env; +use std::num::NonZeroUsize; use std::path::Path; use std::process::ExitCode; use std::sync::{Arc, LazyLock}; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionConfig; -use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, MemoryPool}; +use datafusion::execution::memory_pool::{ + FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool, +}; use datafusion::execution::runtime_env::RuntimeEnvBuilder; -use datafusion::execution::DiskManager; use datafusion::prelude::SessionContext; use datafusion_cli::catalog::DynamicObjectStoreCatalog; use datafusion_cli::functions::ParquetMetadataFunc; @@ -40,7 +42,7 @@ use datafusion_cli::{ use clap::Parser; use datafusion::common::config_err; use datafusion::config::ConfigOptions; -use datafusion::execution::disk_manager::DiskManagerConfig; +use datafusion::execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; use mimalloc::MiMalloc; #[global_allocator] @@ -118,6 +120,13 @@ struct Args { )] mem_pool_type: PoolType, + #[clap( + long, + help = "The number of top memory consumers to display when query fails due to memory exhaustion. To disable memory consumer tracking, set this value to 0", + default_value = "3" + )] + top_memory_consumers: usize, + #[clap( long, help = "The max number of rows to display for 'Table' format\n[possible values: numbers(0/10/...), inf(no limit)]", @@ -154,7 +163,7 @@ async fn main_inner() -> Result<()> { let args = Args::parse(); if !args.quiet { - println!("DataFusion CLI v{}", DATAFUSION_CLI_VERSION); + println!("DataFusion CLI v{DATAFUSION_CLI_VERSION}"); } if let Some(ref path) = args.data_path { @@ -169,22 +178,31 @@ async fn main_inner() -> Result<()> { if let Some(memory_limit) = args.memory_limit { // set memory pool type let pool: Arc = match args.mem_pool_type { - PoolType::Fair => Arc::new(FairSpillPool::new(memory_limit)), - PoolType::Greedy => Arc::new(GreedyMemoryPool::new(memory_limit)), + PoolType::Fair if args.top_memory_consumers == 0 => { + Arc::new(FairSpillPool::new(memory_limit)) + } + PoolType::Fair => Arc::new(TrackConsumersPool::new( + FairSpillPool::new(memory_limit), + NonZeroUsize::new(args.top_memory_consumers).unwrap(), + )), + PoolType::Greedy if args.top_memory_consumers == 0 => { + Arc::new(GreedyMemoryPool::new(memory_limit)) + } + PoolType::Greedy => Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(memory_limit), + NonZeroUsize::new(args.top_memory_consumers).unwrap(), + )), }; + rt_builder = rt_builder.with_memory_pool(pool) } // set disk limit if let Some(disk_limit) = args.disk_limit { - let disk_manager = DiskManager::try_new(DiskManagerConfig::NewOs)?; - - let disk_manager = Arc::try_unwrap(disk_manager) - .expect("DiskManager should be a single instance") - .with_max_temp_directory_size(disk_limit.try_into().unwrap())?; - - let disk_config = DiskManagerConfig::new_existing(Arc::new(disk_manager)); - rt_builder = rt_builder.with_disk_manager(disk_config); + let builder = DiskManagerBuilder::default() + .with_mode(DiskManagerMode::OsTmpDirectory) + .with_max_temp_directory_size(disk_limit.try_into().unwrap()); + rt_builder = rt_builder.with_disk_manager_builder(builder); } let runtime_env = rt_builder.build_arc()?; @@ -265,6 +283,11 @@ fn get_session_config(args: &Args) -> Result { config_options.explain.format = String::from("tree"); } + // in the CLI, we want to show NULL values rather the empty strings + if env::var_os("DATAFUSION_FORMAT_NULL").is_none() { + config_options.format.null = String::from("NULL"); + } + let session_config = SessionConfig::from(config_options).with_information_schema(true); Ok(session_config) @@ -274,7 +297,7 @@ fn parse_valid_file(dir: &str) -> Result { if Path::new(dir).is_file() { Ok(dir.to_string()) } else { - Err(format!("Invalid file '{}'", dir)) + Err(format!("Invalid file '{dir}'")) } } @@ -282,14 +305,14 @@ fn parse_valid_data_dir(dir: &str) -> Result { if Path::new(dir).is_dir() { Ok(dir.to_string()) } else { - Err(format!("Invalid data directory '{}'", dir)) + Err(format!("Invalid data directory '{dir}'")) } } fn parse_batch_size(size: &str) -> Result { match size.parse::() { Ok(size) if size > 0 => Ok(size), - _ => Err(format!("Invalid batch size '{}'", size)), + _ => Err(format!("Invalid batch size '{size}'")), } } @@ -346,20 +369,20 @@ fn parse_size_string(size: &str, label: &str) -> Result { let num_str = caps.get(1).unwrap().as_str(); let num = num_str .parse::() - .map_err(|_| format!("Invalid numeric value in {} '{}'", label, size))?; + .map_err(|_| format!("Invalid numeric value in {label} '{size}'"))?; let suffix = caps.get(2).map(|m| m.as_str()).unwrap_or("b"); let unit = BYTE_SUFFIXES .get(suffix) - .ok_or_else(|| format!("Invalid {} '{}'", label, size))?; + .ok_or_else(|| format!("Invalid {label} '{size}'"))?; let total_bytes = usize::try_from(unit.multiplier()) .ok() .and_then(|multiplier| num.checked_mul(multiplier)) - .ok_or_else(|| format!("{} '{}' is too large", label, size))?; + .ok_or_else(|| format!("{label} '{size}' is too large"))?; Ok(total_bytes) } else { - Err(format!("Invalid {} '{}'", label, size)) + Err(format!("Invalid {label} '{size}'")) } } diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index c31310093ac6..176dfdd4ceed 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -16,6 +16,7 @@ // under the License. use std::any::Any; +use std::error::Error; use std::fmt::{Debug, Display}; use std::sync::Arc; @@ -28,16 +29,31 @@ use datafusion::execution::context::SessionState; use async_trait::async_trait; use aws_config::BehaviorVersion; -use aws_credential_types::provider::ProvideCredentials; -use object_store::aws::{AmazonS3Builder, AwsCredential}; +use aws_credential_types::provider::error::CredentialsError; +use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider}; +use log::debug; +use object_store::aws::{AmazonS3Builder, AmazonS3ConfigKey, AwsCredential}; use object_store::gcp::GoogleCloudStorageBuilder; use object_store::http::HttpBuilder; use object_store::{ClientOptions, CredentialProvider, ObjectStore}; use url::Url; +#[cfg(not(test))] +use object_store::aws::resolve_bucket_region; + +// Provide a local mock when running tests so we don't make network calls +#[cfg(test)] +async fn resolve_bucket_region( + _bucket: &str, + _client_options: &ClientOptions, +) -> object_store::Result { + Ok("eu-central-1".to_string()) +} + pub async fn get_s3_object_store_builder( url: &Url, aws_options: &AwsOptions, + resolve_region: bool, ) -> Result { let AwsOptions { access_key_id, @@ -46,6 +62,7 @@ pub async fn get_s3_object_store_builder( region, endpoint, allow_http, + skip_signature, } = aws_options; let bucket_name = get_bucket_name(url)?; @@ -54,6 +71,7 @@ pub async fn get_s3_object_store_builder( if let (Some(access_key_id), Some(secret_access_key)) = (access_key_id, secret_access_key) { + debug!("Using explicitly provided S3 access_key_id and secret_access_key"); builder = builder .with_access_key_id(access_key_id) .with_secret_access_key(secret_access_key); @@ -62,29 +80,37 @@ pub async fn get_s3_object_store_builder( builder = builder.with_token(session_token); } } else { - let config = aws_config::defaults(BehaviorVersion::latest()).load().await; - if let Some(region) = config.region() { - builder = builder.with_region(region.to_string()); + debug!("Using AWS S3 SDK to determine credentials"); + let CredentialsFromConfig { + region, + credentials, + } = CredentialsFromConfig::try_new().await?; + if let Some(region) = region { + builder = builder.with_region(region); + } + if let Some(credentials) = credentials { + let credentials = Arc::new(S3CredentialProvider { credentials }); + builder = builder.with_credentials(credentials); + } else { + debug!("No credentials found, defaulting to skip signature "); + builder = builder.with_skip_signature(true); } - - let credentials = config - .credentials_provider() - .ok_or_else(|| { - DataFusionError::ObjectStore(object_store::Error::Generic { - store: "S3", - source: "Failed to get S3 credentials from the environment".into(), - }) - })? - .clone(); - - let credentials = Arc::new(S3CredentialProvider { credentials }); - builder = builder.with_credentials(credentials); } if let Some(region) = region { builder = builder.with_region(region); } + // If the region is not set or auto_detect_region is true, resolve the region. + if builder + .get_config_value(&AmazonS3ConfigKey::Region) + .is_none() + || resolve_region + { + let region = resolve_bucket_region(bucket_name, &ClientOptions::new()).await?; + builder = builder.with_region(region); + } + if let Some(endpoint) = endpoint { // Make a nicer error if the user hasn't allowed http and the endpoint // is http as the default message is "URL scheme is not allowed" @@ -105,9 +131,71 @@ pub async fn get_s3_object_store_builder( builder = builder.with_allow_http(*allow_http); } + if let Some(skip_signature) = skip_signature { + builder = builder.with_skip_signature(*skip_signature); + } + Ok(builder) } +/// Credentials from the AWS SDK +struct CredentialsFromConfig { + region: Option, + credentials: Option, +} + +impl CredentialsFromConfig { + /// Attempt find AWS S3 credentials via the AWS SDK + pub async fn try_new() -> Result { + let config = aws_config::defaults(BehaviorVersion::latest()).load().await; + let region = config.region().map(|r| r.to_string()); + + let credentials = config + .credentials_provider() + .ok_or_else(|| { + DataFusionError::ObjectStore(object_store::Error::Generic { + store: "S3", + source: "Failed to get S3 credentials aws_config".into(), + }) + })? + .clone(); + + // The credential provider is lazy, so it does not fetch credentials + // until they are needed. To ensure that the credentials are valid, + // we can call `provide_credentials` here. + let credentials = match credentials.provide_credentials().await { + Ok(_) => Some(credentials), + Err(CredentialsError::CredentialsNotLoaded(_)) => { + debug!("Could not use AWS SDK to get credentials"); + None + } + // other errors like `CredentialsError::InvalidConfiguration` + // should be returned to the user so they can be fixed + Err(e) => { + // Pass back underlying error to the user, including underlying source + let source_message = if let Some(source) = e.source() { + format!(": {source}") + } else { + String::new() + }; + + let message = format!( + "Error getting credentials from provider: {e}{source_message}", + ); + + return Err(DataFusionError::ObjectStore(object_store::Error::Generic { + store: "S3", + source: message.into(), + })); + } + }; + Ok(Self { + region, + credentials, + }) + } +} + #[derive(Debug)] struct S3CredentialProvider { credentials: aws_credential_types::provider::SharedCredentialsProvider, @@ -219,6 +307,11 @@ pub struct AwsOptions { pub endpoint: Option, /// Allow HTTP (otherwise will always use https) pub allow_http: Option, + /// Do not fetch credentials and do not sign requests + /// + /// This can be useful when interacting with public S3 buckets that deny + /// authorized requests + pub skip_signature: Option, } impl ExtensionOptions for AwsOptions { @@ -256,6 +349,9 @@ impl ExtensionOptions for AwsOptions { "allow_http" => { self.allow_http.set(rem, value)?; } + "skip_signature" | "nosign" => { + self.skip_signature.set(rem, value)?; + } _ => { return config_err!("Config value \"{}\" not found on AwsOptions", rem); } @@ -397,6 +493,7 @@ pub(crate) async fn get_object_store( scheme: &str, url: &Url, table_options: &TableOptions, + resolve_region: bool, ) -> Result, DataFusionError> { let store: Arc = match scheme { "s3" => { @@ -405,7 +502,8 @@ pub(crate) async fn get_object_store( "Given table options incompatible with the 's3' scheme" ); }; - let builder = get_s3_object_store_builder(url, options).await?; + let builder = + get_s3_object_store_builder(url, options, resolve_region).await?; Arc::new(builder.build()?) } "oss" => { @@ -461,7 +559,6 @@ mod tests { use super::*; - use datafusion::common::plan_err; use datafusion::{ datasource::listing::ListingTableUrl, logical_expr::{DdlStatement, LogicalPlan}, @@ -470,6 +567,63 @@ mod tests { use object_store::{aws::AmazonS3ConfigKey, gcp::GoogleConfigKey}; + #[tokio::test] + async fn s3_object_store_builder_default() -> Result<()> { + let location = "s3://bucket/path/FAKE/file.parquet"; + + // No options + let table_url = ListingTableUrl::parse(location)?; + let scheme = table_url.scheme(); + let sql = + format!("CREATE EXTERNAL TABLE test STORED AS PARQUET LOCATION '{location}'"); + + let ctx = SessionContext::new(); + ctx.register_table_options_extension_from_scheme(scheme); + let table_options = get_table_options(&ctx, &sql).await; + let aws_options = table_options.extensions.get::().unwrap(); + let builder = + get_s3_object_store_builder(table_url.as_ref(), aws_options, false).await?; + + // If the environment variables are set (as they are in CI) use them + let expected_access_key_id = std::env::var("AWS_ACCESS_KEY_ID").ok(); + let expected_secret_access_key = std::env::var("AWS_SECRET_ACCESS_KEY").ok(); + let expected_region = Some( + std::env::var("AWS_REGION").unwrap_or_else(|_| "eu-central-1".to_string()), + ); + let expected_endpoint = std::env::var("AWS_ENDPOINT").ok(); + + // get the actual configuration information, then assert_eq! + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::AccessKeyId), + expected_access_key_id + ); + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::SecretAccessKey), + expected_secret_access_key + ); + // Default is to skip signature when no credentials are provided + let expected_skip_signature = + if expected_access_key_id.is_none() && expected_secret_access_key.is_none() { + Some(String::from("true")) + } else { + Some(String::from("false")) + }; + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::Region), + expected_region + ); + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::Endpoint), + expected_endpoint + ); + assert_eq!(builder.get_config_value(&AmazonS3ConfigKey::Token), None); + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::SkipSignature), + expected_skip_signature + ); + Ok(()) + } + #[tokio::test] async fn s3_object_store_builder() -> Result<()> { // "fake" is uppercase to ensure the values are not lowercased when parsed @@ -493,29 +647,27 @@ mod tests { ); let ctx = SessionContext::new(); - let mut plan = ctx.state().create_logical_plan(&sql).await?; - - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options(); - table_options.alter_with_string_hash_map(&cmd.options)?; - let aws_options = table_options.extensions.get::().unwrap(); - let builder = - get_s3_object_store_builder(table_url.as_ref(), aws_options).await?; - // get the actual configuration information, then assert_eq! - let config = [ - (AmazonS3ConfigKey::AccessKeyId, access_key_id), - (AmazonS3ConfigKey::SecretAccessKey, secret_access_key), - (AmazonS3ConfigKey::Region, region), - (AmazonS3ConfigKey::Endpoint, endpoint), - (AmazonS3ConfigKey::Token, session_token), - ]; - for (key, value) in config { - assert_eq!(value, builder.get_config_value(&key).unwrap()); - } - } else { - return plan_err!("LogicalPlan is not a CreateExternalTable"); + ctx.register_table_options_extension_from_scheme(scheme); + let table_options = get_table_options(&ctx, &sql).await; + let aws_options = table_options.extensions.get::().unwrap(); + let builder = + get_s3_object_store_builder(table_url.as_ref(), aws_options, false).await?; + // get the actual configuration information, then assert_eq! + let config = [ + (AmazonS3ConfigKey::AccessKeyId, access_key_id), + (AmazonS3ConfigKey::SecretAccessKey, secret_access_key), + (AmazonS3ConfigKey::Region, region), + (AmazonS3ConfigKey::Endpoint, endpoint), + (AmazonS3ConfigKey::Token, session_token), + ]; + for (key, value) in config { + assert_eq!(value, builder.get_config_value(&key).unwrap()); } + // Should not skip signature when credentials are provided + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::SkipSignature), + Some("false".into()) + ); Ok(()) } @@ -538,21 +690,15 @@ mod tests { ); let ctx = SessionContext::new(); - let mut plan = ctx.state().create_logical_plan(&sql).await?; - - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options(); - table_options.alter_with_string_hash_map(&cmd.options)?; - let aws_options = table_options.extensions.get::().unwrap(); - let err = get_s3_object_store_builder(table_url.as_ref(), aws_options) - .await - .unwrap_err(); - - assert_eq!(err.to_string().lines().next().unwrap_or_default(), "Invalid or Unsupported Configuration: Invalid endpoint: http://endpoint33. HTTP is not allowed for S3 endpoints. To allow HTTP, set 'aws.allow_http' to true"); - } else { - return plan_err!("LogicalPlan is not a CreateExternalTable"); - } + ctx.register_table_options_extension_from_scheme(scheme); + + let table_options = get_table_options(&ctx, &sql).await; + let aws_options = table_options.extensions.get::().unwrap(); + let err = get_s3_object_store_builder(table_url.as_ref(), aws_options, false) + .await + .unwrap_err(); + + assert_eq!(err.to_string().lines().next().unwrap_or_default(), "Invalid or Unsupported Configuration: Invalid endpoint: http://endpoint33. HTTP is not allowed for S3 endpoints. To allow HTTP, set 'aws.allow_http' to true"); // Now add `allow_http` to the options and check if it works let sql = format!( @@ -563,19 +709,59 @@ mod tests { 'aws.allow_http' 'true'\ ) LOCATION '{location}'" ); + let table_options = get_table_options(&ctx, &sql).await; - let mut plan = ctx.state().create_logical_plan(&sql).await?; + let aws_options = table_options.extensions.get::().unwrap(); + // ensure this isn't an error + get_s3_object_store_builder(table_url.as_ref(), aws_options, false).await?; - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options(); - table_options.alter_with_string_hash_map(&cmd.options)?; - let aws_options = table_options.extensions.get::().unwrap(); - // ensure this isn't an error - get_s3_object_store_builder(table_url.as_ref(), aws_options).await?; - } else { - return plan_err!("LogicalPlan is not a CreateExternalTable"); - } + Ok(()) + } + + #[tokio::test] + async fn s3_object_store_builder_resolves_region_when_none_provided() -> Result<()> { + let expected_region = "eu-central-1"; + let location = "s3://test-bucket/path/file.parquet"; + + let table_url = ListingTableUrl::parse(location)?; + let aws_options = AwsOptions { + region: None, // No region specified - should auto-detect + ..Default::default() + }; + + let builder = + get_s3_object_store_builder(table_url.as_ref(), &aws_options, false).await?; + + // Verify that the region was auto-detected in test environment + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::Region), + Some(expected_region.to_string()) + ); + + Ok(()) + } + + #[tokio::test] + async fn s3_object_store_builder_overrides_region_when_resolve_region_enabled( + ) -> Result<()> { + let original_region = "us-east-1"; + let expected_region = "eu-central-1"; // This should be the auto-detected region + let location = "s3://test-bucket/path/file.parquet"; + + let table_url = ListingTableUrl::parse(location)?; + let aws_options = AwsOptions { + region: Some(original_region.to_string()), // Explicit region provided + ..Default::default() + }; + + let builder = + get_s3_object_store_builder(table_url.as_ref(), &aws_options, true).await?; + + // Verify that the region was overridden by auto-detection + assert_eq!( + builder.get_config_value(&AmazonS3ConfigKey::Region), + Some(expected_region.to_string()) + ); Ok(()) } @@ -592,25 +778,19 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.oss.endpoint' '{endpoint}') LOCATION '{location}'"); let ctx = SessionContext::new(); - let mut plan = ctx.state().create_logical_plan(&sql).await?; - - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options(); - table_options.alter_with_string_hash_map(&cmd.options)?; - let aws_options = table_options.extensions.get::().unwrap(); - let builder = get_oss_object_store_builder(table_url.as_ref(), aws_options)?; - // get the actual configuration information, then assert_eq! - let config = [ - (AmazonS3ConfigKey::AccessKeyId, access_key_id), - (AmazonS3ConfigKey::SecretAccessKey, secret_access_key), - (AmazonS3ConfigKey::Endpoint, endpoint), - ]; - for (key, value) in config { - assert_eq!(value, builder.get_config_value(&key).unwrap()); - } - } else { - return plan_err!("LogicalPlan is not a CreateExternalTable"); + ctx.register_table_options_extension_from_scheme(scheme); + let table_options = get_table_options(&ctx, &sql).await; + + let aws_options = table_options.extensions.get::().unwrap(); + let builder = get_oss_object_store_builder(table_url.as_ref(), aws_options)?; + // get the actual configuration information, then assert_eq! + let config = [ + (AmazonS3ConfigKey::AccessKeyId, access_key_id), + (AmazonS3ConfigKey::SecretAccessKey, secret_access_key), + (AmazonS3ConfigKey::Endpoint, endpoint), + ]; + for (key, value) in config { + assert_eq!(value, builder.get_config_value(&key).unwrap()); } Ok(()) @@ -629,30 +809,40 @@ mod tests { let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS('gcp.service_account_path' '{service_account_path}', 'gcp.service_account_key' '{service_account_key}', 'gcp.application_credentials_path' '{application_credentials_path}') LOCATION '{location}'"); let ctx = SessionContext::new(); - let mut plan = ctx.state().create_logical_plan(&sql).await?; - - if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - ctx.register_table_options_extension_from_scheme(scheme); - let mut table_options = ctx.state().default_table_options(); - table_options.alter_with_string_hash_map(&cmd.options)?; - let gcp_options = table_options.extensions.get::().unwrap(); - let builder = get_gcs_object_store_builder(table_url.as_ref(), gcp_options)?; - // get the actual configuration information, then assert_eq! - let config = [ - (GoogleConfigKey::ServiceAccount, service_account_path), - (GoogleConfigKey::ServiceAccountKey, service_account_key), - ( - GoogleConfigKey::ApplicationCredentials, - application_credentials_path, - ), - ]; - for (key, value) in config { - assert_eq!(value, builder.get_config_value(&key).unwrap()); - } - } else { - return plan_err!("LogicalPlan is not a CreateExternalTable"); + ctx.register_table_options_extension_from_scheme(scheme); + let table_options = get_table_options(&ctx, &sql).await; + + let gcp_options = table_options.extensions.get::().unwrap(); + let builder = get_gcs_object_store_builder(table_url.as_ref(), gcp_options)?; + // get the actual configuration information, then assert_eq! + let config = [ + (GoogleConfigKey::ServiceAccount, service_account_path), + (GoogleConfigKey::ServiceAccountKey, service_account_key), + ( + GoogleConfigKey::ApplicationCredentials, + application_credentials_path, + ), + ]; + for (key, value) in config { + assert_eq!(value, builder.get_config_value(&key).unwrap()); } Ok(()) } + + /// Plans the `CREATE EXTERNAL TABLE` SQL statement and returns the + /// resulting resolved `CreateExternalTable` command. + async fn get_table_options(ctx: &SessionContext, sql: &str) -> TableOptions { + let mut plan = ctx.state().create_logical_plan(sql).await.unwrap(); + + let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan else { + panic!("plan is not a CreateExternalTable"); + }; + + let mut table_options = ctx.state().default_table_options(); + table_options + .alter_with_string_hash_map(&cmd.options) + .unwrap(); + table_options + } } diff --git a/datafusion-cli/src/pool_type.rs b/datafusion-cli/src/pool_type.rs index 269790b61f5a..a2164cc3c739 100644 --- a/datafusion-cli/src/pool_type.rs +++ b/datafusion-cli/src/pool_type.rs @@ -33,7 +33,7 @@ impl FromStr for PoolType { match s { "Greedy" | "greedy" => Ok(PoolType::Greedy), "Fair" | "fair" => Ok(PoolType::Fair), - _ => Err(format!("Invalid memory pool type '{}'", s)), + _ => Err(format!("Invalid memory pool type '{s}'")), } } } diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index 1fc949593512..1d6a8396aee7 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -26,7 +26,7 @@ use arrow::datatypes::SchemaRef; use arrow::json::{ArrayWriter, LineDelimitedWriter}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches_with_options; -use datafusion::common::format::DEFAULT_CLI_FORMAT_OPTIONS; +use datafusion::config::FormatOptions; use datafusion::error::Result; /// Allow records to be printed in different formats @@ -110,7 +110,10 @@ fn format_batches_with_maxrows( writer: &mut W, batches: &[RecordBatch], maxrows: MaxRows, + format_options: &FormatOptions, ) -> Result<()> { + let options: arrow::util::display::FormatOptions = format_options.try_into()?; + match maxrows { MaxRows::Limited(maxrows) => { // Filter batches to meet the maxrows condition @@ -131,22 +134,19 @@ fn format_batches_with_maxrows( } } - let formatted = pretty_format_batches_with_options( - &filtered_batches, - &DEFAULT_CLI_FORMAT_OPTIONS, - )?; + let formatted = + pretty_format_batches_with_options(&filtered_batches, &options)?; if over_limit { - let mut formatted_str = format!("{}", formatted); + let mut formatted_str = format!("{formatted}"); formatted_str = keep_only_maxrows(&formatted_str, maxrows); - writeln!(writer, "{}", formatted_str)?; + writeln!(writer, "{formatted_str}")?; } else { - writeln!(writer, "{}", formatted)?; + writeln!(writer, "{formatted}")?; } } MaxRows::Unlimited => { - let formatted = - pretty_format_batches_with_options(batches, &DEFAULT_CLI_FORMAT_OPTIONS)?; - writeln!(writer, "{}", formatted)?; + let formatted = pretty_format_batches_with_options(batches, &options)?; + writeln!(writer, "{formatted}")?; } } @@ -162,6 +162,7 @@ impl PrintFormat { batches: &[RecordBatch], maxrows: MaxRows, with_header: bool, + format_options: &FormatOptions, ) -> Result<()> { // filter out any empty batches let batches: Vec<_> = batches @@ -170,7 +171,7 @@ impl PrintFormat { .cloned() .collect(); if batches.is_empty() { - return self.print_empty(writer, schema); + return self.print_empty(writer, schema, format_options); } match self { @@ -182,7 +183,7 @@ impl PrintFormat { if maxrows == MaxRows::Limited(0) { return Ok(()); } - format_batches_with_maxrows(writer, &batches, maxrows) + format_batches_with_maxrows(writer, &batches, maxrows, format_options) } Self::Json => batches_to_json!(ArrayWriter, writer, &batches), Self::NdJson => batches_to_json!(LineDelimitedWriter, writer, &batches), @@ -194,16 +195,18 @@ impl PrintFormat { &self, writer: &mut W, schema: SchemaRef, + format_options: &FormatOptions, ) -> Result<()> { match self { // Print column headers for Table format Self::Table if !schema.fields().is_empty() => { + let format_options: arrow::util::display::FormatOptions = + format_options.try_into()?; + let empty_batch = RecordBatch::new_empty(schema); - let formatted = pretty_format_batches_with_options( - &[empty_batch], - &DEFAULT_CLI_FORMAT_OPTIONS, - )?; - writeln!(writer, "{}", formatted)?; + let formatted = + pretty_format_batches_with_options(&[empty_batch], &format_options)?; + writeln!(writer, "{formatted}")?; } _ => {} } @@ -644,6 +647,7 @@ mod tests { &self.batches, self.maxrows, with_header, + &FormatOptions::default(), ) .unwrap(); String::from_utf8(buffer).unwrap() diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index 9557e783e8a7..56d787b0fe08 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -29,6 +29,7 @@ use datafusion::common::DataFusionError; use datafusion::error::Result; use datafusion::physical_plan::RecordBatchStream; +use datafusion::config::FormatOptions; use futures::StreamExt; #[derive(Debug, Clone, PartialEq, Copy)] @@ -51,7 +52,7 @@ impl FromStr for MaxRows { } else { match maxrows.parse::() { Ok(nrows) => Ok(Self::Limited(nrows)), - _ => Err(format!("Invalid maxrows {}. Valid inputs are natural numbers or \'none\', \'inf\', or \'infinite\' for no limit.", maxrows)), + _ => Err(format!("Invalid maxrows {maxrows}. Valid inputs are natural numbers or \'none\', \'inf\', or \'infinite\' for no limit.")), } } } @@ -103,12 +104,19 @@ impl PrintOptions { batches: &[RecordBatch], query_start_time: Instant, row_count: usize, + format_options: &FormatOptions, ) -> Result<()> { let stdout = std::io::stdout(); let mut writer = stdout.lock(); - self.format - .print_batches(&mut writer, schema, batches, self.maxrows, true)?; + self.format.print_batches( + &mut writer, + schema, + batches, + self.maxrows, + true, + format_options, + )?; let formatted_exec_details = get_execution_details_formatted( row_count, @@ -132,6 +140,7 @@ impl PrintOptions { &self, mut stream: Pin>, query_start_time: Instant, + format_options: &FormatOptions, ) -> Result<()> { if self.format == PrintFormat::Table { return Err(DataFusionError::External( @@ -154,6 +163,7 @@ impl PrintOptions { &[batch], MaxRows::Unlimited, with_header, + format_options, )?; with_header = false; } diff --git a/datafusion-cli/tests/cli_integration.rs b/datafusion-cli/tests/cli_integration.rs index 9ac09955512b..108651281dfc 100644 --- a/datafusion-cli/tests/cli_integration.rs +++ b/datafusion-cli/tests/cli_integration.rs @@ -69,6 +69,10 @@ fn init() { // can choose the old explain format too ["--command", "EXPLAIN FORMAT indent SELECT 123"], )] +#[case::change_format_version( + "change_format_version", + ["--file", "tests/sql/types_format.sql", "-q"], +)] #[test] fn cli_quick_test<'a>( #[case] snapshot_name: &'a str, @@ -118,6 +122,42 @@ fn test_cli_format<'a>(#[case] format: &'a str) { assert_cmd_snapshot!(cmd); } +#[rstest] +#[case("no_track", ["--top-memory-consumers", "0"])] +#[case("top2", ["--top-memory-consumers", "2"])] +#[case("top3_default", [])] +#[test] +fn test_cli_top_memory_consumers<'a>( + #[case] snapshot_name: &str, + #[case] top_memory_consumers: impl IntoIterator, +) { + let mut settings = make_settings(); + + settings.set_snapshot_suffix(snapshot_name); + + settings.add_filter( + r"[^\s]+\#\d+\(can spill: (true|false)\) consumed .*?B", + "Consumer(can spill: bool) consumed XB", + ); + settings.add_filter( + r"Error: Failed to allocate additional .*? for .*? with .*? already allocated for this reservation - .*? remain available for the total pool", + "Error: Failed to allocate ", + ); + settings.add_filter( + r"Resources exhausted: Failed to allocate additional .*? for .*? with .*? already allocated for this reservation - .*? remain available for the total pool", + "Resources exhausted: Failed to allocate", + ); + + let _bound = settings.bind_to_scope(); + + let mut cmd = cli(); + let sql = "select * from generate_series(1,500000) as t1(v1) order by v1;"; + cmd.args(["--memory-limit", "10M", "--command", sql]); + cmd.args(top_memory_consumers); + + assert_cmd_snapshot!(cmd); +} + #[tokio::test] async fn test_cli() { if env::var("TEST_STORAGE_INTEGRATION").is_err() { @@ -157,16 +197,48 @@ async fn test_aws_options() { STORED AS CSV LOCATION 's3://data/cars.csv' OPTIONS( - 'aws.access_key_id' '{}', - 'aws.secret_access_key' '{}', - 'aws.endpoint' '{}', + 'aws.access_key_id' '{access_key_id}', + 'aws.secret_access_key' '{secret_access_key}', + 'aws.endpoint' '{endpoint_url}', 'aws.allow_http' 'true' ); SELECT * FROM CARS limit 1; -"#, - access_key_id, secret_access_key, endpoint_url +"# ); assert_cmd_snapshot!(cli().env_clear().pass_stdin(input)); } + +#[tokio::test] +async fn test_aws_region_auto_resolution() { + if env::var("TEST_STORAGE_INTEGRATION").is_err() { + eprintln!("Skipping external storages integration tests"); + return; + } + + let mut settings = make_settings(); + settings.add_filter(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z", "[TIME]"); + let _bound = settings.bind_to_scope(); + + let bucket = "s3://clickhouse-public-datasets/hits_compatible/athena_partitioned/hits_1.parquet"; + let region = "us-east-1"; + + let input = format!( + r#"CREATE EXTERNAL TABLE hits +STORED AS PARQUET +LOCATION '{bucket}' +OPTIONS( + 'aws.region' '{region}', + 'aws.skip_signature' true +); + +SELECT COUNT(*) FROM hits; +"# + ); + + assert_cmd_snapshot!(cli() + .env("RUST_LOG", "warn") + .env_remove("AWS_ENDPOINT") + .pass_stdin(input)); +} diff --git a/datafusion-cli/tests/snapshots/aws_region_auto_resolution.snap b/datafusion-cli/tests/snapshots/aws_region_auto_resolution.snap new file mode 100644 index 000000000000..cd6d918b78d9 --- /dev/null +++ b/datafusion-cli/tests/snapshots/aws_region_auto_resolution.snap @@ -0,0 +1,29 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: [] + env: + AWS_ENDPOINT: "" + RUST_LOG: warn + stdin: "CREATE EXTERNAL TABLE hits\nSTORED AS PARQUET\nLOCATION 's3://clickhouse-public-datasets/hits_compatible/athena_partitioned/hits_1.parquet'\nOPTIONS(\n 'aws.region' 'us-east-1',\n 'aws.skip_signature' true\n);\n\nSELECT COUNT(*) FROM hits;\n" +--- +success: true +exit_code: 0 +----- stdout ----- +[CLI_VERSION] +0 row(s) fetched. +[ELAPSED] + ++----------+ +| count(*) | ++----------+ +| 1000000 | ++----------+ +1 row(s) fetched. +[ELAPSED] + +\q + +----- stderr ----- +[[TIME] WARN datafusion_cli::exec] S3 region is incorrect, auto-detecting the correct region (this may be slow). Consider updating your region configuration. diff --git a/datafusion-cli/tests/snapshots/cli_quick_test@change_format_version.snap b/datafusion-cli/tests/snapshots/cli_quick_test@change_format_version.snap new file mode 100644 index 000000000000..74059b2a6103 --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_quick_test@change_format_version.snap @@ -0,0 +1,20 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: + - "--file" + - tests/sql/types_format.sql + - "-q" +--- +success: true +exit_code: 0 +----- stdout ----- ++-----------+ +| Int64(54) | +| Int64 | ++-----------+ +| 54 | ++-----------+ + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap new file mode 100644 index 000000000000..89b646a531f8 --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@no_track.snap @@ -0,0 +1,21 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: + - "--memory-limit" + - 10M + - "--command" + - "select * from generate_series(1,500000) as t1(v1) order by v1;" + - "--top-memory-consumers" + - "0" +--- +success: false +exit_code: 1 +----- stdout ----- +[CLI_VERSION] +Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +caused by +Resources exhausted: Failed to allocate + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap new file mode 100644 index 000000000000..ed925a6f6461 --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top2.snap @@ -0,0 +1,24 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: + - "--memory-limit" + - 10M + - "--command" + - "select * from generate_series(1,500000) as t1(v1) order by v1;" + - "--top-memory-consumers" + - "2" +--- +success: false +exit_code: 1 +----- stdout ----- +[CLI_VERSION] +Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +caused by +Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + Consumer(can spill: bool) consumed XB, + Consumer(can spill: bool) consumed XB. +Error: Failed to allocate + +----- stderr ----- diff --git a/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap new file mode 100644 index 000000000000..f35e3b117178 --- /dev/null +++ b/datafusion-cli/tests/snapshots/cli_top_memory_consumers@top3_default.snap @@ -0,0 +1,23 @@ +--- +source: datafusion-cli/tests/cli_integration.rs +info: + program: datafusion-cli + args: + - "--memory-limit" + - 10M + - "--command" + - "select * from generate_series(1,500000) as t1(v1) order by v1;" +--- +success: false +exit_code: 1 +----- stdout ----- +[CLI_VERSION] +Error: Not enough memory to continue external sort. Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes +caused by +Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + Consumer(can spill: bool) consumed XB, + Consumer(can spill: bool) consumed XB, + Consumer(can spill: bool) consumed XB. +Error: Failed to allocate + +----- stderr ----- diff --git a/datafusion-cli/tests/sql/types_format.sql b/datafusion-cli/tests/sql/types_format.sql new file mode 100644 index 000000000000..637929c980a1 --- /dev/null +++ b/datafusion-cli/tests/sql/types_format.sql @@ -0,0 +1,3 @@ +set datafusion.format.types_info to true; + +select 54 \ No newline at end of file diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index 2ba1673d97b9..b31708a5c1cc 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -77,7 +77,7 @@ tonic = "0.12.1" tracing = { version = "0.1" } tracing-subscriber = { version = "0.3" } url = { workspace = true } -uuid = "1.16" +uuid = "1.17" [target.'cfg(not(target_os = "windows"))'.dev-dependencies] -nix = { version = "0.29.0", features = ["fs"] } +nix = { version = "0.30.1", features = ["fs"] } diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 3ba4c77cd84c..285762bb57e7 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -65,6 +65,7 @@ cargo run --example dataframe - [`flight_sql_server.rs`](examples/flight/flight_sql_server.rs): Run DataFusion as a standalone process and execute SQL queries from JDBC clients - [`function_factory.rs`](examples/function_factory.rs): Register `CREATE FUNCTION` handler to implement SQL macros - [`optimizer_rule.rs`](examples/optimizer_rule.rs): Use a custom OptimizerRule to replace certain predicates +- [`parquet_encrypted.rs`](examples/parquet_encrypted.rs): Read and write encrypted Parquet files using DataFusion - [`parquet_index.rs`](examples/parquet_index.rs): Create an secondary index over several parquet files and use it to speed up queries - [`parquet_exec_visitor.rs`](examples/parquet_exec_visitor.rs): Extract statistics by visiting an ExecutionPlan after execution - [`parse_sql_expr.rs`](examples/parse_sql_expr.rs): Parse SQL text into DataFusion `Expr`. diff --git a/datafusion-examples/examples/advanced_parquet_index.rs b/datafusion-examples/examples/advanced_parquet_index.rs index 03ef3d66f9d7..efaee23366a1 100644 --- a/datafusion-examples/examples/advanced_parquet_index.rs +++ b/datafusion-examples/examples/advanced_parquet_index.rs @@ -495,7 +495,7 @@ impl TableProvider for IndexTableProvider { ParquetSource::default() // provide the predicate so the DataSourceExec can try and prune // row groups internally - .with_predicate(Arc::clone(&schema), predicate) + .with_predicate(predicate) // provide the factory to create parquet reader without re-reading metadata .with_parquet_file_reader_factory(Arc::new(reader_factory)), ); diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 9cda726db719..7b1d3e94b2ef 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -25,6 +25,7 @@ use arrow::array::{ }; use arrow::datatypes::{ArrowNativeTypeOp, ArrowPrimitiveType, Float64Type, UInt32Type}; use arrow::record_batch::RecordBatch; +use arrow_schema::FieldRef; use datafusion::common::{cast::as_float64_array, ScalarValue}; use datafusion::error::Result; use datafusion::logical_expr::{ @@ -92,10 +93,10 @@ impl AggregateUDFImpl for GeoMeanUdaf { } /// This is the description of the state. accumulator's state() must match the types here. - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ - Field::new("prod", args.return_type.clone(), true), - Field::new("n", DataType::UInt32, true), + Field::new("prod", args.return_type().clone(), true).into(), + Field::new("n", DataType::UInt32, true).into(), ]) } @@ -401,7 +402,7 @@ impl AggregateUDFImpl for SimplifiedGeoMeanUdaf { unimplemented!("should not be invoked") } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unimplemented!("should not be invoked") } @@ -482,7 +483,7 @@ async fn main() -> Result<()> { ctx.register_udaf(udf.clone()); let sql_df = ctx - .sql(&format!("SELECT {}(a) FROM t GROUP BY b", udf_name)) + .sql(&format!("SELECT {udf_name}(a) FROM t GROUP BY b")) .await?; sql_df.show().await?; diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index 8330e783319d..f7316ddc1bec 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -23,6 +23,7 @@ use arrow::{ array::{ArrayRef, AsArray, Float64Array}, datatypes::Float64Type, }; +use arrow_schema::FieldRef; use datafusion::common::ScalarValue; use datafusion::error::Result; use datafusion::functions_aggregate::average::avg_udaf; @@ -87,8 +88,8 @@ impl WindowUDFImpl for SmoothItUdf { Ok(Box::new(MyPartitionEvaluator::new())) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::Float64, true)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, true).into()) } } @@ -190,7 +191,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf { /// default implementation will not be called (left as `todo!()`) fn simplify(&self) -> Option { let simplify = |window_function: WindowFunction, _: &dyn SimplifyInfo| { - Ok(Expr::WindowFunction(WindowFunction { + Ok(Expr::from(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(avg_udaf()), params: WindowFunctionParams { args: window_function.params.args, @@ -205,8 +206,8 @@ impl WindowUDFImpl for SimplifySmoothItUdf { Some(Box::new(simplify)) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::Float64, true)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, true).into()) } } diff --git a/datafusion-examples/examples/async_udf.rs b/datafusion-examples/examples/async_udf.rs new file mode 100644 index 000000000000..3037a971dfd9 --- /dev/null +++ b/datafusion-examples/examples/async_udf.rs @@ -0,0 +1,267 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayIter, ArrayRef, AsArray, Int64Array, RecordBatch, StringArray}; +use arrow::compute::kernels::cmp::eq; +use arrow_schema::{DataType, Field, Schema}; +use async_trait::async_trait; +use datafusion::common::error::Result; +use datafusion::common::types::{logical_int64, logical_string}; +use datafusion::common::utils::take_function_args; +use datafusion::common::{internal_err, not_impl_err}; +use datafusion::config::ConfigOptions; +use datafusion::logical_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, +}; +use datafusion::logical_expr_common::signature::Coercion; +use datafusion::physical_expr_common::datum::apply_cmp; +use datafusion::prelude::SessionContext; +use log::trace; +use std::any::Any; +use std::sync::Arc; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx: SessionContext = SessionContext::new(); + + let async_upper = AsyncUpper::new(); + let udf = AsyncScalarUDF::new(Arc::new(async_upper)); + ctx.register_udf(udf.into_scalar_udf()); + let async_equal = AsyncEqual::new(); + let udf = AsyncScalarUDF::new(Arc::new(async_equal)); + ctx.register_udf(udf.into_scalar_udf()); + ctx.register_batch("animal", animal()?)?; + + // use Async UDF in the projection + // +---------------+----------------------------------------------------------------------------------------+ + // | plan_type | plan | + // +---------------+----------------------------------------------------------------------------------------+ + // | logical_plan | Projection: async_equal(a.id, Int64(1)) | + // | | SubqueryAlias: a | + // | | TableScan: animal projection=[id] | + // | physical_plan | ProjectionExec: expr=[__async_fn_0@1 as async_equal(a.id,Int64(1))] | + // | | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_equal(id@0, 1))] | + // | | CoalesceBatchesExec: target_batch_size=8192 | + // | | DataSourceExec: partitions=1, partition_sizes=[1] | + // | | | + // +---------------+----------------------------------------------------------------------------------------+ + ctx.sql("explain select async_equal(a.id, 1) from animal a") + .await? + .show() + .await?; + + // +----------------------------+ + // | async_equal(a.id,Int64(1)) | + // +----------------------------+ + // | true | + // | false | + // | false | + // | false | + // | false | + // +----------------------------+ + ctx.sql("select async_equal(a.id, 1) from animal a") + .await? + .show() + .await?; + + // use Async UDF in the filter + // +---------------+--------------------------------------------------------------------------------------------+ + // | plan_type | plan | + // +---------------+--------------------------------------------------------------------------------------------+ + // | logical_plan | SubqueryAlias: a | + // | | Filter: async_equal(animal.id, Int64(1)) | + // | | TableScan: animal projection=[id, name] | + // | physical_plan | CoalesceBatchesExec: target_batch_size=8192 | + // | | FilterExec: __async_fn_0@2, projection=[id@0, name@1] | + // | | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 | + // | | AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_equal(id@0, 1))] | + // | | CoalesceBatchesExec: target_batch_size=8192 | + // | | DataSourceExec: partitions=1, partition_sizes=[1] | + // | | | + // +---------------+--------------------------------------------------------------------------------------------+ + ctx.sql("explain select * from animal a where async_equal(a.id, 1)") + .await? + .show() + .await?; + + // +----+------+ + // | id | name | + // +----+------+ + // | 1 | cat | + // +----+------+ + ctx.sql("select * from animal a where async_equal(a.id, 1)") + .await? + .show() + .await?; + + Ok(()) +} + +fn animal() -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + + let id_array = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])); + let name_array = Arc::new(StringArray::from(vec![ + "cat", "dog", "fish", "bird", "snake", + ])); + + Ok(RecordBatch::try_new(schema, vec![id_array, name_array])?) +} + +#[derive(Debug)] +pub struct AsyncUpper { + signature: Signature, +} + +impl Default for AsyncUpper { + fn default() -> Self { + Self::new() + } +} + +impl AsyncUpper { + pub fn new() -> Self { + Self { + signature: Signature::new( + TypeSignature::Coercible(vec![Coercion::Exact { + desired_type: TypeSignatureClass::Native(logical_string()), + }]), + Volatility::Volatile, + ), + } + } +} + +#[async_trait] +impl ScalarUDFImpl for AsyncUpper { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "async_upper" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + not_impl_err!("AsyncUpper can only be called from async contexts") + } +} + +#[async_trait] +impl AsyncScalarUDFImpl for AsyncUpper { + fn ideal_batch_size(&self) -> Option { + Some(10) + } + + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + _option: &ConfigOptions, + ) -> Result { + trace!("Invoking async_upper with args: {:?}", args); + let value = &args.args[0]; + let result = match value { + ColumnarValue::Array(array) => { + let string_array = array.as_string::(); + let iter = ArrayIter::new(string_array); + let result = iter + .map(|string| string.map(|s| s.to_uppercase())) + .collect::(); + Arc::new(result) as ArrayRef + } + _ => return internal_err!("Expected a string argument, got {:?}", value), + }; + Ok(result) + } +} + +#[derive(Debug)] +struct AsyncEqual { + signature: Signature, +} + +impl Default for AsyncEqual { + fn default() -> Self { + Self::new() + } +} + +impl AsyncEqual { + pub fn new() -> Self { + Self { + signature: Signature::new( + TypeSignature::Coercible(vec![ + Coercion::Exact { + desired_type: TypeSignatureClass::Native(logical_int64()), + }, + Coercion::Exact { + desired_type: TypeSignatureClass::Native(logical_int64()), + }, + ]), + Volatility::Volatile, + ), + } + } +} + +#[async_trait] +impl ScalarUDFImpl for AsyncEqual { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "async_equal" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + not_impl_err!("AsyncEqual can only be called from async contexts") + } +} + +#[async_trait] +impl AsyncScalarUDFImpl for AsyncEqual { + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + _option: &ConfigOptions, + ) -> Result { + let [arg1, arg2] = take_function_args(self.name(), &args.args)?; + apply_cmp(arg1, arg2, eq)?.to_array(args.number_rows) + } +} diff --git a/datafusion-examples/examples/catalog.rs b/datafusion-examples/examples/catalog.rs index 655438b78b9f..229867cdfc5b 100644 --- a/datafusion-examples/examples/catalog.rs +++ b/datafusion-examples/examples/catalog.rs @@ -309,7 +309,7 @@ fn prepare_example_data() -> Result { 3,baz"#; for i in 0..5 { - let mut file = File::create(path.join(format!("{}.csv", i)))?; + let mut file = File::create(path.join(format!("{i}.csv")))?; file.write_all(content.as_bytes())?; } diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs index 165d82627061..e9a4d71b1633 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_file_format.rs @@ -21,28 +21,24 @@ use arrow::{ array::{AsArray, RecordBatch, StringArray, UInt8Array}, datatypes::{DataType, Field, Schema, SchemaRef, UInt64Type}, }; -use datafusion::physical_expr::LexRequirement; -use datafusion::physical_expr::PhysicalExpr; use datafusion::{ catalog::Session, common::{GetExt, Statistics}, -}; -use datafusion::{ - datasource::physical_plan::FileSource, execution::session_state::SessionStateBuilder, -}; -use datafusion::{ datasource::{ file_format::{ csv::CsvFormatFactory, file_compression_type::FileCompressionType, FileFormat, FileFormatFactory, }, - physical_plan::{FileScanConfig, FileSinkConfig}, + physical_plan::{FileScanConfig, FileSinkConfig, FileSource}, MemTable, }, error::Result, + execution::session_state::SessionStateBuilder, + physical_expr_common::sort_expr::LexRequirement, physical_plan::ExecutionPlan, prelude::SessionContext, }; + use object_store::{ObjectMeta, ObjectStore}; use tempfile::tempdir; @@ -112,11 +108,8 @@ impl FileFormat for TSVFileFormat { &self, state: &dyn Session, conf: FileScanConfig, - filters: Option<&Arc>, ) -> Result> { - self.csv_file_format - .create_physical_plan(state, conf, filters) - .await + self.csv_file_format.create_physical_plan(state, conf).await } async fn create_writer_physical_plan( diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe.rs index 6f61c164f41d..57a28aeca0de 100644 --- a/datafusion-examples/examples/dataframe.rs +++ b/datafusion-examples/examples/dataframe.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; +use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray, StringViewArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::catalog::MemTable; use datafusion::common::config::CsvOptions; use datafusion::common::parsers::CompressionTypeVariant; use datafusion::common::DataFusionError; @@ -63,6 +64,7 @@ async fn main() -> Result<()> { read_parquet(&ctx).await?; read_csv(&ctx).await?; read_memory(&ctx).await?; + read_memory_macro().await?; write_out(&ctx).await?; register_aggregate_test_data("t1", &ctx).await?; register_aggregate_test_data("t2", &ctx).await?; @@ -144,7 +146,7 @@ async fn read_csv(ctx: &SessionContext) -> Result<()> { // and using the `enable_url_table` refer to local files directly let dyn_ctx = ctx.clone().enable_url_table(); let csv_df = dyn_ctx - .sql(&format!("SELECT rating, unixtime FROM '{}'", file_path)) + .sql(&format!("SELECT rating, unixtime FROM '{file_path}'")) .await?; csv_df.show().await?; @@ -173,16 +175,40 @@ async fn read_memory(ctx: &SessionContext) -> Result<()> { Ok(()) } +/// Use the DataFrame API to: +/// 1. Read in-memory data. +async fn read_memory_macro() -> Result<()> { + // create a DataFrame using macro + let df = dataframe!( + "a" => ["a", "b", "c", "d"], + "b" => [1, 10, 10, 100] + )?; + // print the results + df.show().await?; + + // create empty DataFrame using macro + let df_empty = dataframe!()?; + df_empty.show().await?; + + Ok(()) +} + /// Use the DataFrame API to: /// 1. Write out a DataFrame to a table /// 2. Write out a DataFrame to a parquet file /// 3. Write out a DataFrame to a csv file /// 4. Write out a DataFrame to a json file async fn write_out(ctx: &SessionContext) -> std::result::Result<(), DataFusionError> { - let mut df = ctx.sql("values ('a'), ('b'), ('c')").await.unwrap(); - - // Ensure the column names and types match the target table - df = df.with_column_renamed("column1", "tablecol1").unwrap(); + let array = StringViewArray::from(vec!["a", "b", "c"]); + let schema = Arc::new(Schema::new(vec![Field::new( + "tablecol1", + DataType::Utf8View, + false, + )])); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)])?; + let mem_table = MemTable::try_new(schema.clone(), vec![vec![batch]])?; + ctx.register_table("initial_data", Arc::new(mem_table))?; + let df = ctx.table("initial_data").await?; ctx.sql( "create external table diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index b61a350a5a9c..92cf33f4fdf6 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -65,7 +65,7 @@ async fn main() -> Result<()> { let expr2 = Expr::BinaryExpr(BinaryExpr::new( Box::new(col("a")), Operator::Plus, - Box::new(Expr::Literal(ScalarValue::Int32(Some(5)))), + Box::new(Expr::Literal(ScalarValue::Int32(Some(5)), None)), )); assert_eq!(expr, expr2); @@ -147,8 +147,7 @@ fn evaluate_demo() -> Result<()> { ])) as _; assert!( matches!(&result, ColumnarValue::Array(r) if r == &expected_result), - "result: {:?}", - result + "result: {result:?}" ); Ok(()) @@ -424,7 +423,7 @@ fn boundary_analysis_in_conjuctions_demo() -> Result<()> { // // But `AND` conjunctions are easier to reason with because their interval // arithmetic follows naturally from set intersection operations, let us - // now look at an example that is a tad more complicated `OR` conjunctions. + // now look at an example that is a tad more complicated `OR` disjunctions. // The expression we will look at is `age > 60 OR age <= 18`. let age_greater_than_60_less_than_18 = @@ -435,7 +434,7 @@ fn boundary_analysis_in_conjuctions_demo() -> Result<()> { // // Initial range: [14, 79] as described in our column statistics. // - // From the left-hand side and right-hand side of our `OR` conjunctions + // From the left-hand side and right-hand side of our `OR` disjunctions // we end up with two ranges, instead of just one. // // - age > 60: [61, 79] @@ -446,7 +445,8 @@ fn boundary_analysis_in_conjuctions_demo() -> Result<()> { let physical_expr = SessionContext::new() .create_physical_expr(age_greater_than_60_less_than_18, &df_schema)?; - // Since we don't handle interval arithmetic for `OR` operator this will error out. + // However, analysis only supports a single interval, so we don't yet deal + // with the multiple possibilities of the `OR` disjunctions. let analysis = analyze( &physical_expr, AnalysisContext::new(initial_boundaries), diff --git a/datafusion-examples/examples/flight/flight_sql_server.rs b/datafusion-examples/examples/flight/flight_sql_server.rs index 54e8de7177cb..5a573ed52320 100644 --- a/datafusion-examples/examples/flight/flight_sql_server.rs +++ b/datafusion-examples/examples/flight/flight_sql_server.rs @@ -115,6 +115,7 @@ impl FlightSqlServiceImpl { Ok(uuid) } + #[allow(clippy::result_large_err)] fn get_ctx(&self, req: &Request) -> Result, Status> { // get the token from the authorization header on Request let auth = req @@ -140,6 +141,7 @@ impl FlightSqlServiceImpl { } } + #[allow(clippy::result_large_err)] fn get_plan(&self, handle: &str) -> Result { if let Some(plan) = self.statements.get(handle) { Ok(plan.clone()) @@ -148,6 +150,7 @@ impl FlightSqlServiceImpl { } } + #[allow(clippy::result_large_err)] fn get_result(&self, handle: &str) -> Result, Status> { if let Some(result) = self.results.get(handle) { Ok(result.clone()) @@ -195,11 +198,13 @@ impl FlightSqlServiceImpl { .unwrap() } + #[allow(clippy::result_large_err)] fn remove_plan(&self, handle: &str) -> Result<(), Status> { self.statements.remove(&handle.to_string()); Ok(()) } + #[allow(clippy::result_large_err)] fn remove_result(&self, handle: &str) -> Result<(), Status> { self.results.remove(&handle.to_string()); Ok(()) diff --git a/datafusion-examples/examples/function_factory.rs b/datafusion-examples/examples/function_factory.rs index 06367f5c09e3..21da35963345 100644 --- a/datafusion-examples/examples/function_factory.rs +++ b/datafusion-examples/examples/function_factory.rs @@ -150,10 +150,6 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { Ok(ExprSimplifyResult::Simplified(replacement)) } - fn aliases(&self) -> &[String] { - &[] - } - fn output_ordering(&self, _input: &[ExprProperties]) -> Result { Ok(SortProperties::Unordered) } @@ -189,8 +185,7 @@ impl ScalarFunctionWrapper { if let Some(value) = placeholder.strip_prefix('$') { Ok(value.parse().map(|v: usize| v - 1).map_err(|e| { DataFusionError::Execution(format!( - "Placeholder `{}` parsing error: {}!", - placeholder, e + "Placeholder `{placeholder}` parsing error: {e}!" )) })?) } else { diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/optimizer_rule.rs index 63f17484809e..176b1a69808c 100644 --- a/datafusion-examples/examples/optimizer_rule.rs +++ b/datafusion-examples/examples/optimizer_rule.rs @@ -171,7 +171,7 @@ fn is_binary_eq(binary_expr: &BinaryExpr) -> bool { /// Return true if the expression is a literal or column reference fn is_lit_or_col(expr: &Expr) -> bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_)) + matches!(expr, Expr::Column(_) | Expr::Literal(_, _)) } /// A simple user defined filter function diff --git a/datafusion-examples/examples/parquet_encrypted.rs b/datafusion-examples/examples/parquet_encrypted.rs new file mode 100644 index 000000000000..e9e239b7a1c3 --- /dev/null +++ b/datafusion-examples/examples/parquet_encrypted.rs @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::DataFusionError; +use datafusion::config::TableParquetOptions; +use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; +use datafusion::logical_expr::{col, lit}; +use datafusion::parquet::encryption::decrypt::FileDecryptionProperties; +use datafusion::parquet::encryption::encrypt::FileEncryptionProperties; +use datafusion::prelude::{ParquetReadOptions, SessionContext}; +use tempfile::TempDir; + +#[tokio::main] +async fn main() -> datafusion::common::Result<()> { + // The SessionContext is the main high level API for interacting with DataFusion + let ctx = SessionContext::new(); + + // Find the local path of "alltypes_plain.parquet" + let testdata = datafusion::test_util::parquet_test_data(); + let filename = &format!("{testdata}/alltypes_plain.parquet"); + + // Read the sample parquet file + let parquet_df = ctx + .read_parquet(filename, ParquetReadOptions::default()) + .await?; + + // Show information from the dataframe + println!( + "===============================================================================" + ); + println!("Original Parquet DataFrame:"); + query_dataframe(&parquet_df).await?; + + // Setup encryption and decryption properties + let (encrypt, decrypt) = setup_encryption(&parquet_df)?; + + // Create a temporary file location for the encrypted parquet file + let tmp_dir = TempDir::new()?; + let tempfile = tmp_dir.path().join("alltypes_plain-encrypted.parquet"); + let tempfile_str = tempfile.into_os_string().into_string().unwrap(); + + // Write encrypted parquet + let mut options = TableParquetOptions::default(); + options.crypto.file_encryption = Some((&encrypt).into()); + parquet_df + .write_parquet( + tempfile_str.as_str(), + DataFrameWriteOptions::new().with_single_file_output(true), + Some(options), + ) + .await?; + + // Read encrypted parquet + let ctx: SessionContext = SessionContext::new(); + let read_options = + ParquetReadOptions::default().file_decryption_properties((&decrypt).into()); + + let encrypted_parquet_df = ctx.read_parquet(tempfile_str, read_options).await?; + + // Show information from the dataframe + println!("\n\n==============================================================================="); + println!("Encrypted Parquet DataFrame:"); + query_dataframe(&encrypted_parquet_df).await?; + + Ok(()) +} + +// Show information from the dataframe +async fn query_dataframe(df: &DataFrame) -> Result<(), DataFusionError> { + // show its schema using 'describe' + println!("Schema:"); + df.clone().describe().await?.show().await?; + + // Select three columns and filter the results + // so that only rows where id > 1 are returned + println!("\nSelected rows and columns:"); + df.clone() + .select_columns(&["id", "bool_col", "timestamp_col"])? + .filter(col("id").gt(lit(5)))? + .show() + .await?; + + Ok(()) +} + +// Setup encryption and decryption properties +fn setup_encryption( + parquet_df: &DataFrame, +) -> Result<(FileEncryptionProperties, FileDecryptionProperties), DataFusionError> { + let schema = parquet_df.schema(); + let footer_key = b"0123456789012345".to_vec(); // 128bit/16 + let column_key = b"1234567890123450".to_vec(); // 128bit/16 + + let mut encrypt = FileEncryptionProperties::builder(footer_key.clone()); + let mut decrypt = FileDecryptionProperties::builder(footer_key.clone()); + + for field in schema.fields().iter() { + encrypt = encrypt.with_column_key(field.name().as_str(), column_key.clone()); + decrypt = decrypt.with_column_key(field.name().as_str(), column_key.clone()); + } + + let encrypt = encrypt.build()?; + let decrypt = decrypt.build()?; + Ok((encrypt, decrypt)) +} diff --git a/datafusion-examples/examples/parquet_index.rs b/datafusion-examples/examples/parquet_index.rs index 7d6ce4d86af1..e5ae3cc86bfe 100644 --- a/datafusion-examples/examples/parquet_index.rs +++ b/datafusion-examples/examples/parquet_index.rs @@ -23,6 +23,7 @@ use arrow::datatypes::{Int32Type, SchemaRef}; use arrow::util::pretty::pretty_format_batches; use async_trait::async_trait; use datafusion::catalog::Session; +use datafusion::common::pruning::PruningStatistics; use datafusion::common::{ internal_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, }; @@ -39,7 +40,7 @@ use datafusion::parquet::arrow::{ arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter, }; use datafusion::physical_expr::PhysicalExpr; -use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use datafusion::physical_optimizer::pruning::PruningPredicate; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::*; use std::any::Any; @@ -242,8 +243,7 @@ impl TableProvider for IndexTableProvider { let files = self.index.get_files(predicate.clone())?; let object_store_url = ObjectStoreUrl::parse("file://")?; - let source = - Arc::new(ParquetSource::default().with_predicate(self.schema(), predicate)); + let source = Arc::new(ParquetSource::default().with_predicate(predicate)); let mut file_scan_config_builder = FileScanConfigBuilder::new(object_store_url, self.schema(), source) .with_projection(projection.cloned()) diff --git a/datafusion-examples/examples/planner_api.rs b/datafusion-examples/examples/planner_api.rs index 4943e593bd0b..55aec7b0108a 100644 --- a/datafusion-examples/examples/planner_api.rs +++ b/datafusion-examples/examples/planner_api.rs @@ -96,7 +96,7 @@ async fn to_physical_plan_step_by_step_demo( ctx.state().config_options(), |_, _| (), )?; - println!("Analyzed logical plan:\n\n{:?}\n\n", analyzed_logical_plan); + println!("Analyzed logical plan:\n\n{analyzed_logical_plan:?}\n\n"); // Optimize the analyzed logical plan let optimized_logical_plan = ctx.state().optimizer().optimize( @@ -104,10 +104,7 @@ async fn to_physical_plan_step_by_step_demo( &ctx.state(), |_, _| (), )?; - println!( - "Optimized logical plan:\n\n{:?}\n\n", - optimized_logical_plan - ); + println!("Optimized logical plan:\n\n{optimized_logical_plan:?}\n\n"); // Create the physical plan let physical_plan = ctx diff --git a/datafusion-examples/examples/pruning.rs b/datafusion-examples/examples/pruning.rs index 4c802bcdbda0..b2d2fa13b7ed 100644 --- a/datafusion-examples/examples/pruning.rs +++ b/datafusion-examples/examples/pruning.rs @@ -20,10 +20,11 @@ use std::sync::Arc; use arrow::array::{ArrayRef, BooleanArray, Int32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::common::pruning::PruningStatistics; use datafusion::common::{DFSchema, ScalarValue}; use datafusion::execution::context::ExecutionProps; use datafusion::physical_expr::create_physical_expr; -use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use datafusion::physical_optimizer::pruning::PruningPredicate; use datafusion::prelude::*; /// This example shows how to use DataFusion's `PruningPredicate` to prove diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs index d2b2d1bf9655..b65ffb8d7174 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/simple_udtf.rs @@ -133,7 +133,8 @@ struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { fn call(&self, exprs: &[Expr]) -> Result> { - let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.first() else { + let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)), _)) = exprs.first() + else { return plan_err!("read_csv requires at least one string argument"); }; @@ -145,7 +146,7 @@ impl TableFunctionImpl for LocalCsvTableFunc { let info = SimplifyContext::new(&execution_props); let expr = ExprSimplifier::new(info).simplify(expr.clone())?; - if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr { + if let Expr::Literal(ScalarValue::Int64(Some(limit)), _) = expr { Ok(limit as usize) } else { plan_err!("Limit must be an integer") diff --git a/datafusion-examples/examples/sql_dialect.rs b/datafusion-examples/examples/sql_dialect.rs index 12141847ca36..20b515506f3b 100644 --- a/datafusion-examples/examples/sql_dialect.rs +++ b/datafusion-examples/examples/sql_dialect.rs @@ -17,10 +17,10 @@ use std::fmt::Display; -use datafusion::error::Result; +use datafusion::error::{DataFusionError, Result}; use datafusion::sql::{ parser::{CopyToSource, CopyToStatement, DFParser, DFParserBuilder, Statement}, - sqlparser::{keywords::Keyword, parser::ParserError, tokenizer::Token}, + sqlparser::{keywords::Keyword, tokenizer::Token}, }; /// This example demonstrates how to use the DFParser to parse a statement in a custom way @@ -34,8 +34,8 @@ async fn main() -> Result<()> { let my_statement = my_parser.parse_statement()?; match my_statement { - MyStatement::DFStatement(s) => println!("df: {}", s), - MyStatement::MyCopyTo(s) => println!("my_copy: {}", s), + MyStatement::DFStatement(s) => println!("df: {s}"), + MyStatement::MyCopyTo(s) => println!("my_copy: {s}"), } Ok(()) @@ -62,7 +62,7 @@ impl<'a> MyParser<'a> { /// This is the entry point to our parser -- it handles `COPY` statements specially /// but otherwise delegates to the existing DataFusion parser. - pub fn parse_statement(&mut self) -> Result { + pub fn parse_statement(&mut self) -> Result { if self.is_copy() { self.df_parser.parser.next_token(); // COPY let df_statement = self.df_parser.parse_copy()?; @@ -87,8 +87,8 @@ enum MyStatement { impl Display for MyStatement { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - MyStatement::DFStatement(s) => write!(f, "{}", s), - MyStatement::MyCopyTo(s) => write!(f, "{}", s), + MyStatement::DFStatement(s) => write!(f, "{s}"), + MyStatement::MyCopyTo(s) => write!(f, "{s}"), } } } diff --git a/datafusion-examples/examples/thread_pools.rs b/datafusion-examples/examples/thread_pools.rs new file mode 100644 index 000000000000..bba56b2932ab --- /dev/null +++ b/datafusion-examples/examples/thread_pools.rs @@ -0,0 +1,350 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This example shows how to use separate thread pools (tokio [`Runtime`]))s to +//! run the IO and CPU intensive parts of DataFusion plans. +//! +//! # Background +//! +//! DataFusion, by default, plans and executes all operations (both CPU and IO) +//! on the same thread pool. This makes it fast and easy to get started, but +//! can cause issues when running at scale, especially when fetching and operating +//! on data directly from remote sources. +//! +//! Specifically, without configuration such as in this example, DataFusion +//! plans and executes everything the same thread pool (Tokio Runtime), including +//! any I/O, such as reading Parquet files from remote object storage +//! (e.g. AWS S3), catalog access, and CPU intensive work. Running this diverse +//! workload can lead to issues described in the [Architecture section] such as +//! throttled network bandwidth (due to congestion control) and increased +//! latencies or timeouts while processing network messages. +//! +//! [Architecture section]: https://docs.rs/datafusion/latest/datafusion/index.html#thread-scheduling-cpu--io-thread-pools-and-tokio-runtimes + +use arrow::util::pretty::pretty_format_batches; +use datafusion::common::runtime::JoinSet; +use datafusion::error::Result; +use datafusion::execution::SendableRecordBatchStream; +use datafusion::prelude::*; +use futures::stream::StreamExt; +use object_store::client::SpawnedReqwestConnector; +use object_store::http::HttpBuilder; +use std::sync::Arc; +use tokio::runtime::Handle; +use tokio::sync::Notify; +use url::Url; + +/// Normally, you don't need to worry about the details of the tokio +/// [`Runtime`], but for this example it is important to understand how the +/// [`Runtime`]s work. +/// +/// Each thread has "current" runtime that is installed in a thread local +/// variable which is used by the `tokio::spawn` function. +/// +/// The `#[tokio::main]` macro creates a [`Runtime`] and installs it as +/// as the "current" runtime in a thread local variable, on which any `async` +/// [`Future`], [`Stream]`s and [`Task]`s are run. +/// +/// This example uses the runtime created by [`tokio::main`] to do I/O and spawn +/// CPU intensive tasks on a separate [`Runtime`], mirroring the common pattern +/// when using Rust libraries such as `tonic`. Using a separate `Runtime` for +/// CPU bound tasks will often be simpler in larger applications, even though it +/// makes this example slightly more complex. +#[tokio::main] +async fn main() -> Result<()> { + // The first two examples read local files. Enabling the URL table feature + // lets us treat filenames as tables in SQL. + let ctx = SessionContext::new().enable_url_table(); + let sql = format!( + "SELECT * FROM '{}/alltypes_plain.parquet'", + datafusion::test_util::parquet_test_data() + ); + + // Run a query on the current runtime. Calling `await` means the future + // (in this case the `async` function and all spawned work in DataFusion + // plans) on the current runtime. + same_runtime(&ctx, &sql).await?; + + // Run the same query but this time on a different runtime. + // + // Since we call `await` here, the `async` function itself runs on the + // current runtime, but internally `different_runtime_basic` executes the + // DataFusion plan on a different Runtime. + different_runtime_basic(ctx, sql).await?; + + // Run the same query on a different runtime, including remote IO. + // + // NOTE: This is best practice for production systems + different_runtime_advanced().await?; + + Ok(()) +} + +/// Run queries directly on the current tokio `Runtime` +/// +/// This is how most examples in DataFusion are written and works well for +/// development, local query processing, and non latency sensitive workloads. +async fn same_runtime(ctx: &SessionContext, sql: &str) -> Result<()> { + // Calling .sql is an async function as it may also do network + // I/O, for example to contact a remote catalog or do an object store LIST + let df = ctx.sql(sql).await?; + + // While many examples call `collect` or `show()`, those methods buffers the + // results. Internally DataFusion generates output a RecordBatch at a time + + // Calling `execute_stream` return a `SendableRecordBatchStream`. Depending + // on the plan, this may also do network I/O, for example to begin reading a + // parquet file from a remote object store. + let mut stream: SendableRecordBatchStream = df.execute_stream().await?; + + // `next()` drives the plan, incrementally producing new `RecordBatch`es + // using the current runtime. + // + // Perhaps somewhat non obviously, calling `next()` can also result in other + // tasks being spawned on the current runtime (e.g. for `RepartitionExec` to + // read data from each of its input partitions in parallel). + // + // Executing the plan using this pattern intermixes any IO and CPU intensive + // work on same Runtime + while let Some(batch) = stream.next().await { + println!("{}", pretty_format_batches(&[batch?]).unwrap()); + } + Ok(()) +} + +/// Run queries on a **different** Runtime dedicated for CPU bound work +/// +/// This example is suitable for running DataFusion plans against local data +/// sources (e.g. files) and returning results to an async destination, as might +/// be done to return query results to a remote client. +/// +/// Production systems which also read data locally or require very low latency +/// should follow the recommendations on [`different_runtime_advanced`] when +/// processing data from a remote source such as object storage. +async fn different_runtime_basic(ctx: SessionContext, sql: String) -> Result<()> { + // Since we are already in the context of runtime (installed by + // #[tokio::main]), we need a new Runtime (threadpool) for CPU bound tasks + let cpu_runtime = CpuRuntime::try_new()?; + + // Prepare a task that runs the plan on cpu_runtime and sends + // the results back to the original runtime via a channel. + let (tx, mut rx) = tokio::sync::mpsc::channel(2); + let driver_task = async move { + // Plan the query (which might require CPU work to evaluate statistics) + let df = ctx.sql(&sql).await?; + let mut stream: SendableRecordBatchStream = df.execute_stream().await?; + + // Calling `next()` to drive the plan in this task drives the + // execution from the cpu runtime the other thread pool + // + // NOTE any IO run by this plan (for example, reading from an + // `ObjectStore`) will be done on this new thread pool as well. + while let Some(batch) = stream.next().await { + if tx.send(batch).await.is_err() { + // error means dropped receiver, so nothing will get results anymore + return Ok(()); + } + } + Ok(()) as Result<()> + }; + + // Run the driver task on the cpu runtime. Use a JoinSet to + // ensure the spawned task is canceled on error/drop + let mut join_set = JoinSet::new(); + join_set.spawn_on(driver_task, cpu_runtime.handle()); + + // Retrieve the results in the original (IO) runtime. This requires only + // minimal work (pass pointers around). + while let Some(batch) = rx.recv().await { + println!("{}", pretty_format_batches(&[batch?])?); + } + + // wait for completion of the driver task + drain_join_set(join_set).await; + + Ok(()) +} + +/// Run CPU intensive work on a different runtime but do IO operations (object +/// store access) on the current runtime. +async fn different_runtime_advanced() -> Result<()> { + // In this example, we will query a file via https, reading + // the data directly from the plan + + // The current runtime (created by tokio::main) is used for IO + // + // Note this handle should be used for *ALL* remote IO operations in your + // systems, including remote catalog access, which is not included in this + // example. + let cpu_runtime = CpuRuntime::try_new()?; + let io_handle = Handle::current(); + + let ctx = SessionContext::new(); + + // By default, the HttpStore use the same runtime that calls `await` for IO + // operations. This means that if the DataFusion plan is called from the + // cpu_runtime, the HttpStore IO operations will *also* run on the CPU + // runtime, which will error. + // + // To avoid this, we use a `SpawnedReqwestConnector` to configure the + // `ObjectStore` to run the HTTP requests on the IO runtime. + let base_url = Url::parse("https://github.com").unwrap(); + let http_store = HttpBuilder::new() + .with_url(base_url.clone()) + // Use the io_runtime to run the HTTP requests. Without this line, + // you will see an error such as: + // A Tokio 1.x context was found, but IO is disabled. + .with_http_connector(SpawnedReqwestConnector::new(io_handle)) + .build()?; + + // Tell DataFusion to process `http://` urls with this wrapped object store + ctx.register_object_store(&base_url, Arc::new(http_store)); + + // As above, plan and execute the query on the cpu runtime. + let (tx, mut rx) = tokio::sync::mpsc::channel(2); + let driver_task = async move { + // Plan / execute the query + let url = "https://github.com/apache/arrow-testing/raw/master/data/csv/aggregate_test_100.csv"; + let df = ctx + .sql(&format!("SELECT c1,c2,c3 FROM '{url}' LIMIT 5")) + .await?; + + let mut stream: SendableRecordBatchStream = df.execute_stream().await?; + + // Note you can do other non trivial CPU work on the results of the + // stream before sending it back to the original runtime. For example, + // calling a FlightDataEncoder to convert the results to flight messages + // to send over the network + + // send results, as above + while let Some(batch) = stream.next().await { + if tx.send(batch).await.is_err() { + return Ok(()); + } + } + Ok(()) as Result<()> + }; + + let mut join_set = JoinSet::new(); + join_set.spawn_on(driver_task, cpu_runtime.handle()); + while let Some(batch) = rx.recv().await { + println!("{}", pretty_format_batches(&[batch?])?); + } + + Ok(()) +} + +/// Waits for all tasks in the JoinSet to complete and reports any errors that +/// occurred. +/// +/// If we don't do this, any errors that occur in the task (such as IO errors) +/// are not reported. +async fn drain_join_set(mut join_set: JoinSet>) { + // retrieve any errors from the tasks + while let Some(result) = join_set.join_next().await { + match result { + Ok(Ok(())) => {} // task completed successfully + Ok(Err(e)) => eprintln!("Task failed: {e}"), // task failed + Err(e) => eprintln!("JoinSet error: {e}"), // JoinSet error + } + } +} + +/// Creates a Tokio [`Runtime`] for use with CPU bound tasks +/// +/// Tokio forbids dropping `Runtime`s in async contexts, so creating a separate +/// `Runtime` correctly is somewhat tricky. This structure manages the creation +/// and shutdown of a separate thread. +/// +/// # Notes +/// On drop, the thread will wait for all remaining tasks to complete. +/// +/// Depending on your application, more sophisticated shutdown logic may be +/// required, such as ensuring that no new tasks are added to the runtime. +/// +/// # Credits +/// This code is derived from code originally written for [InfluxDB 3.0] +/// +/// [InfluxDB 3.0]: https://github.com/influxdata/influxdb3_core/tree/6fcbb004232738d55655f32f4ad2385523d10696/executor +struct CpuRuntime { + /// Handle is the tokio structure for interacting with a Runtime. + handle: Handle, + /// Signal to start shutting down + notify_shutdown: Arc, + /// When thread is active, is Some + thread_join_handle: Option>, +} + +impl Drop for CpuRuntime { + fn drop(&mut self) { + // Notify the thread to shutdown. + self.notify_shutdown.notify_one(); + // In a production system you also need to ensure your code stops adding + // new tasks to the underlying runtime after this point to allow the + // thread to complete its work and exit cleanly. + if let Some(thread_join_handle) = self.thread_join_handle.take() { + // If the thread is still running, we wait for it to finish + print!("Shutting down CPU runtime thread..."); + if let Err(e) = thread_join_handle.join() { + eprintln!("Error joining CPU runtime thread: {e:?}",); + } else { + println!("CPU runtime thread shutdown successfully."); + } + } + } +} + +impl CpuRuntime { + /// Create a new Tokio Runtime for CPU bound tasks + pub fn try_new() -> Result { + let cpu_runtime = tokio::runtime::Builder::new_multi_thread() + .enable_time() + .build()?; + let handle = cpu_runtime.handle().clone(); + let notify_shutdown = Arc::new(Notify::new()); + let notify_shutdown_captured = Arc::clone(¬ify_shutdown); + + // The cpu_runtime runs and is dropped on a separate thread + let thread_join_handle = std::thread::spawn(move || { + cpu_runtime.block_on(async move { + notify_shutdown_captured.notified().await; + }); + // Note: cpu_runtime is dropped here, which will wait for all tasks + // to complete + }); + + Ok(Self { + handle, + notify_shutdown, + thread_join_handle: Some(thread_join_handle), + }) + } + + /// Return a handle suitable for spawning CPU bound tasks + /// + /// # Notes + /// + /// If a task spawned on this handle attempts to do IO, it will error with a + /// message such as: + /// + /// ```text + ///A Tokio 1.x context was found, but IO is disabled. + /// ``` + pub fn handle(&self) -> &Handle { + &self.handle + } +} diff --git a/datafusion/catalog-listing/Cargo.toml b/datafusion/catalog-listing/Cargo.toml index 734580202232..b88461e7ebcb 100644 --- a/datafusion/catalog-listing/Cargo.toml +++ b/datafusion/catalog-listing/Cargo.toml @@ -48,7 +48,6 @@ object_store = { workspace = true } tokio = { workspace = true } [dev-dependencies] -tempfile = { workspace = true } [lints] workspace = true diff --git a/datafusion/catalog-listing/README.md b/datafusion/catalog-listing/README.md index b4760c413d60..c8d1cf13b4ff 100644 --- a/datafusion/catalog-listing/README.md +++ b/datafusion/catalog-listing/README.md @@ -25,6 +25,12 @@ This crate is a submodule of DataFusion with [ListingTable], an implementation of [TableProvider] based on files in a directory (either locally or on remote object storage such as S3). +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[df]: https://crates.io/crates/datafusion [df]: https://crates.io/crates/datafusion [listingtable]: https://docs.rs/datafusion/latest/datafusion/datasource/listing/struct.ListingTable.html [tableprovider]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 8efb74d4ea1e..00e9c71df348 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -61,7 +61,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { Ok(TreeNodeRecursion::Stop) } } - Expr::Literal(_) + Expr::Literal(_, _) | Expr::Alias(_) | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) @@ -346,8 +346,8 @@ fn populate_partition_values<'a>( { match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { - (Expr::Column(Column { ref name, .. }), Expr::Literal(val)) - | (Expr::Literal(val), Expr::Column(Column { ref name, .. })) => { + (Expr::Column(Column { ref name, .. }), Expr::Literal(val, _)) + | (Expr::Literal(val, _), Expr::Column(Column { ref name, .. })) => { if partition_values .insert(name, PartitionValue::Single(val.to_string())) .is_some() @@ -507,11 +507,7 @@ where Some((name, val)) if name == pn => part_values.push(val), _ => { debug!( - "Ignoring file: file_path='{}', table_path='{}', part='{}', partition_col='{}'", - file_path, - table_path, - part, - pn, + "Ignoring file: file_path='{file_path}', table_path='{table_path}', part='{part}', partition_col='{pn}'", ); return None; } @@ -988,7 +984,7 @@ mod tests { assert_eq!( evaluate_partition_prefix( partitions, - &[col("a").eq(Expr::Literal(ScalarValue::Date32(Some(3))))], + &[col("a").eq(Expr::Literal(ScalarValue::Date32(Some(3)), None))], ), Some(Path::from("a=1970-01-04")), ); @@ -997,9 +993,10 @@ mod tests { assert_eq!( evaluate_partition_prefix( partitions, - &[col("a").eq(Expr::Literal(ScalarValue::Date64(Some( - 4 * 24 * 60 * 60 * 1000 - )))),], + &[col("a").eq(Expr::Literal( + ScalarValue::Date64(Some(4 * 24 * 60 * 60 * 1000)), + None + )),], ), Some(Path::from("a=1970-01-05")), ); diff --git a/datafusion/catalog/README.md b/datafusion/catalog/README.md index 5b201e736fdc..d4870e28f338 100644 --- a/datafusion/catalog/README.md +++ b/datafusion/catalog/README.md @@ -23,4 +23,9 @@ This crate is a submodule of DataFusion that provides catalog management functionality, including catalogs, schemas, and tables. +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/catalog/src/information_schema.rs b/datafusion/catalog/src/information_schema.rs index 7948c0299d39..83b6d64ef47b 100644 --- a/datafusion/catalog/src/information_schema.rs +++ b/datafusion/catalog/src/information_schema.rs @@ -30,6 +30,7 @@ use arrow::{ use async_trait::async_trait; use datafusion_common::config::{ConfigEntry, ConfigOptions}; use datafusion_common::error::Result; +use datafusion_common::types::NativeType; use datafusion_common::DataFusionError; use datafusion_execution::TaskContext; use datafusion_expr::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; @@ -37,7 +38,7 @@ use datafusion_expr::{TableType, Volatility}; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::streaming::PartitionStream; use datafusion_physical_plan::SendableRecordBatchStream; -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::fmt::Debug; use std::{any::Any, sync::Arc}; @@ -102,12 +103,14 @@ impl InformationSchemaConfig { // schema name may not exist in the catalog, so we need to check if let Some(schema) = catalog.schema(&schema_name) { for table_name in schema.table_names() { - if let Some(table) = schema.table(&table_name).await? { + if let Some(table_type) = + schema.table_type(&table_name).await? + { builder.add_table( &catalog_name, &schema_name, &table_name, - table.table_type(), + table_type, ); } } @@ -403,58 +406,63 @@ impl InformationSchemaConfig { /// returns a tuple of (arg_types, return_type) fn get_udf_args_and_return_types( udf: &Arc, -) -> Result, Option)>> { +) -> Result, Option)>> { let signature = udf.signature(); let arg_types = signature.type_signature.get_example_types(); if arg_types.is_empty() { - Ok(vec![(vec![], None)]) + Ok(vec![(vec![], None)].into_iter().collect::>()) } else { Ok(arg_types .into_iter() .map(|arg_types| { // only handle the function which implemented [`ScalarUDFImpl::return_type`] method - let return_type = udf.return_type(&arg_types).ok().map(|t| t.to_string()); + let return_type = udf + .return_type(&arg_types) + .map(|t| remove_native_type_prefix(NativeType::from(t))) + .ok(); let arg_types = arg_types .into_iter() - .map(|t| t.to_string()) + .map(|t| remove_native_type_prefix(NativeType::from(t))) .collect::>(); (arg_types, return_type) }) - .collect::>()) + .collect::>()) } } fn get_udaf_args_and_return_types( udaf: &Arc, -) -> Result, Option)>> { +) -> Result, Option)>> { let signature = udaf.signature(); let arg_types = signature.type_signature.get_example_types(); if arg_types.is_empty() { - Ok(vec![(vec![], None)]) + Ok(vec![(vec![], None)].into_iter().collect::>()) } else { Ok(arg_types .into_iter() .map(|arg_types| { // only handle the function which implemented [`ScalarUDFImpl::return_type`] method - let return_type = - udaf.return_type(&arg_types).ok().map(|t| t.to_string()); + let return_type = udaf + .return_type(&arg_types) + .ok() + .map(|t| remove_native_type_prefix(NativeType::from(t))); let arg_types = arg_types .into_iter() - .map(|t| t.to_string()) + .map(|t| remove_native_type_prefix(NativeType::from(t))) .collect::>(); (arg_types, return_type) }) - .collect::>()) + .collect::>()) } } fn get_udwf_args_and_return_types( udwf: &Arc, -) -> Result, Option)>> { +) -> Result, Option)>> { let signature = udwf.signature(); let arg_types = signature.type_signature.get_example_types(); if arg_types.is_empty() { - Ok(vec![(vec![], None)]) + Ok(vec![(vec![], None)].into_iter().collect::>()) } else { Ok(arg_types .into_iter() @@ -462,14 +470,19 @@ fn get_udwf_args_and_return_types( // only handle the function which implemented [`ScalarUDFImpl::return_type`] method let arg_types = arg_types .into_iter() - .map(|t| t.to_string()) + .map(|t| remove_native_type_prefix(NativeType::from(t))) .collect::>(); (arg_types, None) }) - .collect::>()) + .collect::>()) } } +#[inline] +fn remove_native_type_prefix(native_type: NativeType) -> String { + format!("{native_type:?}") +} + #[async_trait] impl SchemaProvider for InformationSchemaProvider { fn as_any(&self) -> &dyn Any { @@ -1348,3 +1361,92 @@ impl PartitionStream for InformationSchemaParameters { )) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::CatalogProvider; + + #[tokio::test] + async fn make_tables_uses_table_type() { + let config = InformationSchemaConfig { + catalog_list: Arc::new(Fixture), + }; + let mut builder = InformationSchemaTablesBuilder { + catalog_names: StringBuilder::new(), + schema_names: StringBuilder::new(), + table_names: StringBuilder::new(), + table_types: StringBuilder::new(), + schema: Arc::new(Schema::empty()), + }; + + assert!(config.make_tables(&mut builder).await.is_ok()); + + assert_eq!("BASE TABLE", builder.table_types.finish().value(0)); + } + + #[derive(Debug)] + struct Fixture; + + #[async_trait] + impl SchemaProvider for Fixture { + // InformationSchemaConfig::make_tables should use this. + async fn table_type(&self, _: &str) -> Result> { + Ok(Some(TableType::Base)) + } + + // InformationSchemaConfig::make_tables used this before `table_type` + // existed but should not, as it may be expensive. + async fn table(&self, _: &str) -> Result>> { + panic!("InformationSchemaConfig::make_tables called SchemaProvider::table instead of table_type") + } + + fn as_any(&self) -> &dyn Any { + unimplemented!("not required for these tests") + } + + fn table_names(&self) -> Vec { + vec!["atable".to_string()] + } + + fn table_exist(&self, _: &str) -> bool { + unimplemented!("not required for these tests") + } + } + + impl CatalogProviderList for Fixture { + fn as_any(&self) -> &dyn Any { + unimplemented!("not required for these tests") + } + + fn register_catalog( + &self, + _: String, + _: Arc, + ) -> Option> { + unimplemented!("not required for these tests") + } + + fn catalog_names(&self) -> Vec { + vec!["acatalog".to_string()] + } + + fn catalog(&self, _: &str) -> Option> { + Some(Arc::new(Self)) + } + } + + impl CatalogProvider for Fixture { + fn as_any(&self) -> &dyn Any { + unimplemented!("not required for these tests") + } + + fn schema_names(&self) -> Vec { + vec!["aschema".to_string()] + } + + fn schema(&self, _: &str) -> Option> { + Some(Arc::new(Self)) + } + } +} diff --git a/datafusion/catalog/src/listing_schema.rs b/datafusion/catalog/src/listing_schema.rs index cc2c2ee606b3..2e4eac964b18 100644 --- a/datafusion/catalog/src/listing_schema.rs +++ b/datafusion/catalog/src/listing_schema.rs @@ -25,9 +25,7 @@ use std::sync::{Arc, Mutex}; use crate::{SchemaProvider, TableProvider, TableProviderFactory}; use crate::Session; -use datafusion_common::{ - Constraints, DFSchema, DataFusionError, HashMap, TableReference, -}; +use datafusion_common::{DFSchema, DataFusionError, HashMap, TableReference}; use datafusion_expr::CreateExternalTable; use async_trait::async_trait; @@ -143,7 +141,7 @@ impl ListingSchemaProvider { order_exprs: vec![], unbounded: false, options: Default::default(), - constraints: Constraints::empty(), + constraints: Default::default(), column_defaults: Default::default(), }, ) diff --git a/datafusion/catalog/src/memory/table.rs b/datafusion/catalog/src/memory/table.rs index 81243e2c4889..e996e1974d9e 100644 --- a/datafusion/catalog/src/memory/table.rs +++ b/datafusion/catalog/src/memory/table.rs @@ -23,25 +23,22 @@ use std::fmt::Debug; use std::sync::Arc; use crate::TableProvider; -use datafusion_common::error::Result; -use datafusion_expr::Expr; -use datafusion_expr::TableType; -use datafusion_physical_expr::create_physical_sort_exprs; -use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::{ - common, ExecutionPlan, ExecutionPlanProperties, Partitioning, -}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use datafusion_common::error::Result; use datafusion_common::{not_impl_err, plan_err, Constraints, DFSchema, SchemaExt}; use datafusion_common_runtime::JoinSet; -use datafusion_datasource::memory::MemSink; -use datafusion_datasource::memory::MemorySourceConfig; +use datafusion_datasource::memory::{MemSink, MemorySourceConfig}; use datafusion_datasource::sink::DataSinkExec; use datafusion_datasource::source::DataSourceExec; use datafusion_expr::dml::InsertOp; -use datafusion_expr::SortExpr; +use datafusion_expr::{Expr, SortExpr, TableType}; +use datafusion_physical_expr::{create_physical_sort_exprs, LexOrdering}; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::{ + common, ExecutionPlan, ExecutionPlanProperties, Partitioning, +}; use datafusion_session::Session; use async_trait::async_trait; @@ -89,7 +86,7 @@ impl MemTable { .into_iter() .map(|e| Arc::new(RwLock::new(e))) .collect::>(), - constraints: Constraints::empty(), + constraints: Constraints::default(), column_defaults: HashMap::new(), sort_order: Arc::new(Mutex::new(vec![])), }) @@ -239,16 +236,13 @@ impl TableProvider for MemTable { if !sort_order.is_empty() { let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?; - let file_sort_order = sort_order - .iter() - .map(|sort_exprs| { - create_physical_sort_exprs( - sort_exprs, - &df_schema, - state.execution_props(), - ) - }) - .collect::>>()?; + let eqp = state.execution_props(); + let mut file_sort_order = vec![]; + for sort_exprs in sort_order.iter() { + let physical_exprs = + create_physical_sort_exprs(sort_exprs, &df_schema, eqp)?; + file_sort_order.extend(LexOrdering::new(physical_exprs)); + } source = source.try_with_sort_information(file_sort_order)?; } diff --git a/datafusion/catalog/src/schema.rs b/datafusion/catalog/src/schema.rs index 5b37348fd742..9ba55256f182 100644 --- a/datafusion/catalog/src/schema.rs +++ b/datafusion/catalog/src/schema.rs @@ -26,6 +26,7 @@ use std::sync::Arc; use crate::table::TableProvider; use datafusion_common::Result; +use datafusion_expr::TableType; /// Represents a schema, comprising a number of named tables. /// @@ -54,6 +55,14 @@ pub trait SchemaProvider: Debug + Sync + Send { name: &str, ) -> Result>, DataFusionError>; + /// Retrieves the type of a specific table from the schema by name, if it exists, otherwise + /// returns `None`. Implementations for which this operation is cheap but [Self::table] is + /// expensive can override this to improve operations that only need the type, e.g. + /// `SELECT * FROM information_schema.tables`. + async fn table_type(&self, name: &str) -> Result> { + self.table(name).await.map(|o| o.map(|t| t.table_type())) + } + /// If supported by the implementation, adds a new table named `name` to /// this schema. /// diff --git a/datafusion/catalog/src/stream.rs b/datafusion/catalog/src/stream.rs index fbfab513229e..99c432b738e5 100644 --- a/datafusion/catalog/src/stream.rs +++ b/datafusion/catalog/src/stream.rs @@ -256,7 +256,7 @@ impl StreamConfig { Self { source, order: vec![], - constraints: Constraints::empty(), + constraints: Constraints::default(), } } @@ -350,15 +350,10 @@ impl TableProvider for StreamTable { input: Arc, _insert_op: InsertOp, ) -> Result> { - let ordering = match self.0.order.first() { - Some(x) => { - let schema = self.0.source.schema(); - let orders = create_ordering(schema, std::slice::from_ref(x))?; - let ordering = orders.into_iter().next().unwrap(); - Some(ordering.into_iter().map(Into::into).collect()) - } - None => None, - }; + let schema = self.0.source.schema(); + let orders = create_ordering(schema, &self.0.order)?; + // It is sufficient to pass only one of the equivalent orderings: + let ordering = orders.into_iter().next().map(Into::into); Ok(Arc::new(DataSinkExec::new( input, diff --git a/datafusion/common-runtime/Cargo.toml b/datafusion/common-runtime/Cargo.toml index 5e7816b669de..7ddc021e640c 100644 --- a/datafusion/common-runtime/Cargo.toml +++ b/datafusion/common-runtime/Cargo.toml @@ -43,4 +43,4 @@ log = { workspace = true } tokio = { workspace = true } [dev-dependencies] -tokio = { version = "1.44", features = ["rt", "rt-multi-thread", "time"] } +tokio = { version = "1.45", features = ["rt", "rt-multi-thread", "time"] } diff --git a/datafusion/common-runtime/README.md b/datafusion/common-runtime/README.md index 77100e52603c..bd0d4954b845 100644 --- a/datafusion/common-runtime/README.md +++ b/datafusion/common-runtime/README.md @@ -23,4 +23,9 @@ This crate is a submodule of DataFusion that provides common utilities. +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 74e99163955e..b356f249b79b 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -55,15 +55,17 @@ apache-avro = { version = "0.17", default-features = false, features = [ arrow = { workspace = true } arrow-ipc = { workspace = true } base64 = "0.22.1" +chrono = { workspace = true } half = { workspace = true } hashbrown = { workspace = true } +hex = "0.4.3" indexmap = { workspace = true } -libc = "0.2.171" +libc = "0.2.174" log = { workspace = true } object_store = { workspace = true, optional = true } parquet = { workspace = true, optional = true, default-features = true } paste = "1.0.15" -pyo3 = { version = "0.24.0", optional = true } +pyo3 = { version = "0.24.2", optional = true } recursive = { workspace = true, optional = true } sqlparser = { workspace = true } tokio = { workspace = true } diff --git a/datafusion/common/README.md b/datafusion/common/README.md index 524ab4420d2a..e4d6b772658d 100644 --- a/datafusion/common/README.md +++ b/datafusion/common/README.md @@ -23,4 +23,9 @@ This crate is a submodule of DataFusion that provides common data types and utilities. +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index 50a4e257d1c9..b3acaeee5a54 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -130,8 +130,8 @@ impl Column { /// where `"foo.BAR"` would be parsed to a reference to column named `foo.BAR` pub fn from_qualified_name(flat_name: impl Into) -> Self { let flat_name = flat_name.into(); - Self::from_idents(parse_identifiers_normalized(&flat_name, false)).unwrap_or( - Self { + Self::from_idents(parse_identifiers_normalized(&flat_name, false)).unwrap_or_else( + || Self { relation: None, name: flat_name, spans: Spans::new(), @@ -142,8 +142,8 @@ impl Column { /// Deserialize a fully qualified name string into a column preserving column text case pub fn from_qualified_name_ignore_case(flat_name: impl Into) -> Self { let flat_name = flat_name.into(); - Self::from_idents(parse_identifiers_normalized(&flat_name, true)).unwrap_or( - Self { + Self::from_idents(parse_identifiers_normalized(&flat_name, true)).unwrap_or_else( + || Self { relation: None, name: flat_name, spans: Spans::new(), diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 1c746a4e9840..6618c6aeec28 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -17,16 +17,24 @@ //! Runtime configuration, via [`ConfigOptions`] +use arrow_ipc::CompressionType; + +use crate::error::_config_err; +use crate::parsers::CompressionTypeVariant; +use crate::utils::get_available_parallelism; +use crate::{DataFusionError, Result}; use std::any::Any; use std::collections::{BTreeMap, HashMap}; use std::error::Error; use std::fmt::{self, Display}; use std::str::FromStr; -use crate::error::_config_err; -use crate::parsers::CompressionTypeVariant; -use crate::utils::get_available_parallelism; -use crate::{DataFusionError, Result}; +#[cfg(feature = "parquet")] +use hex; +#[cfg(feature = "parquet")] +use parquet::encryption::decrypt::FileDecryptionProperties; +#[cfg(feature = "parquet")] +use parquet::encryption::encrypt::FileEncryptionProperties; /// A macro that wraps a configuration struct and automatically derives /// [`Default`] and [`ConfigField`] for it, allowing it to be used @@ -188,7 +196,6 @@ macro_rules! config_namespace { } } } - config_namespace! { /// Options related to catalog and directory scanning /// @@ -260,10 +267,10 @@ config_namespace! { /// string length and thus DataFusion can not enforce such limits. pub support_varchar_with_length: bool, default = true - /// If true, `VARCHAR` is mapped to `Utf8View` during SQL planning. - /// If false, `VARCHAR` is mapped to `Utf8` during SQL planning. - /// Default is false. - pub map_varchar_to_utf8view: bool, default = false + /// If true, string types (VARCHAR, CHAR, Text, and String) are mapped to `Utf8View` during SQL planning. + /// If false, they are mapped to `Utf8`. + /// Default is true. + pub map_string_types_to_utf8view: bool, default = true /// When set to true, the source locations relative to the original SQL /// query (i.e. [`Span`](https://docs.rs/sqlparser/latest/sqlparser/tokenizer/struct.Span.html)) will be collected @@ -275,6 +282,61 @@ config_namespace! { } } +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub enum SpillCompression { + Zstd, + Lz4Frame, + #[default] + Uncompressed, +} + +impl FromStr for SpillCompression { + type Err = DataFusionError; + + fn from_str(s: &str) -> Result { + match s.to_ascii_lowercase().as_str() { + "zstd" => Ok(Self::Zstd), + "lz4_frame" => Ok(Self::Lz4Frame), + "uncompressed" | "" => Ok(Self::Uncompressed), + other => Err(DataFusionError::Configuration(format!( + "Invalid Spill file compression type: {other}. Expected one of: zstd, lz4_frame, uncompressed" + ))), + } + } +} + +impl ConfigField for SpillCompression { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { + v.some(key, self, description) + } + + fn set(&mut self, _: &str, value: &str) -> Result<()> { + *self = SpillCompression::from_str(value)?; + Ok(()) + } +} + +impl Display for SpillCompression { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let str = match self { + Self::Zstd => "zstd", + Self::Lz4Frame => "lz4_frame", + Self::Uncompressed => "uncompressed", + }; + write!(f, "{str}") + } +} + +impl From for Option { + fn from(c: SpillCompression) -> Self { + match c { + SpillCompression::Zstd => Some(CompressionType::ZSTD), + SpillCompression::Lz4Frame => Some(CompressionType::LZ4_FRAME), + SpillCompression::Uncompressed => None, + } + } +} + config_namespace! { /// Options related to query execution /// @@ -293,20 +355,22 @@ config_namespace! { /// target batch size is determined by the configuration setting pub coalesce_batches: bool, default = true - /// Should DataFusion collect statistics after listing files - pub collect_statistics: bool, default = false + /// Should DataFusion collect statistics when first creating a table. + /// Has no effect after the table is created. Applies to the default + /// `ListingTableProvider` in DataFusion. Defaults to true. + pub collect_statistics: bool, default = true /// Number of partitions for query execution. Increasing partitions can increase /// concurrency. /// /// Defaults to the number of CPU cores on the system - pub target_partitions: usize, default = get_available_parallelism() + pub target_partitions: usize, transform = ExecutionOptions::normalized_parallelism, default = get_available_parallelism() /// The default time zone /// /// Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime /// according to this time zone, and then extract the hour - pub time_zone: Option, default = Some("+00:00".into()) + pub time_zone: String, default = "+00:00".into() /// Parquet options pub parquet: ParquetOptions, default = Default::default() @@ -316,7 +380,7 @@ config_namespace! { /// This is mostly use to plan `UNION` children in parallel. /// /// Defaults to the number of CPU cores on the system - pub planning_concurrency: usize, default = get_available_parallelism() + pub planning_concurrency: usize, transform = ExecutionOptions::normalized_parallelism, default = get_available_parallelism() /// When set to true, skips verifying that the schema produced by /// planning the input of `LogicalPlan::Aggregate` exactly matches the @@ -329,6 +393,16 @@ config_namespace! { /// the new schema verification step. pub skip_physical_aggregate_schema_check: bool, default = false + /// Sets the compression codec used when spilling data to disk. + /// + /// Since datafusion writes spill files using the Arrow IPC Stream format, + /// only codecs supported by the Arrow IPC Stream Writer are allowed. + /// Valid values are: uncompressed, lz4_frame, zstd. + /// Note: lz4_frame offers faster (de)compression, but typically results in + /// larger spill files. In contrast, zstd achieves + /// higher compression ratios at the cost of slower (de)compression speed. + pub spill_compression: SpillCompression, default = SpillCompression::Uncompressed + /// Specifies the reserved memory for each spillable sort operation to /// facilitate an in-memory merge. /// @@ -405,6 +479,13 @@ config_namespace! { /// in joins can reduce memory usage when joining large /// tables with a highly-selective join filter, but is also slightly slower. pub enforce_batch_size_in_joins: bool, default = false + + /// Size (bytes) of data buffer DataFusion uses when writing output files. + /// This affects the size of the data chunks that are uploaded to remote + /// object stores (e.g. AWS S3). If very large (>= 100 GiB) output files are being + /// written, it may be necessary to increase this size to avoid errors from + /// the remote end point. + pub objectstore_writer_buffer_size: usize, default = 10 * 1024 * 1024 } } @@ -467,6 +548,9 @@ config_namespace! { /// nanosecond resolution. pub coerce_int96: Option, transform = str::to_lowercase, default = None + /// (reading) Use any available bloom filters when reading parquet files + pub bloom_filter_on_read: bool, default = true + // The following options affect writing to parquet files // and map to parquet::file::properties::WriterProperties @@ -542,9 +626,6 @@ config_namespace! { /// default parquet writer setting pub encoding: Option, transform = str::to_lowercase, default = None - /// (writing) Use any available bloom filters when reading parquet files - pub bloom_filter_on_read: bool, default = true - /// (writing) Write bloom filters for all columns when creating parquet files pub bloom_filter_on_write: bool, default = false @@ -586,6 +667,17 @@ config_namespace! { } } +config_namespace! { + /// Options for configuring Parquet Modular Encryption + pub struct ParquetEncryptionOptions { + /// Optional file decryption properties + pub file_decryption: Option, default = None + + /// Optional file encryption properties + pub file_encryption: Option, default = None + } +} + config_namespace! { /// Options related to query optimization /// @@ -606,6 +698,13 @@ config_namespace! { /// during aggregations, if possible pub enable_topk_aggregation: bool, default = true + /// When set to true attempts to push down dynamic filters generated by operators into the file scan phase. + /// For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer + /// will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. + /// This means that if we already have 10 timestamps in the year 2025 + /// any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. + pub enable_dynamic_filter_pushdown: bool, default = true + /// When set to true, the optimizer will insert filters before a join between /// a nullable and non-nullable column to filter out nulls on the nullable side. This /// filter can add additional overhead when the file format does not fully support @@ -632,13 +731,20 @@ config_namespace! { /// long runner execution, all types of joins may encounter out-of-memory errors. pub allow_symmetric_joins_without_pruning: bool, default = true - /// When set to `true`, file groups will be repartitioned to achieve maximum parallelism. - /// Currently Parquet and CSV formats are supported. + /// When set to `true`, datasource partitions will be repartitioned to achieve maximum parallelism. + /// This applies to both in-memory partitions and FileSource's file groups (1 group is 1 partition). + /// + /// For FileSources, only Parquet and CSV formats are currently supported. /// - /// If set to `true`, all files will be repartitioned evenly (i.e., a single large file + /// If set to `true` for a FileSource, all files will be repartitioned evenly (i.e., a single large file /// might be partitioned into smaller chunks) for parallel scanning. - /// If set to `false`, different files will be read in parallel, but repartitioning won't + /// If set to `false` for a FileSource, different files will be read in parallel, but repartitioning won't /// happen within a single file. + /// + /// If set to `true` for an in-memory source, all memtable's partitions will have their batches + /// repartitioned evenly to the desired number of `target_partitions`. Repartitioning can change + /// the total number of partitions and batches per partition, but does not slice the initial + /// record tables provided to the MemTable on creation. pub repartition_file_scans: bool, default = true /// Should DataFusion repartition data using the partitions keys to execute window @@ -739,6 +845,72 @@ config_namespace! { } } +impl ExecutionOptions { + /// Returns the correct parallelism based on the provided `value`. + /// If `value` is `"0"`, returns the default available parallelism, computed with + /// `get_available_parallelism`. Otherwise, returns `value`. + fn normalized_parallelism(value: &str) -> String { + if value.parse::() == Ok(0) { + get_available_parallelism().to_string() + } else { + value.to_owned() + } + } +} + +config_namespace! { + /// Options controlling the format of output when printing record batches + /// Copies [`arrow::util::display::FormatOptions`] + pub struct FormatOptions { + /// If set to `true` any formatting errors will be written to the output + /// instead of being converted into a [`std::fmt::Error`] + pub safe: bool, default = true + /// Format string for nulls + pub null: String, default = "".into() + /// Date format for date arrays + pub date_format: Option, default = Some("%Y-%m-%d".to_string()) + /// Format for DateTime arrays + pub datetime_format: Option, default = Some("%Y-%m-%dT%H:%M:%S%.f".to_string()) + /// Timestamp format for timestamp arrays + pub timestamp_format: Option, default = Some("%Y-%m-%dT%H:%M:%S%.f".to_string()) + /// Timestamp format for timestamp with timezone arrays. When `None`, ISO 8601 format is used. + pub timestamp_tz_format: Option, default = None + /// Time format for time arrays + pub time_format: Option, default = Some("%H:%M:%S%.f".to_string()) + /// Duration format. Can be either `"pretty"` or `"ISO8601"` + pub duration_format: String, transform = str::to_lowercase, default = "pretty".into() + /// Show types in visual representation batches + pub types_info: bool, default = false + } +} + +impl<'a> TryInto> for &'a FormatOptions { + type Error = DataFusionError; + fn try_into(self) -> Result> { + let duration_format = match self.duration_format.as_str() { + "pretty" => arrow::util::display::DurationFormat::Pretty, + "iso8601" => arrow::util::display::DurationFormat::ISO8601, + _ => { + return _config_err!( + "Invalid duration format: {}. Valid values are pretty or iso8601", + self.duration_format + ) + } + }; + + Ok(arrow::util::display::FormatOptions::new() + .with_display_error(self.safe) + .with_null(&self.null) + .with_date_format(self.date_format.as_deref()) + .with_datetime_format(self.datetime_format.as_deref()) + .with_timestamp_format(self.timestamp_format.as_deref()) + .with_timestamp_tz_format(self.timestamp_tz_format.as_deref()) + .with_time_format(self.time_format.as_deref()) + .with_duration_format(duration_format) + .with_types_info(self.types_info)) + } +} + /// A key value pair, with a corresponding description #[derive(Debug)] pub struct ConfigEntry { @@ -768,6 +940,8 @@ pub struct ConfigOptions { pub explain: ExplainOptions, /// Optional extensions registered using [`Extensions::insert`] pub extensions: Extensions, + /// Formatting options when printing batches + pub format: FormatOptions, } impl ConfigField for ConfigOptions { @@ -780,6 +954,7 @@ impl ConfigField for ConfigOptions { "optimizer" => self.optimizer.set(rem, value), "explain" => self.explain.set(rem, value), "sql_parser" => self.sql_parser.set(rem, value), + "format" => self.format.set(rem, value), _ => _config_err!("Config value \"{key}\" not found on ConfigOptions"), } } @@ -790,6 +965,7 @@ impl ConfigField for ConfigOptions { self.optimizer.visit(v, "datafusion.optimizer", ""); self.explain.visit(v, "datafusion.explain", ""); self.sql_parser.visit(v, "datafusion.sql_parser", ""); + self.format.visit(v, "datafusion.format", ""); } } @@ -851,7 +1027,9 @@ impl ConfigOptions { for key in keys.0 { let env = key.to_uppercase().replace('.', "_"); if let Some(var) = std::env::var_os(env) { - ret.set(&key, var.to_string_lossy().as_ref())?; + let value = var.to_string_lossy(); + log::info!("Set {key} to {value} from the environment variable"); + ret.set(&key, value.as_ref())?; } } @@ -1087,7 +1265,10 @@ impl ConfigField for Option { } } -fn default_transform(input: &str) -> Result +/// Default transformation to parse a [`ConfigField`] for a string. +/// +/// This uses [`FromStr`] to parse the data. +pub fn default_config_transform(input: &str) -> Result where T: FromStr, ::Err: Sync + Send + Error + 'static, @@ -1104,19 +1285,45 @@ where }) } +/// Macro that generates [`ConfigField`] for a given type. +/// +/// # Usage +/// This always requires [`Display`] to be implemented for the given type. +/// +/// There are two ways to invoke this macro. The first one uses +/// [`default_config_transform`]/[`FromStr`] to parse the data: +/// +/// ```ignore +/// config_field(MyType); +/// ``` +/// +/// Note that the parsing error MUST implement [`std::error::Error`]! +/// +/// Or you can specify how you want to parse an [`str`] into the type: +/// +/// ```ignore +/// fn parse_it(s: &str) -> Result { +/// ... +/// } +/// +/// config_field( +/// MyType, +/// value => parse_it(value) +/// ); +/// ``` #[macro_export] macro_rules! config_field { ($t:ty) => { - config_field!($t, value => default_transform(value)?); + config_field!($t, value => $crate::config::default_config_transform(value)?); }; ($t:ty, $arg:ident => $transform:expr) => { - impl ConfigField for $t { - fn visit(&self, v: &mut V, key: &str, description: &'static str) { + impl $crate::config::ConfigField for $t { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { v.some(key, self, description) } - fn set(&mut self, _: &str, $arg: &str) -> Result<()> { + fn set(&mut self, _: &str, $arg: &str) -> $crate::error::Result<()> { *self = $transform; Ok(()) } @@ -1125,7 +1332,7 @@ macro_rules! config_field { } config_field!(String); -config_field!(bool, value => default_transform(value.to_lowercase().as_str())?); +config_field!(bool, value => default_config_transform(value.to_lowercase().as_str())?); config_field!(usize); config_field!(f64); config_field!(u64); @@ -1138,8 +1345,7 @@ impl ConfigField for u8 { fn set(&mut self, key: &str, value: &str) -> Result<()> { if value.is_empty() { return Err(DataFusionError::Configuration(format!( - "Input string for {} key is empty", - key + "Input string for {key} key is empty" ))); } // Check if the string is a valid number @@ -1151,8 +1357,7 @@ impl ConfigField for u8 { // Check if the first character is ASCII (single byte) if bytes.len() > 1 || !value.chars().next().unwrap().is_ascii() { return Err(DataFusionError::Configuration(format!( - "Error parsing {} as u8. Non-ASCII string provided", - value + "Error parsing {value} as u8. Non-ASCII string provided" ))); } *self = bytes[0]; @@ -1631,6 +1836,24 @@ pub struct TableParquetOptions { /// ) /// ``` pub key_value_metadata: HashMap>, + /// Options for configuring Parquet modular encryption + /// See ConfigFileEncryptionProperties and ConfigFileDecryptionProperties in datafusion/common/src/config.rs + /// These can be set via 'format.crypto', for example: + /// ```sql + /// OPTIONS ( + /// 'format.crypto.file_encryption.encrypt_footer' 'true', + /// 'format.crypto.file_encryption.footer_key_as_hex' '30313233343536373839303132333435', -- b"0123456789012345" */ + /// 'format.crypto.file_encryption.column_key_as_hex::double_field' '31323334353637383930313233343530', -- b"1234567890123450" + /// 'format.crypto.file_encryption.column_key_as_hex::float_field' '31323334353637383930313233343531', -- b"1234567890123451" + /// -- Same for decryption + /// 'format.crypto.file_decryption.footer_key_as_hex' '30313233343536373839303132333435', -- b"0123456789012345" + /// 'format.crypto.file_decryption.column_key_as_hex::double_field' '31323334353637383930313233343530', -- b"1234567890123450" + /// 'format.crypto.file_decryption.column_key_as_hex::float_field' '31323334353637383930313233343531', -- b"1234567890123451" + /// ) + /// ``` + /// See datafusion-cli/tests/sql/encrypted_parquet.sql for a more complete example. + /// Note that keys must be provided as in hex format since these are binary strings. + pub crypto: ParquetEncryptionOptions, } impl TableParquetOptions { @@ -1658,7 +1881,9 @@ impl ConfigField for TableParquetOptions { fn visit(&self, v: &mut V, key_prefix: &str, description: &'static str) { self.global.visit(v, key_prefix, description); self.column_specific_options - .visit(v, key_prefix, description) + .visit(v, key_prefix, description); + self.crypto + .visit(v, &format!("{key_prefix}.crypto"), description); } fn set(&mut self, key: &str, value: &str) -> Result<()> { @@ -1679,6 +1904,8 @@ impl ConfigField for TableParquetOptions { }; self.key_value_metadata.insert(k, Some(value.into())); Ok(()) + } else if let Some(crypto_feature) = key.strip_prefix("crypto.") { + self.crypto.set(crypto_feature, value) } else if key.contains("::") { self.column_specific_options.set(key, value) } else { @@ -1829,6 +2056,322 @@ config_namespace_with_hashmap! { } } +#[derive(Clone, Debug, PartialEq)] +pub struct ConfigFileEncryptionProperties { + /// Should the parquet footer be encrypted + /// default is true + pub encrypt_footer: bool, + /// Key to use for the parquet footer encoded in hex format + pub footer_key_as_hex: String, + /// Metadata information for footer key + pub footer_key_metadata_as_hex: String, + /// HashMap of column names --> (key in hex format, metadata) + pub column_encryption_properties: HashMap, + /// AAD prefix string uniquely identifies the file and prevents file swapping + pub aad_prefix_as_hex: String, + /// If true, store the AAD prefix in the file + /// default is false + pub store_aad_prefix: bool, +} + +// Setup to match EncryptionPropertiesBuilder::new() +impl Default for ConfigFileEncryptionProperties { + fn default() -> Self { + ConfigFileEncryptionProperties { + encrypt_footer: true, + footer_key_as_hex: String::new(), + footer_key_metadata_as_hex: String::new(), + column_encryption_properties: Default::default(), + aad_prefix_as_hex: String::new(), + store_aad_prefix: false, + } + } +} + +config_namespace_with_hashmap! { + pub struct ColumnEncryptionProperties { + /// Per column encryption key + pub column_key_as_hex: String, default = "".to_string() + /// Per column encryption key metadata + pub column_metadata_as_hex: Option, default = None + } +} + +impl ConfigField for ConfigFileEncryptionProperties { + fn visit(&self, v: &mut V, key_prefix: &str, _description: &'static str) { + let key = format!("{key_prefix}.encrypt_footer"); + let desc = "Encrypt the footer"; + self.encrypt_footer.visit(v, key.as_str(), desc); + + let key = format!("{key_prefix}.footer_key_as_hex"); + let desc = "Key to use for the parquet footer"; + self.footer_key_as_hex.visit(v, key.as_str(), desc); + + let key = format!("{key_prefix}.footer_key_metadata_as_hex"); + let desc = "Metadata to use for the parquet footer"; + self.footer_key_metadata_as_hex.visit(v, key.as_str(), desc); + + let key = format!("{key_prefix}.aad_prefix_as_hex"); + let desc = "AAD prefix to use"; + self.aad_prefix_as_hex.visit(v, key.as_str(), desc); + + let key = format!("{key_prefix}.store_aad_prefix"); + let desc = "If true, store the AAD prefix"; + self.store_aad_prefix.visit(v, key.as_str(), desc); + + self.aad_prefix_as_hex.visit(v, key.as_str(), desc); + } + + fn set(&mut self, key: &str, value: &str) -> Result<()> { + // Any hex encoded values must be pre-encoded using + // hex::encode() before calling set. + + if key.contains("::") { + // Handle any column specific properties + return self.column_encryption_properties.set(key, value); + }; + + let (key, rem) = key.split_once('.').unwrap_or((key, "")); + match key { + "encrypt_footer" => self.encrypt_footer.set(rem, value.as_ref()), + "footer_key_as_hex" => self.footer_key_as_hex.set(rem, value.as_ref()), + "footer_key_metadata_as_hex" => { + self.footer_key_metadata_as_hex.set(rem, value.as_ref()) + } + "aad_prefix_as_hex" => self.aad_prefix_as_hex.set(rem, value.as_ref()), + "store_aad_prefix" => self.store_aad_prefix.set(rem, value.as_ref()), + _ => _config_err!( + "Config value \"{}\" not found on ConfigFileEncryptionProperties", + key + ), + } + } +} + +#[cfg(feature = "parquet")] +impl From for FileEncryptionProperties { + fn from(val: ConfigFileEncryptionProperties) -> Self { + let mut fep = FileEncryptionProperties::builder( + hex::decode(val.footer_key_as_hex).unwrap(), + ) + .with_plaintext_footer(!val.encrypt_footer) + .with_aad_prefix_storage(val.store_aad_prefix); + + if !val.footer_key_metadata_as_hex.is_empty() { + fep = fep.with_footer_key_metadata( + hex::decode(&val.footer_key_metadata_as_hex) + .expect("Invalid footer key metadata"), + ); + } + + for (column_name, encryption_props) in val.column_encryption_properties.iter() { + let encryption_key = hex::decode(&encryption_props.column_key_as_hex) + .expect("Invalid column encryption key"); + let key_metadata = encryption_props + .column_metadata_as_hex + .as_ref() + .map(|x| hex::decode(x).expect("Invalid column metadata")); + match key_metadata { + Some(key_metadata) => { + fep = fep.with_column_key_and_metadata( + column_name, + encryption_key, + key_metadata, + ); + } + None => { + fep = fep.with_column_key(column_name, encryption_key); + } + } + } + + if !val.aad_prefix_as_hex.is_empty() { + let aad_prefix: Vec = + hex::decode(&val.aad_prefix_as_hex).expect("Invalid AAD prefix"); + fep = fep.with_aad_prefix(aad_prefix); + } + fep.build().unwrap() + } +} + +#[cfg(feature = "parquet")] +impl From<&FileEncryptionProperties> for ConfigFileEncryptionProperties { + fn from(f: &FileEncryptionProperties) -> Self { + let (column_names_vec, column_keys_vec, column_metas_vec) = f.column_keys(); + + let mut column_encryption_properties: HashMap< + String, + ColumnEncryptionProperties, + > = HashMap::new(); + + for (i, column_name) in column_names_vec.iter().enumerate() { + let column_key_as_hex = hex::encode(&column_keys_vec[i]); + let column_metadata_as_hex: Option = + column_metas_vec.get(i).map(hex::encode); + column_encryption_properties.insert( + column_name.clone(), + ColumnEncryptionProperties { + column_key_as_hex, + column_metadata_as_hex, + }, + ); + } + let mut aad_prefix: Vec = Vec::new(); + if let Some(prefix) = f.aad_prefix() { + aad_prefix = prefix.clone(); + } + ConfigFileEncryptionProperties { + encrypt_footer: f.encrypt_footer(), + footer_key_as_hex: hex::encode(f.footer_key()), + footer_key_metadata_as_hex: f + .footer_key_metadata() + .map(hex::encode) + .unwrap_or_default(), + column_encryption_properties, + aad_prefix_as_hex: hex::encode(aad_prefix), + store_aad_prefix: f.store_aad_prefix(), + } + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct ConfigFileDecryptionProperties { + /// Binary string to use for the parquet footer encoded in hex format + pub footer_key_as_hex: String, + /// HashMap of column names --> key in hex format + pub column_decryption_properties: HashMap, + /// AAD prefix string uniquely identifies the file and prevents file swapping + pub aad_prefix_as_hex: String, + /// If true, then verify signature for files with plaintext footers. + /// default = true + pub footer_signature_verification: bool, +} + +config_namespace_with_hashmap! { + pub struct ColumnDecryptionProperties { + /// Per column encryption key + pub column_key_as_hex: String, default = "".to_string() + } +} + +// Setup to match DecryptionPropertiesBuilder::new() +impl Default for ConfigFileDecryptionProperties { + fn default() -> Self { + ConfigFileDecryptionProperties { + footer_key_as_hex: String::new(), + column_decryption_properties: Default::default(), + aad_prefix_as_hex: String::new(), + footer_signature_verification: true, + } + } +} + +impl ConfigField for ConfigFileDecryptionProperties { + fn visit(&self, v: &mut V, key_prefix: &str, _description: &'static str) { + let key = format!("{key_prefix}.footer_key_as_hex"); + let desc = "Key to use for the parquet footer"; + self.footer_key_as_hex.visit(v, key.as_str(), desc); + + let key = format!("{key_prefix}.aad_prefix_as_hex"); + let desc = "AAD prefix to use"; + self.aad_prefix_as_hex.visit(v, key.as_str(), desc); + + let key = format!("{key_prefix}.footer_signature_verification"); + let desc = "If true, verify the footer signature"; + self.footer_signature_verification + .visit(v, key.as_str(), desc); + + self.column_decryption_properties.visit(v, key_prefix, desc); + } + + fn set(&mut self, key: &str, value: &str) -> Result<()> { + // Any hex encoded values must be pre-encoded using + // hex::encode() before calling set. + + if key.contains("::") { + // Handle any column specific properties + return self.column_decryption_properties.set(key, value); + }; + + let (key, rem) = key.split_once('.').unwrap_or((key, "")); + match key { + "footer_key_as_hex" => self.footer_key_as_hex.set(rem, value.as_ref()), + "aad_prefix_as_hex" => self.aad_prefix_as_hex.set(rem, value.as_ref()), + "footer_signature_verification" => { + self.footer_signature_verification.set(rem, value.as_ref()) + } + _ => _config_err!( + "Config value \"{}\" not found on ConfigFileEncryptionProperties", + key + ), + } + } +} + +#[cfg(feature = "parquet")] +impl From for FileDecryptionProperties { + fn from(val: ConfigFileDecryptionProperties) -> Self { + let mut column_names: Vec<&str> = Vec::new(); + let mut column_keys: Vec> = Vec::new(); + + for (col_name, decryption_properties) in val.column_decryption_properties.iter() { + column_names.push(col_name.as_str()); + column_keys.push( + hex::decode(&decryption_properties.column_key_as_hex) + .expect("Invalid column decryption key"), + ); + } + + let mut fep = FileDecryptionProperties::builder( + hex::decode(val.footer_key_as_hex).expect("Invalid footer key"), + ) + .with_column_keys(column_names, column_keys) + .unwrap(); + + if !val.footer_signature_verification { + fep = fep.disable_footer_signature_verification(); + } + + if !val.aad_prefix_as_hex.is_empty() { + let aad_prefix = + hex::decode(&val.aad_prefix_as_hex).expect("Invalid AAD prefix"); + fep = fep.with_aad_prefix(aad_prefix); + } + + fep.build().unwrap() + } +} + +#[cfg(feature = "parquet")] +impl From<&FileDecryptionProperties> for ConfigFileDecryptionProperties { + fn from(f: &FileDecryptionProperties) -> Self { + let (column_names_vec, column_keys_vec) = f.column_keys(); + let mut column_decryption_properties: HashMap< + String, + ColumnDecryptionProperties, + > = HashMap::new(); + for (i, column_name) in column_names_vec.iter().enumerate() { + let props = ColumnDecryptionProperties { + column_key_as_hex: hex::encode(column_keys_vec[i].clone()), + }; + column_decryption_properties.insert(column_name.clone(), props); + } + + let mut aad_prefix: Vec = Vec::new(); + if let Some(prefix) = f.aad_prefix() { + aad_prefix = prefix.clone(); + } + ConfigFileDecryptionProperties { + footer_key_as_hex: hex::encode( + f.footer_key(None).unwrap_or_default().as_ref(), + ), + column_decryption_properties, + aad_prefix_as_hex: hex::encode(aad_prefix), + footer_signature_verification: f.check_plaintext_footer_integrity(), + } + } +} + config_namespace! { /// Options controlling CSV format pub struct CsvOptions { @@ -1982,11 +2525,11 @@ config_namespace! { } } -pub trait FormatOptionsExt: Display {} +pub trait OutputFormatExt: Display {} #[derive(Debug, Clone, PartialEq)] #[allow(clippy::large_enum_variant)] -pub enum FormatOptions { +pub enum OutputFormat { CSV(CsvOptions), JSON(JsonOptions), #[cfg(feature = "parquet")] @@ -1995,29 +2538,28 @@ pub enum FormatOptions { ARROW, } -impl Display for FormatOptions { +impl Display for OutputFormat { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let out = match self { - FormatOptions::CSV(_) => "csv", - FormatOptions::JSON(_) => "json", + OutputFormat::CSV(_) => "csv", + OutputFormat::JSON(_) => "json", #[cfg(feature = "parquet")] - FormatOptions::PARQUET(_) => "parquet", - FormatOptions::AVRO => "avro", - FormatOptions::ARROW => "arrow", + OutputFormat::PARQUET(_) => "parquet", + OutputFormat::AVRO => "avro", + OutputFormat::ARROW => "arrow", }; - write!(f, "{}", out) + write!(f, "{out}") } } #[cfg(test)] mod tests { - use std::any::Any; - use std::collections::HashMap; - use crate::config::{ ConfigEntry, ConfigExtension, ConfigField, ConfigFileType, ExtensionOptions, Extensions, TableOptions, }; + use std::any::Any; + use std::collections::HashMap; #[derive(Default, Debug, Clone)] pub struct TestExtensionConfig { @@ -2146,6 +2688,129 @@ mod tests { ); } + #[cfg(feature = "parquet")] + #[test] + fn parquet_table_encryption() { + use crate::config::{ + ConfigFileDecryptionProperties, ConfigFileEncryptionProperties, + }; + use parquet::encryption::decrypt::FileDecryptionProperties; + use parquet::encryption::encrypt::FileEncryptionProperties; + + let footer_key = b"0123456789012345".to_vec(); // 128bit/16 + let column_names = vec!["double_field", "float_field"]; + let column_keys = + vec![b"1234567890123450".to_vec(), b"1234567890123451".to_vec()]; + + let file_encryption_properties = + FileEncryptionProperties::builder(footer_key.clone()) + .with_column_keys(column_names.clone(), column_keys.clone()) + .unwrap() + .build() + .unwrap(); + + let decryption_properties = FileDecryptionProperties::builder(footer_key.clone()) + .with_column_keys(column_names.clone(), column_keys.clone()) + .unwrap() + .build() + .unwrap(); + + // Test round-trip + let config_encrypt: ConfigFileEncryptionProperties = + (&file_encryption_properties).into(); + let encryption_properties_built: FileEncryptionProperties = + config_encrypt.clone().into(); + assert_eq!(file_encryption_properties, encryption_properties_built); + + let config_decrypt: ConfigFileDecryptionProperties = + (&decryption_properties).into(); + let decryption_properties_built: FileDecryptionProperties = + config_decrypt.clone().into(); + assert_eq!(decryption_properties, decryption_properties_built); + + /////////////////////////////////////////////////////////////////////////////////// + // Test encryption config + + // Display original encryption config + // println!("{:#?}", config_encrypt); + + let mut table_config = TableOptions::new(); + table_config.set_config_format(ConfigFileType::PARQUET); + table_config + .parquet + .set( + "crypto.file_encryption.encrypt_footer", + config_encrypt.encrypt_footer.to_string().as_str(), + ) + .unwrap(); + table_config + .parquet + .set( + "crypto.file_encryption.footer_key_as_hex", + config_encrypt.footer_key_as_hex.as_str(), + ) + .unwrap(); + + for (i, col_name) in column_names.iter().enumerate() { + let key = format!("crypto.file_encryption.column_key_as_hex::{col_name}"); + let value = hex::encode(column_keys[i].clone()); + table_config + .parquet + .set(key.as_str(), value.as_str()) + .unwrap(); + } + + // Print matching final encryption config + // println!("{:#?}", table_config.parquet.crypto.file_encryption); + + assert_eq!( + table_config.parquet.crypto.file_encryption, + Some(config_encrypt) + ); + + /////////////////////////////////////////////////////////////////////////////////// + // Test decryption config + + // Display original decryption config + // println!("{:#?}", config_decrypt); + + let mut table_config = TableOptions::new(); + table_config.set_config_format(ConfigFileType::PARQUET); + table_config + .parquet + .set( + "crypto.file_decryption.footer_key_as_hex", + config_decrypt.footer_key_as_hex.as_str(), + ) + .unwrap(); + + for (i, col_name) in column_names.iter().enumerate() { + let key = format!("crypto.file_decryption.column_key_as_hex::{col_name}"); + let value = hex::encode(column_keys[i].clone()); + table_config + .parquet + .set(key.as_str(), value.as_str()) + .unwrap(); + } + + // Print matching final decryption config + // println!("{:#?}", table_config.parquet.crypto.file_decryption); + + assert_eq!( + table_config.parquet.crypto.file_decryption, + Some(config_decrypt.clone()) + ); + + // Set config directly + let mut table_config = TableOptions::new(); + table_config.set_config_format(ConfigFileType::PARQUET); + table_config.parquet.crypto.file_decryption = Some(config_decrypt.clone()); + assert_eq!( + table_config.parquet.crypto.file_decryption, + Some(config_decrypt.clone()) + ); + } + #[cfg(feature = "parquet")] #[test] fn parquet_table_options_config_entry() { diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 66a26a18c0dc..804e14bf72fb 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -472,7 +472,7 @@ impl DFSchema { let matches = self.qualified_fields_with_unqualified_name(name); match matches.len() { 0 => Err(unqualified_field_not_found(name, self)), - 1 => Ok((matches[0].0, (matches[0].1))), + 1 => Ok((matches[0].0, matches[0].1)), _ => { // When `matches` size > 1, it doesn't necessarily mean an `ambiguous name` problem. // Because name may generate from Alias/... . It means that it don't own qualifier. @@ -515,14 +515,6 @@ impl DFSchema { Ok(self.field(idx)) } - /// Find the field with the given qualified column - pub fn field_from_column(&self, column: &Column) -> Result<&Field> { - match &column.relation { - Some(r) => self.field_with_qualified_name(r, &column.name), - None => self.field_with_unqualified_name(&column.name), - } - } - /// Find the field with the given qualified column pub fn qualified_field_from_column( &self, @@ -969,16 +961,28 @@ impl Display for DFSchema { /// widely used in the DataFusion codebase. pub trait ExprSchema: std::fmt::Debug { /// Is this column reference nullable? - fn nullable(&self, col: &Column) -> Result; + fn nullable(&self, col: &Column) -> Result { + Ok(self.field_from_column(col)?.is_nullable()) + } /// What is the datatype of this column? - fn data_type(&self, col: &Column) -> Result<&DataType>; + fn data_type(&self, col: &Column) -> Result<&DataType> { + Ok(self.field_from_column(col)?.data_type()) + } /// Returns the column's optional metadata. - fn metadata(&self, col: &Column) -> Result<&HashMap>; + fn metadata(&self, col: &Column) -> Result<&HashMap> { + Ok(self.field_from_column(col)?.metadata()) + } /// Return the column's datatype and nullability - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)>; + fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { + let field = self.field_from_column(col)?; + Ok((field.data_type(), field.is_nullable())) + } + + // Return the column's field + fn field_from_column(&self, col: &Column) -> Result<&Field>; } // Implement `ExprSchema` for `Arc` @@ -998,24 +1002,18 @@ impl + std::fmt::Debug> ExprSchema for P { fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { self.as_ref().data_type_and_nullable(col) } -} - -impl ExprSchema for DFSchema { - fn nullable(&self, col: &Column) -> Result { - Ok(self.field_from_column(col)?.is_nullable()) - } - - fn data_type(&self, col: &Column) -> Result<&DataType> { - Ok(self.field_from_column(col)?.data_type()) - } - fn metadata(&self, col: &Column) -> Result<&HashMap> { - Ok(self.field_from_column(col)?.metadata()) + fn field_from_column(&self, col: &Column) -> Result<&Field> { + self.as_ref().field_from_column(col) } +} - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { - let field = self.field_from_column(col)?; - Ok((field.data_type(), field.is_nullable())) +impl ExprSchema for DFSchema { + fn field_from_column(&self, col: &Column) -> Result<&Field> { + match &col.relation { + Some(r) => self.field_with_qualified_name(r, &col.name), + None => self.field_with_unqualified_name(&col.name), + } } } @@ -1090,7 +1088,7 @@ impl SchemaExt for Schema { pub fn qualified_name(qualifier: Option<&TableReference>, name: &str) -> String { match qualifier { - Some(q) => format!("{}.{}", q, name), + Some(q) => format!("{q}.{name}"), None => name.to_string(), } } diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index c50ec64759d5..b4a537fdce7e 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -59,7 +59,7 @@ pub enum DataFusionError { ParquetError(ParquetError), /// Error when reading Avro data. #[cfg(feature = "avro")] - AvroError(AvroError), + AvroError(Box), /// Error when reading / writing to / from an object_store (e.g. S3 or LocalFile) #[cfg(feature = "object_store")] ObjectStore(object_store::Error), @@ -311,7 +311,7 @@ impl From for DataFusionError { #[cfg(feature = "avro")] impl From for DataFusionError { fn from(e: AvroError) -> Self { - DataFusionError::AvroError(e) + DataFusionError::AvroError(Box::new(e)) } } @@ -397,7 +397,7 @@ impl Error for DataFusionError { impl From for io::Error { fn from(e: DataFusionError) -> Self { - io::Error::new(io::ErrorKind::Other, e) + io::Error::other(e) } } @@ -526,7 +526,7 @@ impl DataFusionError { pub fn message(&self) -> Cow { match *self { DataFusionError::ArrowError(ref desc, ref backtrace) => { - let backtrace = backtrace.clone().unwrap_or("".to_owned()); + let backtrace = backtrace.clone().unwrap_or_else(|| "".to_owned()); Cow::Owned(format!("{desc}{backtrace}")) } #[cfg(feature = "parquet")] @@ -535,7 +535,8 @@ impl DataFusionError { DataFusionError::AvroError(ref desc) => Cow::Owned(desc.to_string()), DataFusionError::IoError(ref desc) => Cow::Owned(desc.to_string()), DataFusionError::SQL(ref desc, ref backtrace) => { - let backtrace: String = backtrace.clone().unwrap_or("".to_owned()); + let backtrace: String = + backtrace.clone().unwrap_or_else(|| "".to_owned()); Cow::Owned(format!("{desc:?}{backtrace}")) } DataFusionError::Configuration(ref desc) => Cow::Owned(desc.to_string()), @@ -547,7 +548,7 @@ impl DataFusionError { DataFusionError::Plan(ref desc) => Cow::Owned(desc.to_string()), DataFusionError::SchemaError(ref desc, ref backtrace) => { let backtrace: &str = - &backtrace.as_ref().clone().unwrap_or("".to_owned()); + &backtrace.as_ref().clone().unwrap_or_else(|| "".to_owned()); Cow::Owned(format!("{desc}{backtrace}")) } DataFusionError::Execution(ref desc) => Cow::Owned(desc.to_string()), @@ -759,23 +760,33 @@ macro_rules! make_error { /// Macro wraps `$ERR` to add backtrace feature #[macro_export] macro_rules! $NAME_DF_ERR { - ($d($d args:expr),*) => { - $crate::DataFusionError::$ERR( + ($d($d args:expr),* $d(; diagnostic=$d DIAG:expr)?) => {{ + let err =$crate::DataFusionError::$ERR( ::std::format!( "{}{}", ::std::format!($d($d args),*), $crate::DataFusionError::get_back_trace(), ).into() - ) + ); + $d ( + let err = err.with_diagnostic($d DIAG); + )? + err } } + } /// Macro wraps Err(`$ERR`) to add backtrace feature #[macro_export] macro_rules! $NAME_ERR { - ($d($d args:expr),*) => { - Err($crate::[<_ $NAME_DF_ERR>]!($d($d args),*)) - } + ($d($d args:expr),* $d(; diagnostic = $d DIAG:expr)?) => {{ + let err = $crate::[<_ $NAME_DF_ERR>]!($d($d args),*); + $d ( + let err = err.with_diagnostic($d DIAG); + )? + Err(err) + + }} } @@ -816,54 +827,80 @@ make_error!(resources_err, resources_datafusion_err, ResourcesExhausted); // Exposes a macro to create `DataFusionError::SQL` with optional backtrace #[macro_export] macro_rules! sql_datafusion_err { - ($ERR:expr) => { - DataFusionError::SQL($ERR, Some(DataFusionError::get_back_trace())) - }; + ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ + let err = DataFusionError::SQL($ERR, Some(DataFusionError::get_back_trace())); + $( + let err = err.with_diagnostic($DIAG); + )? + err + }}; } // Exposes a macro to create `Err(DataFusionError::SQL)` with optional backtrace #[macro_export] macro_rules! sql_err { - ($ERR:expr) => { - Err(datafusion_common::sql_datafusion_err!($ERR)) - }; + ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ + let err = datafusion_common::sql_datafusion_err!($ERR); + $( + let err = err.with_diagnostic($DIAG); + )? + Err(err) + }}; } // Exposes a macro to create `DataFusionError::ArrowError` with optional backtrace #[macro_export] macro_rules! arrow_datafusion_err { - ($ERR:expr) => { - DataFusionError::ArrowError($ERR, Some(DataFusionError::get_back_trace())) - }; + ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ + let err = DataFusionError::ArrowError($ERR, Some(DataFusionError::get_back_trace())); + $( + let err = err.with_diagnostic($DIAG); + )? + err + }}; } // Exposes a macro to create `Err(DataFusionError::ArrowError)` with optional backtrace #[macro_export] macro_rules! arrow_err { - ($ERR:expr) => { - Err(datafusion_common::arrow_datafusion_err!($ERR)) - }; + ($ERR:expr $(; diagnostic = $DIAG:expr)?) => { + { + let err = datafusion_common::arrow_datafusion_err!($ERR); + $( + let err = err.with_diagnostic($DIAG); + )? + Err(err) + }}; } // Exposes a macro to create `DataFusionError::SchemaError` with optional backtrace #[macro_export] macro_rules! schema_datafusion_err { - ($ERR:expr) => { - $crate::error::DataFusionError::SchemaError( + ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ + let err = $crate::error::DataFusionError::SchemaError( $ERR, Box::new(Some($crate::error::DataFusionError::get_back_trace())), - ) - }; + ); + $( + let err = err.with_diagnostic($DIAG); + )? + err + }}; } // Exposes a macro to create `Err(DataFusionError::SchemaError)` with optional backtrace #[macro_export] macro_rules! schema_err { - ($ERR:expr) => { - Err($crate::error::DataFusionError::SchemaError( + ($ERR:expr $(; diagnostic = $DIAG:expr)?) => {{ + let err = $crate::error::DataFusionError::SchemaError( $ERR, Box::new(Some($crate::error::DataFusionError::get_back_trace())), - )) + ); + $( + let err = err.with_diagnostic($DIAG); + )? + Err(err) + } }; } @@ -908,7 +945,7 @@ pub fn add_possible_columns_to_diag( .collect(); for name in field_names { - diagnostic.add_note(format!("possible column {}", name), None); + diagnostic.add_note(format!("possible column {name}"), None); } } @@ -1083,8 +1120,7 @@ mod test { ); // assert wrapping other Error - let generic_error: GenericError = - Box::new(std::io::Error::new(std::io::ErrorKind::Other, "io error")); + let generic_error: GenericError = Box::new(std::io::Error::other("io error")); let datafusion_error: DataFusionError = generic_error.into(); println!("{}", datafusion_error.strip_backtrace()); assert_eq!( @@ -1095,13 +1131,12 @@ mod test { #[test] fn external_error_no_recursive() { - let generic_error_1: GenericError = - Box::new(std::io::Error::new(std::io::ErrorKind::Other, "io error")); + let generic_error_1: GenericError = Box::new(std::io::Error::other("io error")); let external_error_1: DataFusionError = generic_error_1.into(); let generic_error_2: GenericError = Box::new(external_error_1); let external_error_2: DataFusionError = generic_error_2.into(); - println!("{}", external_error_2); + println!("{external_error_2}"); assert!(external_error_2 .to_string() .starts_with("External error: io error")); diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index 3e33466edf50..60f0f4abb0c0 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -95,10 +95,17 @@ impl TryFrom<&TableParquetOptions> for WriterPropertiesBuilder { global, column_specific_options, key_value_metadata, + crypto, } = table_parquet_options; let mut builder = global.into_writer_properties_builder()?; + if let Some(file_encryption_properties) = &crypto.file_encryption { + builder = builder.with_file_encryption_properties( + file_encryption_properties.clone().into(), + ); + } + // check that the arrow schema is present in the kv_metadata, if configured to do so if !global.skip_arrow_metadata && !key_value_metadata.contains_key(ARROW_SCHEMA_META_KEY) @@ -330,8 +337,7 @@ fn split_compression_string(str_setting: &str) -> Result<(String, Option)> let level = &rh[..rh.len() - 1].parse::().map_err(|_| { DataFusionError::Configuration(format!( "Could not parse compression string. \ - Got codec: {} and unknown level from {}", - codec, str_setting + Got codec: {codec} and unknown level from {str_setting}" )) })?; Ok((codec.to_owned(), Some(*level))) @@ -450,7 +456,10 @@ mod tests { }; use std::collections::HashMap; - use crate::config::{ParquetColumnOptions, ParquetOptions}; + use crate::config::{ + ConfigFileEncryptionProperties, ParquetColumnOptions, ParquetEncryptionOptions, + ParquetOptions, + }; use super::*; @@ -581,6 +590,9 @@ mod tests { HashMap::from([(COL_NAME.into(), configured_col_props)]) }; + let fep: Option = + props.file_encryption_properties().map(|fe| fe.into()); + #[allow(deprecated)] // max_statistics_size TableParquetOptions { global: ParquetOptions { @@ -628,6 +640,10 @@ mod tests { }, column_specific_options, key_value_metadata, + crypto: ParquetEncryptionOptions { + file_encryption: fep, + file_decryption: None, + }, } } @@ -682,6 +698,7 @@ mod tests { )] .into(), key_value_metadata: [(key, value)].into(), + crypto: Default::default(), }; let writer_props = WriterPropertiesBuilder::try_from(&table_parquet_opts) diff --git a/datafusion/common/src/format.rs b/datafusion/common/src/format.rs index 23cfb72314a3..a4ebd1753999 100644 --- a/datafusion/common/src/format.rs +++ b/datafusion/common/src/format.rs @@ -19,6 +19,7 @@ use arrow::compute::CastOptions; use arrow::util::display::{DurationFormat, FormatOptions}; /// The default [`FormatOptions`] to use within DataFusion +/// Also see [`crate::config::FormatOptions`] pub const DEFAULT_FORMAT_OPTIONS: FormatOptions<'static> = FormatOptions::new().with_duration_format(DurationFormat::Pretty); @@ -27,7 +28,3 @@ pub const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions { safe: false, format_options: DEFAULT_FORMAT_OPTIONS, }; - -pub const DEFAULT_CLI_FORMAT_OPTIONS: FormatOptions<'static> = FormatOptions::new() - .with_duration_format(DurationFormat::Pretty) - .with_null("NULL"); diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index c4f2805f8285..63962998ad18 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -36,35 +36,31 @@ pub enum Constraint { } /// This object encapsulates a list of functional constraints: -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, PartialOrd)] pub struct Constraints { inner: Vec, } impl Constraints { - /// Create empty constraints - pub fn empty() -> Self { - Constraints::new_unverified(vec![]) - } - /// Create a new [`Constraints`] object from the given `constraints`. - /// Users should use the [`Constraints::empty`] or [`SqlToRel::new_constraint_from_table_constraints`] functions - /// for constructing [`Constraints`]. This constructor is for internal - /// purposes only and does not check whether the argument is valid. The user - /// is responsible for supplying a valid vector of [`Constraint`] objects. + /// Users should use the [`Constraints::default`] or [`SqlToRel::new_constraint_from_table_constraints`] + /// functions for constructing [`Constraints`] instances. This constructor + /// is for internal purposes only and does not check whether the argument + /// is valid. The user is responsible for supplying a valid vector of + /// [`Constraint`] objects. /// /// [`SqlToRel::new_constraint_from_table_constraints`]: https://docs.rs/datafusion/latest/datafusion/sql/planner/struct.SqlToRel.html#method.new_constraint_from_table_constraints pub fn new_unverified(constraints: Vec) -> Self { Self { inner: constraints } } - /// Check whether constraints is empty - pub fn is_empty(&self) -> bool { - self.inner.is_empty() + /// Extends the current constraints with the given `other` constraints. + pub fn extend(&mut self, other: Constraints) { + self.inner.extend(other.inner); } - /// Projects constraints using the given projection indices. - /// Returns None if any of the constraint columns are not included in the projection. + /// Projects constraints using the given projection indices. Returns `None` + /// if any of the constraint columns are not included in the projection. pub fn project(&self, proj_indices: &[usize]) -> Option { let projected = self .inner @@ -74,14 +70,14 @@ impl Constraints { Constraint::PrimaryKey(indices) => { let new_indices = update_elements_with_matching_indices(indices, proj_indices); - // Only keep constraint if all columns are preserved + // Only keep the constraint if all columns are preserved: (new_indices.len() == indices.len()) .then_some(Constraint::PrimaryKey(new_indices)) } Constraint::Unique(indices) => { let new_indices = update_elements_with_matching_indices(indices, proj_indices); - // Only keep constraint if all columns are preserved + // Only keep the constraint if all columns are preserved: (new_indices.len() == indices.len()) .then_some(Constraint::Unique(new_indices)) } @@ -93,15 +89,9 @@ impl Constraints { } } -impl Default for Constraints { - fn default() -> Self { - Constraints::empty() - } -} - impl IntoIterator for Constraints { type Item = Constraint; - type IntoIter = IntoIter; + type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { self.inner.into_iter() @@ -113,7 +103,7 @@ impl Display for Constraints { let pk = self .inner .iter() - .map(|c| format!("{:?}", c)) + .map(|c| format!("{c:?}")) .collect::>(); let pk = pk.join(", "); write!(f, "constraints=[{pk}]") @@ -374,7 +364,7 @@ impl FunctionalDependencies { // These joins preserve functional dependencies of the left side: left_func_dependencies } - JoinType::RightSemi | JoinType::RightAnti => { + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { // These joins preserve functional dependencies of the right side: right_func_dependencies } diff --git a/datafusion/common/src/join_type.rs b/datafusion/common/src/join_type.rs index ac81d977b729..d9a1478f0238 100644 --- a/datafusion/common/src/join_type.rs +++ b/datafusion/common/src/join_type.rs @@ -67,6 +67,11 @@ pub enum JoinType { /// /// [1]: http://btw2017.informatik.uni-stuttgart.de/slidesandpapers/F1-10-37/paper_web.pdf LeftMark, + /// Right Mark Join + /// + /// Same logic as the LeftMark Join above, however it returns a record for each record from the + /// right input. + RightMark, } impl JoinType { @@ -87,13 +92,12 @@ impl JoinType { JoinType::RightSemi => JoinType::LeftSemi, JoinType::LeftAnti => JoinType::RightAnti, JoinType::RightAnti => JoinType::LeftAnti, - JoinType::LeftMark => { - unreachable!("LeftMark join type does not support swapping") - } + JoinType::LeftMark => JoinType::RightMark, + JoinType::RightMark => JoinType::LeftMark, } } - /// Does the join type support swapping inputs? + /// Does the join type support swapping inputs? pub fn supports_swap(&self) -> bool { matches!( self, @@ -121,6 +125,7 @@ impl Display for JoinType { JoinType::LeftAnti => "LeftAnti", JoinType::RightAnti => "RightAnti", JoinType::LeftMark => "LeftMark", + JoinType::RightMark => "RightMark", }; write!(f, "{join_type}") } @@ -141,6 +146,7 @@ impl FromStr for JoinType { "LEFTANTI" => Ok(JoinType::LeftAnti), "RIGHTANTI" => Ok(JoinType::RightAnti), "LEFTMARK" => Ok(JoinType::LeftMark), + "RIGHTMARK" => Ok(JoinType::RightMark), _ => _not_impl_err!("The join type {s} does not exist or is not implemented"), } } diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index b137624532b9..3ea7321ef3b4 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -46,7 +46,10 @@ pub mod file_options; pub mod format; pub mod hash_utils; pub mod instant; +pub mod nested_struct; +mod null_equality; pub mod parsers; +pub mod pruning; pub mod rounding; pub mod scalar; pub mod spans; @@ -78,6 +81,7 @@ pub use functional_dependencies::{ }; use hashbrown::hash_map::DefaultHashBuilder; pub use join_type::{JoinConstraint, JoinSide, JoinType}; +pub use null_equality::NullEquality; pub use param_value::ParamValues; pub use scalar::{ScalarType, ScalarValue}; pub use schema_reference::SchemaReference; @@ -185,9 +189,7 @@ mod tests { let expected_prefix = expected_prefix.as_ref(); assert!( actual.starts_with(expected_prefix), - "Expected '{}' to start with '{}'", - actual, - expected_prefix + "Expected '{actual}' to start with '{expected_prefix}'" ); } } diff --git a/datafusion/common/src/nested_struct.rs b/datafusion/common/src/nested_struct.rs new file mode 100644 index 000000000000..f349b360f238 --- /dev/null +++ b/datafusion/common/src/nested_struct.rs @@ -0,0 +1,329 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::{DataFusionError, Result, _plan_err}; +use arrow::{ + array::{new_null_array, Array, ArrayRef, StructArray}, + compute::cast, + datatypes::{DataType::Struct, Field, FieldRef}, +}; +use std::sync::Arc; + +/// Cast a struct column to match target struct fields, handling nested structs recursively. +/// +/// This function implements struct-to-struct casting with the assumption that **structs should +/// always be allowed to cast to other structs**. However, the source column must already be +/// a struct type - non-struct sources will result in an error. +/// +/// ## Field Matching Strategy +/// - **By Name**: Source struct fields are matched to target fields by name (case-sensitive) +/// - **Type Adaptation**: When a matching field is found, it is recursively cast to the target field's type +/// - **Missing Fields**: Target fields not present in the source are filled with null values +/// - **Extra Fields**: Source fields not present in the target are ignored +/// +/// ## Nested Struct Handling +/// - Nested structs are handled recursively using the same casting rules +/// - Each level of nesting follows the same field matching and null-filling strategy +/// - This allows for complex struct transformations while maintaining data integrity +/// +/// # Arguments +/// * `source_col` - The source array to cast (must be a struct array) +/// * `target_fields` - The target struct field definitions to cast to +/// +/// # Returns +/// A `Result` containing the cast struct array +/// +/// # Errors +/// Returns a `DataFusionError::Plan` if the source column is not a struct type +fn cast_struct_column( + source_col: &ArrayRef, + target_fields: &[Arc], +) -> Result { + if let Some(struct_array) = source_col.as_any().downcast_ref::() { + let mut children: Vec<(Arc, Arc)> = Vec::new(); + let num_rows = source_col.len(); + + for target_child_field in target_fields { + let field_arc = Arc::clone(target_child_field); + match struct_array.column_by_name(target_child_field.name()) { + Some(source_child_col) => { + let adapted_child = + cast_column(source_child_col, target_child_field)?; + children.push((field_arc, adapted_child)); + } + None => { + children.push(( + field_arc, + new_null_array(target_child_field.data_type(), num_rows), + )); + } + } + } + + let struct_array = StructArray::from(children); + Ok(Arc::new(struct_array)) + } else { + // Return error if source is not a struct type + Err(DataFusionError::Plan(format!( + "Cannot cast column of type {:?} to struct type. Source must be a struct to cast to struct.", + source_col.data_type() + ))) + } +} + +/// Cast a column to match the target field type, with special handling for nested structs. +/// +/// This function serves as the main entry point for column casting operations. For struct +/// types, it enforces that **only struct columns can be cast to struct types**. +/// +/// ## Casting Behavior +/// - **Struct Types**: Delegates to `cast_struct_column` for struct-to-struct casting only +/// - **Non-Struct Types**: Uses Arrow's standard `cast` function for primitive type conversions +/// +/// ## Struct Casting Requirements +/// The struct casting logic requires that the source column must already be a struct type. +/// This makes the function useful for: +/// - Schema evolution scenarios where struct layouts change over time +/// - Data migration between different struct schemas +/// - Type-safe data processing pipelines that maintain struct type integrity +/// +/// # Arguments +/// * `source_col` - The source array to cast +/// * `target_field` - The target field definition (including type and metadata) +/// +/// # Returns +/// A `Result` containing the cast array +/// +/// # Errors +/// Returns an error if: +/// - Attempting to cast a non-struct column to a struct type +/// - Arrow's cast function fails for non-struct types +/// - Memory allocation fails during struct construction +/// - Invalid data type combinations are encountered +pub fn cast_column(source_col: &ArrayRef, target_field: &Field) -> Result { + match target_field.data_type() { + Struct(target_fields) => cast_struct_column(source_col, target_fields), + _ => Ok(cast(source_col, target_field.data_type())?), + } +} + +/// Validates compatibility between source and target struct fields for casting operations. +/// +/// This function implements comprehensive struct compatibility checking by examining: +/// - Field name matching between source and target structs +/// - Type castability for each matching field (including recursive struct validation) +/// - Proper handling of missing fields (target fields not in source are allowed - filled with nulls) +/// - Proper handling of extra fields (source fields not in target are allowed - ignored) +/// +/// # Compatibility Rules +/// - **Field Matching**: Fields are matched by name (case-sensitive) +/// - **Missing Target Fields**: Allowed - will be filled with null values during casting +/// - **Extra Source Fields**: Allowed - will be ignored during casting +/// - **Type Compatibility**: Each matching field must be castable using Arrow's type system +/// - **Nested Structs**: Recursively validates nested struct compatibility +/// +/// # Arguments +/// * `source_fields` - Fields from the source struct type +/// * `target_fields` - Fields from the target struct type +/// +/// # Returns +/// * `Ok(true)` if the structs are compatible for casting +/// * `Err(DataFusionError)` with detailed error message if incompatible +/// +/// # Examples +/// ```text +/// // Compatible: source has extra field, target has missing field +/// // Source: {a: i32, b: string, c: f64} +/// // Target: {a: i64, d: bool} +/// // Result: Ok(true) - 'a' can cast i32->i64, 'b','c' ignored, 'd' filled with nulls +/// +/// // Incompatible: matching field has incompatible types +/// // Source: {a: string} +/// // Target: {a: binary} +/// // Result: Err(...) - string cannot cast to binary +/// ``` +pub fn validate_struct_compatibility( + source_fields: &[FieldRef], + target_fields: &[FieldRef], +) -> Result { + // Check compatibility for each target field + for target_field in target_fields { + // Look for matching field in source by name + if let Some(source_field) = source_fields + .iter() + .find(|f| f.name() == target_field.name()) + { + // Check if the matching field types are compatible + match (source_field.data_type(), target_field.data_type()) { + // Recursively validate nested structs + (Struct(source_nested), Struct(target_nested)) => { + validate_struct_compatibility(source_nested, target_nested)?; + } + // For non-struct types, use the existing castability check + _ => { + if !arrow::compute::can_cast_types( + source_field.data_type(), + target_field.data_type(), + ) { + return _plan_err!( + "Cannot cast struct field '{}' from type {:?} to type {:?}", + target_field.name(), + source_field.data_type(), + target_field.data_type() + ); + } + } + } + } + // Missing fields in source are OK - they'll be filled with nulls + } + + // Extra fields in source are OK - they'll be ignored + Ok(true) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::{ + array::{Int32Array, Int64Array, StringArray}, + datatypes::{DataType, Field}, + }; + /// Macro to extract and downcast a column from a StructArray + macro_rules! get_column_as { + ($struct_array:expr, $column_name:expr, $array_type:ty) => { + $struct_array + .column_by_name($column_name) + .unwrap() + .as_any() + .downcast_ref::<$array_type>() + .unwrap() + }; + } + + #[test] + fn test_cast_simple_column() { + let source = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + let target_field = Field::new("ints", DataType::Int64, true); + let result = cast_column(&source, &target_field).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.len(), 3); + assert_eq!(result.value(0), 1); + assert_eq!(result.value(1), 2); + assert_eq!(result.value(2), 3); + } + + #[test] + fn test_cast_struct_with_missing_field() { + let a_array = Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef; + let source_struct = StructArray::from(vec![( + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::clone(&a_array), + )]); + let source_col = Arc::new(source_struct) as ArrayRef; + + let target_field = Field::new( + "s", + Struct( + vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Utf8, true)), + ] + .into(), + ), + true, + ); + + let result = cast_column(&source_col, &target_field).unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_array.fields().len(), 2); + let a_result = get_column_as!(&struct_array, "a", Int32Array); + assert_eq!(a_result.value(0), 1); + assert_eq!(a_result.value(1), 2); + + let b_result = get_column_as!(&struct_array, "b", StringArray); + assert_eq!(b_result.len(), 2); + assert!(b_result.is_null(0)); + assert!(b_result.is_null(1)); + } + + #[test] + fn test_cast_struct_source_not_struct() { + let source = Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef; + let target_field = Field::new( + "s", + Struct(vec![Arc::new(Field::new("a", DataType::Int32, true))].into()), + true, + ); + + let result = cast_column(&source, &target_field); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot cast column of type")); + assert!(error_msg.contains("to struct type")); + assert!(error_msg.contains("Source must be a struct")); + } + + #[test] + fn test_validate_struct_compatibility_incompatible_types() { + // Source struct: {field1: Binary, field2: String} + let source_fields = vec![ + Arc::new(Field::new("field1", DataType::Binary, true)), + Arc::new(Field::new("field2", DataType::Utf8, true)), + ]; + + // Target struct: {field1: Int32} + let target_fields = vec![Arc::new(Field::new("field1", DataType::Int32, true))]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot cast struct field 'field1'")); + assert!(error_msg.contains("Binary")); + assert!(error_msg.contains("Int32")); + } + + #[test] + fn test_validate_struct_compatibility_compatible_types() { + // Source struct: {field1: Int32, field2: String} + let source_fields = vec![ + Arc::new(Field::new("field1", DataType::Int32, true)), + Arc::new(Field::new("field2", DataType::Utf8, true)), + ]; + + // Target struct: {field1: Int64} (Int32 can cast to Int64) + let target_fields = vec![Arc::new(Field::new("field1", DataType::Int64, true))]; + + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + assert!(result.unwrap()); + } + + #[test] + fn test_validate_struct_compatibility_missing_field_in_source() { + // Source struct: {field2: String} (missing field1) + let source_fields = vec![Arc::new(Field::new("field2", DataType::Utf8, true))]; + + // Target struct: {field1: Int32} + let target_fields = vec![Arc::new(Field::new("field1", DataType::Int32, true))]; + + // Should be OK - missing fields will be filled with nulls + let result = validate_struct_compatibility(&source_fields, &target_fields); + assert!(result.is_ok()); + assert!(result.unwrap()); + } +} diff --git a/datafusion/common/src/null_equality.rs b/datafusion/common/src/null_equality.rs new file mode 100644 index 000000000000..847fb0975703 --- /dev/null +++ b/datafusion/common/src/null_equality.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Represents the behavior for null values when evaluating equality. Currently, its primary use +/// case is to define the behavior of joins for null values. +/// +/// # Examples +/// +/// The following table shows the expected equality behavior for `NullEquality`. +/// +/// | A | B | NullEqualsNothing | NullEqualsNull | +/// |------|------|-------------------|----------------| +/// | NULL | NULL | false | true | +/// | NULL | 'b' | false | false | +/// | 'a' | NULL | false | false | +/// | 'a' | 'b' | false | false | +/// +/// # Order +/// +/// The order on this type represents the "restrictiveness" of the behavior. The more restrictive +/// a behavior is, the fewer elements are considered to be equal to null. +/// [NullEquality::NullEqualsNothing] represents the most restrictive behavior. +/// +/// This mirrors the old order with `null_equals_null` booleans, as `false` indicated that +/// `null != null`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] +pub enum NullEquality { + /// Null is *not* equal to anything (`null != null`) + NullEqualsNothing, + /// Null is equal to null (`null == null`) + NullEqualsNull, +} diff --git a/datafusion/common/src/parsers.rs b/datafusion/common/src/parsers.rs index c73c8a55f18c..41571ebb8576 100644 --- a/datafusion/common/src/parsers.rs +++ b/datafusion/common/src/parsers.rs @@ -64,7 +64,7 @@ impl Display for CompressionTypeVariant { Self::ZSTD => "ZSTD", Self::UNCOMPRESSED => "", }; - write!(f, "{}", str) + write!(f, "{str}") } } diff --git a/datafusion/common/src/pruning.rs b/datafusion/common/src/pruning.rs new file mode 100644 index 000000000000..48750e3c995c --- /dev/null +++ b/datafusion/common/src/pruning.rs @@ -0,0 +1,1122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, NullArray, UInt64Array}; +use arrow::array::{ArrayRef, BooleanArray}; +use arrow::datatypes::{FieldRef, Schema, SchemaRef}; +use std::collections::HashSet; +use std::sync::Arc; + +use crate::error::DataFusionError; +use crate::stats::Precision; +use crate::{Column, Statistics}; +use crate::{ColumnStatistics, ScalarValue}; + +/// A source of runtime statistical information to [`PruningPredicate`]s. +/// +/// # Supported Information +/// +/// 1. Minimum and maximum values for columns +/// +/// 2. Null counts and row counts for columns +/// +/// 3. Whether the values in a column are contained in a set of literals +/// +/// # Vectorized Interface +/// +/// Information for containers / files are returned as Arrow [`ArrayRef`], so +/// the evaluation happens once on a single `RecordBatch`, which amortizes the +/// overhead of evaluating the predicate. This is important when pruning 1000s +/// of containers which often happens in analytic systems that have 1000s of +/// potential files to consider. +/// +/// For example, for the following three files with a single column `a`: +/// ```text +/// file1: column a: min=5, max=10 +/// file2: column a: No stats +/// file2: column a: min=20, max=30 +/// ``` +/// +/// PruningStatistics would return: +/// +/// ```text +/// min_values("a") -> Some([5, Null, 20]) +/// max_values("a") -> Some([10, Null, 30]) +/// min_values("X") -> None +/// ``` +/// +/// [`PruningPredicate`]: https://docs.rs/datafusion/latest/datafusion/physical_optimizer/pruning/struct.PruningPredicate.html +pub trait PruningStatistics { + /// Return the minimum values for the named column, if known. + /// + /// If the minimum value for a particular container is not known, the + /// returned array should have `null` in that row. If the minimum value is + /// not known for any row, return `None`. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + fn min_values(&self, column: &Column) -> Option; + + /// Return the maximum values for the named column, if known. + /// + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + fn max_values(&self, column: &Column) -> Option; + + /// Return the number of containers (e.g. Row Groups) being pruned with + /// these statistics. + /// + /// This value corresponds to the size of the [`ArrayRef`] returned by + /// [`Self::min_values`], [`Self::max_values`], [`Self::null_counts`], + /// and [`Self::row_counts`]. + fn num_containers(&self) -> usize; + + /// Return the number of null values for the named column as an + /// [`UInt64Array`] + /// + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + /// + /// [`UInt64Array`]: arrow::array::UInt64Array + fn null_counts(&self, column: &Column) -> Option; + + /// Return the number of rows for the named column in each container + /// as an [`UInt64Array`]. + /// + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + /// + /// [`UInt64Array`]: arrow::array::UInt64Array + fn row_counts(&self, column: &Column) -> Option; + + /// Returns [`BooleanArray`] where each row represents information known + /// about specific literal `values` in a column. + /// + /// For example, Parquet Bloom Filters implement this API to communicate + /// that `values` are known not to be present in a Row Group. + /// + /// The returned array has one row for each container, with the following + /// meanings: + /// * `true` if the values in `column` ONLY contain values from `values` + /// * `false` if the values in `column` are NOT ANY of `values` + /// * `null` if the neither of the above holds or is unknown. + /// + /// If these statistics can not determine column membership for any + /// container, return `None` (the default). + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option; +} + +/// Prune files based on their partition values. +/// +/// This is used both at planning time and execution time to prune +/// files based on their partition values. +/// This feeds into [`CompositePruningStatistics`] to allow pruning +/// with filters that depend both on partition columns and data columns +/// (e.g. `WHERE partition_col = data_col`). +#[derive(Clone)] +pub struct PartitionPruningStatistics { + /// Values for each column for each container. + /// + /// The outer vectors represent the columns while the inner vectors + /// represent the containers. The order must match the order of the + /// partition columns in [`PartitionPruningStatistics::partition_schema`]. + partition_values: Vec, + /// The number of containers. + /// + /// Stored since the partition values are column-major and if + /// there are no columns we wouldn't know the number of containers. + num_containers: usize, + /// The schema of the partition columns. + /// + /// This must **not** be the schema of the entire file or table: it must + /// only be the schema of the partition columns, in the same order as the + /// values in [`PartitionPruningStatistics::partition_values`]. + partition_schema: SchemaRef, +} + +impl PartitionPruningStatistics { + /// Create a new instance of [`PartitionPruningStatistics`]. + /// + /// Args: + /// * `partition_values`: A vector of vectors of [`ScalarValue`]s. + /// The outer vector represents the containers while the inner + /// vector represents the partition values for each column. + /// Note that this is the **opposite** of the order of the + /// partition columns in `PartitionPruningStatistics::partition_schema`. + /// * `partition_schema`: The schema of the partition columns. + /// This must **not** be the schema of the entire file or table: + /// instead it must only be the schema of the partition columns, + /// in the same order as the values in `partition_values`. + pub fn try_new( + partition_values: Vec>, + partition_fields: Vec, + ) -> Result { + let num_containers = partition_values.len(); + let partition_schema = Arc::new(Schema::new(partition_fields)); + let mut partition_values_by_column = + vec![ + Vec::with_capacity(partition_values.len()); + partition_schema.fields().len() + ]; + for partition_value in partition_values { + for (i, value) in partition_value.into_iter().enumerate() { + partition_values_by_column[i].push(value); + } + } + Ok(Self { + partition_values: partition_values_by_column + .into_iter() + .map(|v| { + if v.is_empty() { + Ok(Arc::new(NullArray::new(0)) as ArrayRef) + } else { + ScalarValue::iter_to_array(v) + } + }) + .collect::, _>>()?, + num_containers, + partition_schema, + }) + } +} + +impl PruningStatistics for PartitionPruningStatistics { + fn min_values(&self, column: &Column) -> Option { + let index = self.partition_schema.index_of(column.name()).ok()?; + self.partition_values.get(index).and_then(|v| { + if v.is_empty() || v.null_count() == v.len() { + // If the array is empty or all nulls, return None + None + } else { + // Otherwise, return the array as is + Some(Arc::clone(v)) + } + }) + } + + fn max_values(&self, column: &Column) -> Option { + self.min_values(column) + } + + fn num_containers(&self) -> usize { + self.num_containers + } + + fn null_counts(&self, _column: &Column) -> Option { + None + } + + fn row_counts(&self, _column: &Column) -> Option { + None + } + + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option { + let index = self.partition_schema.index_of(column.name()).ok()?; + let array = self.partition_values.get(index)?; + let boolean_array = values.iter().try_fold(None, |acc, v| { + let arrow_value = v.to_scalar().ok()?; + let eq_result = arrow::compute::kernels::cmp::eq(array, &arrow_value).ok()?; + match acc { + None => Some(Some(eq_result)), + Some(acc_array) => { + arrow::compute::kernels::boolean::and(&acc_array, &eq_result) + .map(Some) + .ok() + } + } + })??; + // If the boolean array is empty or all null values, return None + if boolean_array.is_empty() || boolean_array.null_count() == boolean_array.len() { + None + } else { + Some(boolean_array) + } + } +} + +/// Prune a set of containers represented by their statistics. +/// +/// Each [`Statistics`] represents a "container" -- some collection of data +/// that has statistics of its columns. +/// +/// It is up to the caller to decide what each container represents. For +/// example, they can come from a file (e.g. [`PartitionedFile`]) or a set of of +/// files (e.g. [`FileGroup`]) +/// +/// [`PartitionedFile`]: https://docs.rs/datafusion/latest/datafusion/datasource/listing/struct.PartitionedFile.html +/// [`FileGroup`]: https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.FileGroup.html +#[derive(Clone)] +pub struct PrunableStatistics { + /// Statistics for each container. + /// These are taken as a reference since they may be rather large / expensive to clone + /// and we often won't return all of them as ArrayRefs (we only return the columns the predicate requests). + statistics: Vec>, + /// The schema of the file these statistics are for. + schema: SchemaRef, +} + +impl PrunableStatistics { + /// Create a new instance of [`PrunableStatistics`]. + /// Each [`Statistics`] represents a container (e.g. a file or a partition of files). + /// The `schema` is the schema of the data in the containers and should apply to all files. + pub fn new(statistics: Vec>, schema: SchemaRef) -> Self { + Self { statistics, schema } + } + + fn get_exact_column_statistics( + &self, + column: &Column, + get_stat: impl Fn(&ColumnStatistics) -> &Precision, + ) -> Option { + let index = self.schema.index_of(column.name()).ok()?; + let mut has_value = false; + match ScalarValue::iter_to_array(self.statistics.iter().map(|s| { + s.column_statistics + .get(index) + .and_then(|stat| { + if let Precision::Exact(min) = get_stat(stat) { + has_value = true; + Some(min.clone()) + } else { + None + } + }) + .unwrap_or(ScalarValue::Null) + })) { + // If there is any non-null value and no errors, return the array + Ok(array) => has_value.then_some(array), + Err(_) => { + log::warn!( + "Failed to convert min values to array for column {}", + column.name() + ); + None + } + } + } +} + +impl PruningStatistics for PrunableStatistics { + fn min_values(&self, column: &Column) -> Option { + self.get_exact_column_statistics(column, |stat| &stat.min_value) + } + + fn max_values(&self, column: &Column) -> Option { + self.get_exact_column_statistics(column, |stat| &stat.max_value) + } + + fn num_containers(&self) -> usize { + self.statistics.len() + } + + fn null_counts(&self, column: &Column) -> Option { + let index = self.schema.index_of(column.name()).ok()?; + if self.statistics.iter().any(|s| { + s.column_statistics + .get(index) + .is_some_and(|stat| stat.null_count.is_exact().unwrap_or(false)) + }) { + Some(Arc::new( + self.statistics + .iter() + .map(|s| { + s.column_statistics.get(index).and_then(|stat| { + if let Precision::Exact(null_count) = &stat.null_count { + u64::try_from(*null_count).ok() + } else { + None + } + }) + }) + .collect::(), + )) + } else { + None + } + } + + fn row_counts(&self, column: &Column) -> Option { + // If the column does not exist in the schema, return None + if self.schema.index_of(column.name()).is_err() { + return None; + } + if self + .statistics + .iter() + .any(|s| s.num_rows.is_exact().unwrap_or(false)) + { + Some(Arc::new( + self.statistics + .iter() + .map(|s| { + if let Precision::Exact(row_count) = &s.num_rows { + u64::try_from(*row_count).ok() + } else { + None + } + }) + .collect::(), + )) + } else { + None + } + } + + fn contained( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + None + } +} + +/// Combine multiple [`PruningStatistics`] into a single +/// [`CompositePruningStatistics`]. +/// This can be used to combine statistics from different sources, +/// for example partition values and file statistics. +/// This allows pruning with filters that depend on multiple sources of statistics, +/// such as `WHERE partition_col = data_col`. +/// This is done by iterating over the statistics and returning the first +/// one that has information for the requested column. +/// If multiple statistics have information for the same column, +/// the first one is returned without any regard for completeness or accuracy. +/// That is: if the first statistics has information for a column, even if it is incomplete, +/// that is returned even if a later statistics has more complete information. +pub struct CompositePruningStatistics { + pub statistics: Vec>, +} + +impl CompositePruningStatistics { + /// Create a new instance of [`CompositePruningStatistics`] from + /// a vector of [`PruningStatistics`]. + pub fn new(statistics: Vec>) -> Self { + assert!(!statistics.is_empty()); + // Check that all statistics have the same number of containers + let num_containers = statistics[0].num_containers(); + for stats in &statistics { + assert_eq!(num_containers, stats.num_containers()); + } + Self { statistics } + } +} + +impl PruningStatistics for CompositePruningStatistics { + fn min_values(&self, column: &Column) -> Option { + for stats in &self.statistics { + if let Some(array) = stats.min_values(column) { + return Some(array); + } + } + None + } + + fn max_values(&self, column: &Column) -> Option { + for stats in &self.statistics { + if let Some(array) = stats.max_values(column) { + return Some(array); + } + } + None + } + + fn num_containers(&self) -> usize { + self.statistics[0].num_containers() + } + + fn null_counts(&self, column: &Column) -> Option { + for stats in &self.statistics { + if let Some(array) = stats.null_counts(column) { + return Some(array); + } + } + None + } + + fn row_counts(&self, column: &Column) -> Option { + for stats in &self.statistics { + if let Some(array) = stats.row_counts(column) { + return Some(array); + } + } + None + } + + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option { + for stats in &self.statistics { + if let Some(array) = stats.contained(column, values) { + return Some(array); + } + } + None + } +} + +#[cfg(test)] +mod tests { + use crate::{ + cast::{as_int32_array, as_uint64_array}, + ColumnStatistics, + }; + + use super::*; + use arrow::datatypes::{DataType, Field}; + use std::sync::Arc; + + #[test] + fn test_partition_pruning_statistics() { + let partition_values = vec![ + vec![ScalarValue::from(1i32), ScalarValue::from(2i32)], + vec![ScalarValue::from(3i32), ScalarValue::from(4i32)], + ]; + let partition_fields = vec![ + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, false)), + ]; + let partition_stats = + PartitionPruningStatistics::try_new(partition_values, partition_fields) + .unwrap(); + + let column_a = Column::new_unqualified("a"); + let column_b = Column::new_unqualified("b"); + + // Partition values don't know anything about nulls or row counts + assert!(partition_stats.null_counts(&column_a).is_none()); + assert!(partition_stats.row_counts(&column_a).is_none()); + assert!(partition_stats.null_counts(&column_b).is_none()); + assert!(partition_stats.row_counts(&column_b).is_none()); + + // Min/max values are the same as the partition values + let min_values_a = + as_int32_array(&partition_stats.min_values(&column_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_a = vec![Some(1), Some(3)]; + assert_eq!(min_values_a, expected_values_a); + let max_values_a = + as_int32_array(&partition_stats.max_values(&column_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_a = vec![Some(1), Some(3)]; + assert_eq!(max_values_a, expected_values_a); + + let min_values_b = + as_int32_array(&partition_stats.min_values(&column_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_b = vec![Some(2), Some(4)]; + assert_eq!(min_values_b, expected_values_b); + let max_values_b = + as_int32_array(&partition_stats.max_values(&column_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_b = vec![Some(2), Some(4)]; + assert_eq!(max_values_b, expected_values_b); + + // Contained values are only true for the partition values + let values = HashSet::from([ScalarValue::from(1i32)]); + let contained_a = partition_stats.contained(&column_a, &values).unwrap(); + let expected_contained_a = BooleanArray::from(vec![true, false]); + assert_eq!(contained_a, expected_contained_a); + let contained_b = partition_stats.contained(&column_b, &values).unwrap(); + let expected_contained_b = BooleanArray::from(vec![false, false]); + assert_eq!(contained_b, expected_contained_b); + + // The number of containers is the length of the partition values + assert_eq!(partition_stats.num_containers(), 2); + } + + #[test] + fn test_partition_pruning_statistics_empty() { + let partition_values = vec![]; + let partition_fields = vec![ + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Field::new("b", DataType::Int32, false)), + ]; + let partition_stats = + PartitionPruningStatistics::try_new(partition_values, partition_fields) + .unwrap(); + + let column_a = Column::new_unqualified("a"); + let column_b = Column::new_unqualified("b"); + + // Partition values don't know anything about nulls or row counts + assert!(partition_stats.null_counts(&column_a).is_none()); + assert!(partition_stats.row_counts(&column_a).is_none()); + assert!(partition_stats.null_counts(&column_b).is_none()); + assert!(partition_stats.row_counts(&column_b).is_none()); + + // Min/max values are all missing + assert!(partition_stats.min_values(&column_a).is_none()); + assert!(partition_stats.max_values(&column_a).is_none()); + assert!(partition_stats.min_values(&column_b).is_none()); + assert!(partition_stats.max_values(&column_b).is_none()); + + // Contained values are all empty + let values = HashSet::from([ScalarValue::from(1i32)]); + assert!(partition_stats.contained(&column_a, &values).is_none()); + } + + #[test] + fn test_statistics_pruning_statistics() { + let statistics = vec![ + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(0i32))) + .with_max_value(Precision::Exact(ScalarValue::from(100i32))) + .with_null_count(Precision::Exact(0)), + ) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(100i32))) + .with_max_value(Precision::Exact(ScalarValue::from(200i32))) + .with_null_count(Precision::Exact(5)), + ) + .with_num_rows(Precision::Exact(100)), + ), + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(50i32))) + .with_max_value(Precision::Exact(ScalarValue::from(300i32))) + .with_null_count(Precision::Exact(10)), + ) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(200i32))) + .with_max_value(Precision::Exact(ScalarValue::from(400i32))) + .with_null_count(Precision::Exact(0)), + ) + .with_num_rows(Precision::Exact(200)), + ), + ]; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + let pruning_stats = PrunableStatistics::new(statistics, schema); + + let column_a = Column::new_unqualified("a"); + let column_b = Column::new_unqualified("b"); + + // Min/max values are the same as the statistics + let min_values_a = as_int32_array(&pruning_stats.min_values(&column_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_a = vec![Some(0), Some(50)]; + assert_eq!(min_values_a, expected_values_a); + let max_values_a = as_int32_array(&pruning_stats.max_values(&column_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_a = vec![Some(100), Some(300)]; + assert_eq!(max_values_a, expected_values_a); + let min_values_b = as_int32_array(&pruning_stats.min_values(&column_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_b = vec![Some(100), Some(200)]; + assert_eq!(min_values_b, expected_values_b); + let max_values_b = as_int32_array(&pruning_stats.max_values(&column_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_b = vec![Some(200), Some(400)]; + assert_eq!(max_values_b, expected_values_b); + + // Null counts are the same as the statistics + let null_counts_a = + as_uint64_array(&pruning_stats.null_counts(&column_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_null_counts_a = vec![Some(0), Some(10)]; + assert_eq!(null_counts_a, expected_null_counts_a); + let null_counts_b = + as_uint64_array(&pruning_stats.null_counts(&column_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_null_counts_b = vec![Some(5), Some(0)]; + assert_eq!(null_counts_b, expected_null_counts_b); + + // Row counts are the same as the statistics + let row_counts_a = as_uint64_array(&pruning_stats.row_counts(&column_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_row_counts_a = vec![Some(100), Some(200)]; + assert_eq!(row_counts_a, expected_row_counts_a); + let row_counts_b = as_uint64_array(&pruning_stats.row_counts(&column_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_row_counts_b = vec![Some(100), Some(200)]; + assert_eq!(row_counts_b, expected_row_counts_b); + + // Contained values are all null/missing (we can't know this just from statistics) + let values = HashSet::from([ScalarValue::from(0i32)]); + assert!(pruning_stats.contained(&column_a, &values).is_none()); + assert!(pruning_stats.contained(&column_b, &values).is_none()); + + // The number of containers is the length of the statistics + assert_eq!(pruning_stats.num_containers(), 2); + + // Test with a column that has no statistics + let column_c = Column::new_unqualified("c"); + assert!(pruning_stats.min_values(&column_c).is_none()); + assert!(pruning_stats.max_values(&column_c).is_none()); + assert!(pruning_stats.null_counts(&column_c).is_none()); + // Since row counts uses the first column that has row counts we get them back even + // if this columns does not have them set. + // This is debatable, personally I think `row_count` should not take a `Column` as an argument + // at all since all columns should have the same number of rows. + // But for now we just document the current behavior in this test. + let row_counts_c = as_uint64_array(&pruning_stats.row_counts(&column_c).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_row_counts_c = vec![Some(100), Some(200)]; + assert_eq!(row_counts_c, expected_row_counts_c); + assert!(pruning_stats.contained(&column_c, &values).is_none()); + + // Test with a column that doesn't exist + let column_d = Column::new_unqualified("d"); + assert!(pruning_stats.min_values(&column_d).is_none()); + assert!(pruning_stats.max_values(&column_d).is_none()); + assert!(pruning_stats.null_counts(&column_d).is_none()); + assert!(pruning_stats.row_counts(&column_d).is_none()); + assert!(pruning_stats.contained(&column_d, &values).is_none()); + } + + #[test] + fn test_statistics_pruning_statistics_empty() { + let statistics = vec![]; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + let pruning_stats = PrunableStatistics::new(statistics, schema); + + let column_a = Column::new_unqualified("a"); + let column_b = Column::new_unqualified("b"); + + // Min/max values are all missing + assert!(pruning_stats.min_values(&column_a).is_none()); + assert!(pruning_stats.max_values(&column_a).is_none()); + assert!(pruning_stats.min_values(&column_b).is_none()); + assert!(pruning_stats.max_values(&column_b).is_none()); + + // Null counts are all missing + assert!(pruning_stats.null_counts(&column_a).is_none()); + assert!(pruning_stats.null_counts(&column_b).is_none()); + + // Row counts are all missing + assert!(pruning_stats.row_counts(&column_a).is_none()); + assert!(pruning_stats.row_counts(&column_b).is_none()); + + // Contained values are all empty + let values = HashSet::from([ScalarValue::from(1i32)]); + assert!(pruning_stats.contained(&column_a, &values).is_none()); + } + + #[test] + fn test_composite_pruning_statistics_partition_and_file() { + // Create partition statistics + let partition_values = vec![ + vec![ScalarValue::from(1i32), ScalarValue::from(10i32)], + vec![ScalarValue::from(2i32), ScalarValue::from(20i32)], + ]; + let partition_fields = vec![ + Arc::new(Field::new("part_a", DataType::Int32, false)), + Arc::new(Field::new("part_b", DataType::Int32, false)), + ]; + let partition_stats = + PartitionPruningStatistics::try_new(partition_values, partition_fields) + .unwrap(); + + // Create file statistics + let file_statistics = vec![ + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(100i32))) + .with_max_value(Precision::Exact(ScalarValue::from(200i32))) + .with_null_count(Precision::Exact(0)), + ) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(300i32))) + .with_max_value(Precision::Exact(ScalarValue::from(400i32))) + .with_null_count(Precision::Exact(5)), + ) + .with_num_rows(Precision::Exact(100)), + ), + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(500i32))) + .with_max_value(Precision::Exact(ScalarValue::from(600i32))) + .with_null_count(Precision::Exact(10)), + ) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(700i32))) + .with_max_value(Precision::Exact(ScalarValue::from(800i32))) + .with_null_count(Precision::Exact(0)), + ) + .with_num_rows(Precision::Exact(200)), + ), + ]; + + let file_schema = Arc::new(Schema::new(vec![ + Field::new("col_x", DataType::Int32, false), + Field::new("col_y", DataType::Int32, false), + ])); + let file_stats = PrunableStatistics::new(file_statistics, file_schema); + + // Create composite statistics + let composite_stats = CompositePruningStatistics::new(vec![ + Box::new(partition_stats), + Box::new(file_stats), + ]); + + // Test accessing columns that are only in partition statistics + let part_a = Column::new_unqualified("part_a"); + let part_b = Column::new_unqualified("part_b"); + + // Test accessing columns that are only in file statistics + let col_x = Column::new_unqualified("col_x"); + let col_y = Column::new_unqualified("col_y"); + + // For partition columns, should get values from partition statistics + let min_values_part_a = + as_int32_array(&composite_stats.min_values(&part_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_part_a = vec![Some(1), Some(2)]; + assert_eq!(min_values_part_a, expected_values_part_a); + + let max_values_part_a = + as_int32_array(&composite_stats.max_values(&part_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + // For partition values, min and max are the same + assert_eq!(max_values_part_a, expected_values_part_a); + + let min_values_part_b = + as_int32_array(&composite_stats.min_values(&part_b).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_part_b = vec![Some(10), Some(20)]; + assert_eq!(min_values_part_b, expected_values_part_b); + + // For file columns, should get values from file statistics + let min_values_col_x = + as_int32_array(&composite_stats.min_values(&col_x).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_col_x = vec![Some(100), Some(500)]; + assert_eq!(min_values_col_x, expected_values_col_x); + + let max_values_col_x = + as_int32_array(&composite_stats.max_values(&col_x).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_max_values_col_x = vec![Some(200), Some(600)]; + assert_eq!(max_values_col_x, expected_max_values_col_x); + + let min_values_col_y = + as_int32_array(&composite_stats.min_values(&col_y).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_values_col_y = vec![Some(300), Some(700)]; + assert_eq!(min_values_col_y, expected_values_col_y); + + // Test null counts - only available from file statistics + assert!(composite_stats.null_counts(&part_a).is_none()); + assert!(composite_stats.null_counts(&part_b).is_none()); + + let null_counts_col_x = + as_uint64_array(&composite_stats.null_counts(&col_x).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_null_counts_col_x = vec![Some(0), Some(10)]; + assert_eq!(null_counts_col_x, expected_null_counts_col_x); + + // Test row counts - only available from file statistics + assert!(composite_stats.row_counts(&part_a).is_none()); + let row_counts_col_x = + as_uint64_array(&composite_stats.row_counts(&col_x).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_row_counts = vec![Some(100), Some(200)]; + assert_eq!(row_counts_col_x, expected_row_counts); + + // Test contained values - only available from partition statistics + let values = HashSet::from([ScalarValue::from(1i32)]); + let contained_part_a = composite_stats.contained(&part_a, &values).unwrap(); + let expected_contained_part_a = BooleanArray::from(vec![true, false]); + assert_eq!(contained_part_a, expected_contained_part_a); + + // File statistics don't implement contained + assert!(composite_stats.contained(&col_x, &values).is_none()); + + // Non-existent column should return None for everything + let non_existent = Column::new_unqualified("non_existent"); + assert!(composite_stats.min_values(&non_existent).is_none()); + assert!(composite_stats.max_values(&non_existent).is_none()); + assert!(composite_stats.null_counts(&non_existent).is_none()); + assert!(composite_stats.row_counts(&non_existent).is_none()); + assert!(composite_stats.contained(&non_existent, &values).is_none()); + + // Verify num_containers matches + assert_eq!(composite_stats.num_containers(), 2); + } + + #[test] + fn test_composite_pruning_statistics_priority() { + // Create two sets of file statistics with the same column names + // but different values to test that the first one gets priority + + // First set of statistics + let first_statistics = vec![ + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(100i32))) + .with_max_value(Precision::Exact(ScalarValue::from(200i32))) + .with_null_count(Precision::Exact(0)), + ) + .with_num_rows(Precision::Exact(100)), + ), + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(300i32))) + .with_max_value(Precision::Exact(ScalarValue::from(400i32))) + .with_null_count(Precision::Exact(5)), + ) + .with_num_rows(Precision::Exact(200)), + ), + ]; + + let first_schema = Arc::new(Schema::new(vec![Field::new( + "col_a", + DataType::Int32, + false, + )])); + let first_stats = PrunableStatistics::new(first_statistics, first_schema); + + // Second set of statistics with the same column name but different values + let second_statistics = vec![ + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(1000i32))) + .with_max_value(Precision::Exact(ScalarValue::from(2000i32))) + .with_null_count(Precision::Exact(10)), + ) + .with_num_rows(Precision::Exact(1000)), + ), + Arc::new( + Statistics::default() + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::from(3000i32))) + .with_max_value(Precision::Exact(ScalarValue::from(4000i32))) + .with_null_count(Precision::Exact(20)), + ) + .with_num_rows(Precision::Exact(2000)), + ), + ]; + + let second_schema = Arc::new(Schema::new(vec![Field::new( + "col_a", + DataType::Int32, + false, + )])); + let second_stats = PrunableStatistics::new(second_statistics, second_schema); + + // Create composite statistics with first stats having priority + let composite_stats = CompositePruningStatistics::new(vec![ + Box::new(first_stats.clone()), + Box::new(second_stats.clone()), + ]); + + let col_a = Column::new_unqualified("col_a"); + + // Should get values from first statistics since it has priority + let min_values = as_int32_array(&composite_stats.min_values(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_min_values = vec![Some(100), Some(300)]; + assert_eq!(min_values, expected_min_values); + + let max_values = as_int32_array(&composite_stats.max_values(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_max_values = vec![Some(200), Some(400)]; + assert_eq!(max_values, expected_max_values); + + let null_counts = as_uint64_array(&composite_stats.null_counts(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_null_counts = vec![Some(0), Some(5)]; + assert_eq!(null_counts, expected_null_counts); + + let row_counts = as_uint64_array(&composite_stats.row_counts(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_row_counts = vec![Some(100), Some(200)]; + assert_eq!(row_counts, expected_row_counts); + + // Create composite statistics with second stats having priority + // Now that we've added Clone trait to PrunableStatistics, we can just clone them + + let composite_stats_reversed = CompositePruningStatistics::new(vec![ + Box::new(second_stats.clone()), + Box::new(first_stats.clone()), + ]); + + // Should get values from second statistics since it now has priority + let min_values = + as_int32_array(&composite_stats_reversed.min_values(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_min_values = vec![Some(1000), Some(3000)]; + assert_eq!(min_values, expected_min_values); + + let max_values = + as_int32_array(&composite_stats_reversed.max_values(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_max_values = vec![Some(2000), Some(4000)]; + assert_eq!(max_values, expected_max_values); + + let null_counts = + as_uint64_array(&composite_stats_reversed.null_counts(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_null_counts = vec![Some(10), Some(20)]; + assert_eq!(null_counts, expected_null_counts); + + let row_counts = + as_uint64_array(&composite_stats_reversed.row_counts(&col_a).unwrap()) + .unwrap() + .into_iter() + .collect::>(); + let expected_row_counts = vec![Some(1000), Some(2000)]; + assert_eq!(row_counts, expected_row_counts); + } + + #[test] + fn test_composite_pruning_statistics_empty_and_mismatched_containers() { + // Test with empty statistics vector + // This should never happen, so we panic instead of returning a Result which would burned callers + let result = std::panic::catch_unwind(|| { + CompositePruningStatistics::new(vec![]); + }); + assert!(result.is_err()); + + // We should panic here because the number of containers is different + let result = std::panic::catch_unwind(|| { + // Create statistics with different number of containers + // Use partition stats for the test + let partition_values_1 = vec![ + vec![ScalarValue::from(1i32), ScalarValue::from(10i32)], + vec![ScalarValue::from(2i32), ScalarValue::from(20i32)], + ]; + let partition_fields_1 = vec![ + Arc::new(Field::new("part_a", DataType::Int32, false)), + Arc::new(Field::new("part_b", DataType::Int32, false)), + ]; + let partition_stats_1 = PartitionPruningStatistics::try_new( + partition_values_1, + partition_fields_1, + ) + .unwrap(); + let partition_values_2 = vec![ + vec![ScalarValue::from(3i32), ScalarValue::from(30i32)], + vec![ScalarValue::from(4i32), ScalarValue::from(40i32)], + vec![ScalarValue::from(5i32), ScalarValue::from(50i32)], + ]; + let partition_fields_2 = vec![ + Arc::new(Field::new("part_x", DataType::Int32, false)), + Arc::new(Field::new("part_y", DataType::Int32, false)), + ]; + let partition_stats_2 = PartitionPruningStatistics::try_new( + partition_values_2, + partition_fields_2, + ) + .unwrap(); + + CompositePruningStatistics::new(vec![ + Box::new(partition_stats_1), + Box::new(partition_stats_2), + ]); + }); + assert!(result.is_err()); + } +} diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index b8d9aea810f0..f774f46b424d 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -52,13 +52,14 @@ use arrow::compute::kernels::{ }; use arrow::datatypes::{ i256, ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType, - Date32Type, Date64Type, Field, Float32Type, Int16Type, Int32Type, Int64Type, - Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, - IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, + Date32Type, Field, Float32Type, Int16Type, Int32Type, Int64Type, Int8Type, + IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, + TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, UnionFields, UnionMode, DECIMAL128_MAX_PRECISION, }; use arrow::util::display::{array_value_to_string, ArrayFormatter, FormatOptions}; +use chrono::{Duration, NaiveDate}; use half::f16; pub use struct_builder::ScalarStructBuilder; @@ -506,7 +507,7 @@ impl PartialOrd for ScalarValue { } (List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None, (Struct(struct_arr1), Struct(struct_arr2)) => { - partial_cmp_struct(struct_arr1, struct_arr2) + partial_cmp_struct(struct_arr1.as_ref(), struct_arr2.as_ref()) } (Struct(_), _) => None, (Map(map_arr1), Map(map_arr2)) => partial_cmp_map(map_arr1, map_arr2), @@ -597,10 +598,28 @@ fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option { let arr1 = first_array_for_list(arr1); let arr2 = first_array_for_list(arr2); - let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; - let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; + let min_length = arr1.len().min(arr2.len()); + let arr1_trimmed = arr1.slice(0, min_length); + let arr2_trimmed = arr2.slice(0, min_length); + + let lt_res = arrow::compute::kernels::cmp::lt(&arr1_trimmed, &arr2_trimmed).ok()?; + let eq_res = arrow::compute::kernels::cmp::eq(&arr1_trimmed, &arr2_trimmed).ok()?; for j in 0..lt_res.len() { + // In Postgres, NULL values in lists are always considered to be greater than non-NULL values: + // + // $ SELECT ARRAY[NULL]::integer[] > ARRAY[1] + // true + // + // These next two if statements are introduced for replicating Postgres behavior, as + // arrow::compute does not account for this. + if arr1_trimmed.is_null(j) && !arr2_trimmed.is_null(j) { + return Some(Ordering::Greater); + } + if !arr1_trimmed.is_null(j) && arr2_trimmed.is_null(j) { + return Some(Ordering::Less); + } + if lt_res.is_valid(j) && lt_res.value(j) { return Some(Ordering::Less); } @@ -609,10 +628,23 @@ fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option { } } - Some(Ordering::Equal) + Some(arr1.len().cmp(&arr2.len())) } -fn partial_cmp_struct(s1: &Arc, s2: &Arc) -> Option { +fn flatten<'a>(array: &'a StructArray, columns: &mut Vec<&'a ArrayRef>) { + for i in 0..array.num_columns() { + let column = array.column(i); + if let Some(nested_struct) = column.as_any().downcast_ref::() { + // If it's a nested struct, recursively expand + flatten(nested_struct, columns); + } else { + // If it's a primitive type, add directly + columns.push(column); + } + } +} + +pub fn partial_cmp_struct(s1: &StructArray, s2: &StructArray) -> Option { if s1.len() != s2.len() { return None; } @@ -621,9 +653,15 @@ fn partial_cmp_struct(s1: &Arc, s2: &Arc) -> Option() } + + /// Compacts the allocation referenced by `self` to the minimum, copying the data if + /// necessary. + /// + /// This can be relevant when `self` is a list or contains a list as a nested value, as + /// a single list holds an Arc to its entire original array buffer. + pub fn compact(&mut self) { + match self { + ScalarValue::Null + | ScalarValue::Boolean(_) + | ScalarValue::Float16(_) + | ScalarValue::Float32(_) + | ScalarValue::Float64(_) + | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Decimal256(_, _, _) + | ScalarValue::Int8(_) + | ScalarValue::Int16(_) + | ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + | ScalarValue::Date32(_) + | ScalarValue::Date64(_) + | ScalarValue::Time32Second(_) + | ScalarValue::Time32Millisecond(_) + | ScalarValue::Time64Microsecond(_) + | ScalarValue::Time64Nanosecond(_) + | ScalarValue::IntervalYearMonth(_) + | ScalarValue::IntervalDayTime(_) + | ScalarValue::IntervalMonthDayNano(_) + | ScalarValue::DurationSecond(_) + | ScalarValue::DurationMillisecond(_) + | ScalarValue::DurationMicrosecond(_) + | ScalarValue::DurationNanosecond(_) + | ScalarValue::Utf8(_) + | ScalarValue::LargeUtf8(_) + | ScalarValue::Utf8View(_) + | ScalarValue::TimestampSecond(_, _) + | ScalarValue::TimestampMillisecond(_, _) + | ScalarValue::TimestampMicrosecond(_, _) + | ScalarValue::TimestampNanosecond(_, _) + | ScalarValue::Binary(_) + | ScalarValue::FixedSizeBinary(_, _) + | ScalarValue::LargeBinary(_) + | ScalarValue::BinaryView(_) => (), + ScalarValue::FixedSizeList(arr) => { + let array = copy_array_data(&arr.to_data()); + *Arc::make_mut(arr) = FixedSizeListArray::from(array); + } + ScalarValue::List(arr) => { + let array = copy_array_data(&arr.to_data()); + *Arc::make_mut(arr) = ListArray::from(array); + } + ScalarValue::LargeList(arr) => { + let array = copy_array_data(&arr.to_data()); + *Arc::make_mut(arr) = LargeListArray::from(array) + } + ScalarValue::Struct(arr) => { + let array = copy_array_data(&arr.to_data()); + *Arc::make_mut(arr) = StructArray::from(array); + } + ScalarValue::Map(arr) => { + let array = copy_array_data(&arr.to_data()); + *Arc::make_mut(arr) = MapArray::from(array); + } + ScalarValue::Union(val, _, _) => { + if let Some((_, value)) = val.as_mut() { + value.compact(); + } + } + ScalarValue::Dictionary(_, value) => { + value.compact(); + } + } + } + + /// Compacts ([ScalarValue::compact]) the current [ScalarValue] and returns it. + pub fn compacted(mut self) -> Self { + self.compact(); + self + } +} + +/// Compacts the data of an `ArrayData` into a new `ArrayData`. +/// +/// This is useful when you want to minimize the memory footprint of an +/// `ArrayData`. For example, the value returned by [`Array::slice`] still +/// points at the same underlying data buffers as the original array, which may +/// hold many more values. Calling `copy_array_data` on the sliced array will +/// create a new, smaller, `ArrayData` that only contains the data for the +/// sliced array. +/// +/// # Example +/// ``` +/// # use arrow::array::{make_array, Array, Int32Array}; +/// use datafusion_common::scalar::copy_array_data; +/// let array = Int32Array::from_iter_values(0..8192); +/// // Take only the first 2 elements +/// let sliced_array = array.slice(0, 2); +/// // The memory footprint of `sliced_array` is close to 8192 * 4 bytes +/// assert_eq!(32864, sliced_array.get_array_memory_size()); +/// // however, we can copy the data to a new `ArrayData` +/// let new_array = make_array(copy_array_data(&sliced_array.into_data())); +/// // The memory footprint of `new_array` is now only 2 * 4 bytes +/// // and overhead: +/// assert_eq!(160, new_array.get_array_memory_size()); +/// ``` +/// +/// See also [`ScalarValue::compact`] which applies to `ScalarValue` instances +/// as necessary. +pub fn copy_array_data(src_data: &ArrayData) -> ArrayData { + let mut copy = MutableArrayData::new(vec![&src_data], true, src_data.len()); + copy.extend(0, 0, src_data.len()); + copy.freeze() } macro_rules! impl_scalar { @@ -3663,12 +3817,28 @@ impl fmt::Display for ScalarValue { ScalarValue::List(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, ScalarValue::LargeList(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, ScalarValue::FixedSizeList(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, - ScalarValue::Date32(e) => { - format_option!(f, e.map(|v| Date32Type::to_naive_date(v).to_string()))? - } - ScalarValue::Date64(e) => { - format_option!(f, e.map(|v| Date64Type::to_naive_date(v).to_string()))? - } + ScalarValue::Date32(e) => format_option!( + f, + e.map(|v| { + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + match epoch.checked_add_signed(Duration::try_days(v as i64).unwrap()) + { + Some(date) => date.to_string(), + None => "".to_string(), + } + }) + )?, + ScalarValue::Date64(e) => format_option!( + f, + e.map(|v| { + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + match epoch.checked_add_signed(Duration::try_milliseconds(v).unwrap()) + { + Some(date) => date.to_string(), + None => "".to_string(), + } + }) + )?, ScalarValue::Time32Second(e) => format_option!(f, e)?, ScalarValue::Time32Millisecond(e) => format_option!(f, e)?, ScalarValue::Time64Microsecond(e) => format_option!(f, e)?, @@ -3739,7 +3909,7 @@ impl fmt::Display for ScalarValue { array_value_to_string(arr.column(0), i).unwrap(); let value = array_value_to_string(arr.column(1), i).unwrap(); - buffer.push_back(format!("{}:{}", key, value)); + buffer.push_back(format!("{key}:{value}")); } format!( "{{{}}}", @@ -3758,7 +3928,7 @@ impl fmt::Display for ScalarValue { )? } ScalarValue::Union(val, _fields, _mode) => match val { - Some((id, val)) => write!(f, "{}:{}", id, val)?, + Some((id, val)) => write!(f, "{id}:{val}")?, None => write!(f, "NULL")?, }, ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?, @@ -3935,7 +4105,7 @@ impl fmt::Debug for ScalarValue { write!(f, "DurationNanosecond(\"{self}\")") } ScalarValue::Union(val, _fields, _mode) => match val { - Some((id, val)) => write!(f, "Union {}:{}", id, val), + Some((id, val)) => write!(f, "Union {id}:{val}"), None => write!(f, "Union(NULL)"), }, ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"), @@ -4059,7 +4229,7 @@ mod tests { #[test] #[should_panic( - expected = "Error building ScalarValue::Struct. Expected array with exactly one element, found array with 4 elements" + expected = "InvalidArgumentError(\"Incorrect array length for StructArray field \\\"bool\\\", expected 1 got 4\")" )] fn test_scalar_value_from_for_struct_should_panic() { let _ = ScalarStructBuilder::new() @@ -4752,6 +4922,109 @@ mod tests { ])]), )); assert_eq!(a.partial_cmp(&b), Some(Ordering::Less)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(2), + Some(3), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Less)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(2), + Some(3), + Some(4), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); + + let a = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + None, + Some(2), + Some(3), + ])]), + )); + let b = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); + + let a = ScalarValue::LargeList(Arc::new(LargeListArray::from_iter_primitive::< + Int64Type, + _, + _, + >(vec![Some(vec![ + None, + Some(2), + Some(3), + ])]))); + let b = ScalarValue::LargeList(Arc::new(LargeListArray::from_iter_primitive::< + Int64Type, + _, + _, + >(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]))); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); + + let a = ScalarValue::FixedSizeList(Arc::new( + FixedSizeListArray::from_iter_primitive::( + vec![Some(vec![None, Some(2), Some(3)])], + 3, + ), + )); + let b = ScalarValue::FixedSizeList(Arc::new( + FixedSizeListArray::from_iter_primitive::( + vec![Some(vec![Some(1), Some(2), Some(3)])], + 3, + ), + )); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater)); } #[test] @@ -6973,6 +7246,19 @@ mod tests { "); } + #[test] + fn test_display_date64_large_values() { + assert_eq!( + format!("{}", ScalarValue::Date64(Some(790179464505))), + "1995-01-15" + ); + // This used to panic, see https://github.com/apache/arrow-rs/issues/7728 + assert_eq!( + format!("{}", ScalarValue::Date64(Some(-790179464505600000))), + "" + ); + } + #[test] fn test_struct_display_null() { let fields = vec![Field::new("a", DataType::Int32, false)]; @@ -7162,14 +7448,14 @@ mod tests { fn get_random_timestamps(sample_size: u64) -> Vec { let vector_size = sample_size; let mut timestamp = vec![]; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for i in 0..vector_size { - let year = rng.gen_range(1995..=2050); - let month = rng.gen_range(1..=12); - let day = rng.gen_range(1..=28); // to exclude invalid dates - let hour = rng.gen_range(0..=23); - let minute = rng.gen_range(0..=59); - let second = rng.gen_range(0..=59); + let year = rng.random_range(1995..=2050); + let month = rng.random_range(1..=12); + let day = rng.random_range(1..=28); // to exclude invalid dates + let hour = rng.random_range(0..=23); + let minute = rng.random_range(0..=59); + let second = rng.random_range(0..=59); if i % 4 == 0 { timestamp.push(ScalarValue::TimestampSecond( Some( @@ -7183,7 +7469,7 @@ mod tests { None, )) } else if i % 4 == 1 { - let millisec = rng.gen_range(0..=999); + let millisec = rng.random_range(0..=999); timestamp.push(ScalarValue::TimestampMillisecond( Some( NaiveDate::from_ymd_opt(year, month, day) @@ -7196,7 +7482,7 @@ mod tests { None, )) } else if i % 4 == 2 { - let microsec = rng.gen_range(0..=999_999); + let microsec = rng.random_range(0..=999_999); timestamp.push(ScalarValue::TimestampMicrosecond( Some( NaiveDate::from_ymd_opt(year, month, day) @@ -7209,7 +7495,7 @@ mod tests { None, )) } else if i % 4 == 3 { - let nanosec = rng.gen_range(0..=999_999_999); + let nanosec = rng.random_range(0..=999_999_999); timestamp.push(ScalarValue::TimestampNanosecond( Some( NaiveDate::from_ymd_opt(year, month, day) @@ -7233,27 +7519,27 @@ mod tests { let vector_size = sample_size; let mut intervals = vec![]; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); const SECS_IN_ONE_DAY: i32 = 86_400; const MICROSECS_IN_ONE_DAY: i64 = 86_400_000_000; for i in 0..vector_size { if i % 4 == 0 { - let days = rng.gen_range(0..5000); + let days = rng.random_range(0..5000); // to not break second precision - let millis = rng.gen_range(0..SECS_IN_ONE_DAY) * 1000; + let millis = rng.random_range(0..SECS_IN_ONE_DAY) * 1000; intervals.push(ScalarValue::new_interval_dt(days, millis)); } else if i % 4 == 1 { - let days = rng.gen_range(0..5000); - let millisec = rng.gen_range(0..(MILLISECS_IN_ONE_DAY as i32)); + let days = rng.random_range(0..5000); + let millisec = rng.random_range(0..(MILLISECS_IN_ONE_DAY as i32)); intervals.push(ScalarValue::new_interval_dt(days, millisec)); } else if i % 4 == 2 { - let days = rng.gen_range(0..5000); + let days = rng.random_range(0..5000); // to not break microsec precision - let nanosec = rng.gen_range(0..MICROSECS_IN_ONE_DAY) * 1000; + let nanosec = rng.random_range(0..MICROSECS_IN_ONE_DAY) * 1000; intervals.push(ScalarValue::new_interval_mdn(0, days, nanosec)); } else { - let days = rng.gen_range(0..5000); - let nanosec = rng.gen_range(0..NANOSECS_IN_ONE_DAY); + let days = rng.random_range(0..5000); + let nanosec = rng.random_range(0..NANOSECS_IN_ONE_DAY); intervals.push(ScalarValue::new_interval_mdn(0, days, nanosec)); } } diff --git a/datafusion/common/src/scalar/struct_builder.rs b/datafusion/common/src/scalar/struct_builder.rs index 5ed464018401..fd19dccf8963 100644 --- a/datafusion/common/src/scalar/struct_builder.rs +++ b/datafusion/common/src/scalar/struct_builder.rs @@ -17,7 +17,6 @@ //! [`ScalarStructBuilder`] for building [`ScalarValue::Struct`] -use crate::error::_internal_err; use crate::{Result, ScalarValue}; use arrow::array::{ArrayRef, StructArray}; use arrow::datatypes::{DataType, Field, FieldRef, Fields}; @@ -109,17 +108,8 @@ impl ScalarStructBuilder { pub fn build(self) -> Result { let Self { fields, arrays } = self; - for array in &arrays { - if array.len() != 1 { - return _internal_err!( - "Error building ScalarValue::Struct. \ - Expected array with exactly one element, found array with {} elements", - array.len() - ); - } - } - - let struct_array = StructArray::try_new(Fields::from(fields), arrays, None)?; + let struct_array = + StructArray::try_new_with_length(Fields::from(fields), arrays, None, 1)?; Ok(ScalarValue::Struct(Arc::new(struct_array))) } } @@ -181,3 +171,15 @@ impl IntoFields for Vec { Fields::from(self) } } + +#[cfg(test)] +mod tests { + use super::*; + + // Other cases are tested by doc tests + #[test] + fn test_empty_struct() { + let sv = ScalarStructBuilder::new().build().unwrap(); + assert_eq!(format!("{sv}"), "{}"); + } +} diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index 807d885b3a4d..a6d132ef51f6 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -233,8 +233,8 @@ impl Precision { impl Debug for Precision { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Precision::Exact(inner) => write!(f, "Exact({:?})", inner), - Precision::Inexact(inner) => write!(f, "Inexact({:?})", inner), + Precision::Exact(inner) => write!(f, "Exact({inner:?})"), + Precision::Inexact(inner) => write!(f, "Inexact({inner:?})"), Precision::Absent => write!(f, "Absent"), } } @@ -243,8 +243,8 @@ impl Debug for Precision { impl Display for Precision { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Precision::Exact(inner) => write!(f, "Exact({:?})", inner), - Precision::Inexact(inner) => write!(f, "Inexact({:?})", inner), + Precision::Exact(inner) => write!(f, "Exact({inner:?})"), + Precision::Inexact(inner) => write!(f, "Inexact({inner:?})"), Precision::Absent => write!(f, "Absent"), } } @@ -352,6 +352,7 @@ impl Statistics { return self; }; + #[allow(clippy::large_enum_variant)] enum Slot { /// The column is taken and put into the specified statistics location Taken(usize), @@ -451,6 +452,9 @@ impl Statistics { /// Summarize zero or more statistics into a single `Statistics` instance. /// + /// The method assumes that all statistics are for the same schema. + /// If not, maybe you can call `SchemaMapper::map_column_statistics` to make them consistent. + /// /// Returns an error if the statistics do not match the specified schemas. pub fn try_merge_iter<'a, I>(items: I, schema: &Schema) -> Result where @@ -569,7 +573,7 @@ impl Display for Statistics { .iter() .enumerate() .map(|(i, cs)| { - let s = format!("(Col[{}]:", i); + let s = format!("(Col[{i}]:"); let s = if cs.min_value != Precision::Absent { format!("{} Min={}", s, cs.min_value) } else { diff --git a/datafusion/common/src/test_util.rs b/datafusion/common/src/test_util.rs index b801c452af2c..820a230bf6e1 100644 --- a/datafusion/common/src/test_util.rs +++ b/datafusion/common/src/test_util.rs @@ -18,10 +18,25 @@ //! Utility functions to make testing DataFusion based crates easier use crate::arrow::util::pretty::pretty_format_batches_with_options; -use crate::format::DEFAULT_FORMAT_OPTIONS; -use arrow::array::RecordBatch; +use arrow::array::{ArrayRef, RecordBatch}; +use arrow::error::ArrowError; +use std::fmt::Display; use std::{error::Error, path::PathBuf}; +/// Converts a vector or array into an ArrayRef. +pub trait IntoArrayRef { + fn into_array_ref(self) -> ArrayRef; +} + +pub fn format_batches(results: &[RecordBatch]) -> Result { + let datafusion_format_options = crate::config::FormatOptions::default(); + + let arrow_format_options: arrow::util::display::FormatOptions = + (&datafusion_format_options).try_into().unwrap(); + + pretty_format_batches_with_options(results, &arrow_format_options) +} + /// Compares formatted output of a record batch with an expected /// vector of strings, with the result of pretty formatting record /// batches. This is a macro so errors appear on the correct line @@ -59,12 +74,9 @@ macro_rules! assert_batches_eq { let expected_lines: Vec = $EXPECTED_LINES.iter().map(|&s| s.into()).collect(); - let formatted = $crate::arrow::util::pretty::pretty_format_batches_with_options( - $CHUNKS, - &$crate::format::DEFAULT_FORMAT_OPTIONS, - ) - .unwrap() - .to_string(); + let formatted = $crate::test_util::format_batches($CHUNKS) + .unwrap() + .to_string(); let actual_lines: Vec<&str> = formatted.trim().lines().collect(); @@ -77,18 +89,13 @@ macro_rules! assert_batches_eq { } pub fn batches_to_string(batches: &[RecordBatch]) -> String { - let actual = pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS) - .unwrap() - .to_string(); + let actual = format_batches(batches).unwrap().to_string(); actual.trim().to_string() } pub fn batches_to_sort_string(batches: &[RecordBatch]) -> String { - let actual_lines = - pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS) - .unwrap() - .to_string(); + let actual_lines = format_batches(batches).unwrap().to_string(); let mut actual_lines: Vec<&str> = actual_lines.trim().lines().collect(); @@ -122,12 +129,9 @@ macro_rules! assert_batches_sorted_eq { expected_lines.as_mut_slice()[2..num_lines - 1].sort_unstable() } - let formatted = $crate::arrow::util::pretty::pretty_format_batches_with_options( - $CHUNKS, - &$crate::format::DEFAULT_FORMAT_OPTIONS, - ) - .unwrap() - .to_string(); + let formatted = $crate::test_util::format_batches($CHUNKS) + .unwrap() + .to_string(); // fix for windows: \r\n --> let mut actual_lines: Vec<&str> = formatted.trim().lines().collect(); @@ -384,6 +388,326 @@ macro_rules! record_batch { } } +pub mod array_conversion { + use arrow::array::ArrayRef; + + use super::IntoArrayRef; + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Boolean, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Boolean, self) + } + } + + impl IntoArrayRef for &[bool] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Boolean, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Boolean, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int8, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int8, self) + } + } + + impl IntoArrayRef for &[i8] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int8, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int8, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int16, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int16, self) + } + } + + impl IntoArrayRef for &[i16] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int16, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int16, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int32, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int32, self) + } + } + + impl IntoArrayRef for &[i32] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int32, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int32, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int64, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int64, self) + } + } + + impl IntoArrayRef for &[i64] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int64, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Int64, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt8, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt8, self) + } + } + + impl IntoArrayRef for &[u8] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt8, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt8, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt16, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt16, self) + } + } + + impl IntoArrayRef for &[u16] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt16, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt16, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt32, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt32, self) + } + } + + impl IntoArrayRef for &[u32] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt32, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt32, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt64, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt64, self) + } + } + + impl IntoArrayRef for &[u64] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt64, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(UInt64, self.to_vec()) + } + } + + //#TODO add impl for f16 + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float32, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float32, self) + } + } + + impl IntoArrayRef for &[f32] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float32, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float32, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float64, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float64, self) + } + } + + impl IntoArrayRef for &[f64] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float64, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Float64, self.to_vec()) + } + } + + impl IntoArrayRef for Vec<&str> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self) + } + } + + impl IntoArrayRef for &[&str] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option<&str>] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self.to_vec()) + } + } + + impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self) + } + } + + impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self) + } + } + + impl IntoArrayRef for &[String] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self.to_vec()) + } + } + + impl IntoArrayRef for &[Option] { + fn into_array_ref(self) -> ArrayRef { + create_array!(Utf8, self.to_vec()) + } + } +} + #[cfg(test)] mod tests { use crate::cast::{as_float64_array, as_int32_array, as_string_array}; diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index c70389b63177..cf51dadf6b4a 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -2354,7 +2354,7 @@ pub(crate) mod tests { fn test_large_tree() { let mut item = TestTreeNode::new_leaf("initial".to_string()); for i in 0..3000 { - item = TestTreeNode::new(vec![item], format!("parent-{}", i)); + item = TestTreeNode::new(vec![item], format!("parent-{i}")); } let mut visitor = diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 409f248621f7..c09859c46e15 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -950,7 +950,7 @@ pub fn get_available_parallelism() -> usize { .get() } -/// Converts a collection of function arguments into an fixed-size array of length N +/// Converts a collection of function arguments into a fixed-size array of length N /// producing a reasonable error message in case of unexpected number of arguments. /// /// # Example diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index edc0d34b539a..9747f4424060 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -78,6 +78,7 @@ recursive_protection = [ "datafusion-optimizer/recursive_protection", "datafusion-physical-optimizer/recursive_protection", "datafusion-sql/recursive_protection", + "sqlparser/recursive-protection", ] serde = [ "dep:serde", @@ -95,10 +96,10 @@ extended_tests = [] [dependencies] arrow = { workspace = true } arrow-ipc = { workspace = true } -arrow-schema = { workspace = true } +arrow-schema = { workspace = true, features = ["canonical_extension_types"] } async-trait = { workspace = true } bytes = { workspace = true } -bzip2 = { version = "0.5.2", optional = true } +bzip2 = { version = "0.6.0", optional = true } chrono = { workspace = true } datafusion-catalog = { workspace = true } datafusion-catalog-listing = { workspace = true } @@ -117,7 +118,6 @@ datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true, optional = true } datafusion-functions-table = { workspace = true } datafusion-functions-window = { workspace = true } -datafusion-macros = { workspace = true } datafusion-optimizer = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } @@ -125,7 +125,7 @@ datafusion-physical-optimizer = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-session = { workspace = true } datafusion-sql = { workspace = true } -flate2 = { version = "1.1.1", optional = true } +flate2 = { version = "1.1.2", optional = true } futures = { workspace = true } itertools = { workspace = true } log = { workspace = true } @@ -139,7 +139,7 @@ sqlparser = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true } url = { workspace = true } -uuid = { version = "1.16", features = ["v4", "js"] } +uuid = { version = "1.17", features = ["v4", "js"] } xz2 = { version = "0.1", optional = true, features = ["static"] } zstd = { version = "0.13", optional = true, default-features = false } @@ -150,22 +150,23 @@ ctor = { workspace = true } dashmap = "6.1.0" datafusion-doc = { workspace = true } datafusion-functions-window-common = { workspace = true } +datafusion-macros = { workspace = true } datafusion-physical-optimizer = { workspace = true } doc-comment = { workspace = true } env_logger = { workspace = true } insta = { workspace = true } paste = "^1.0" rand = { workspace = true, features = ["small_rng"] } -rand_distr = "0.4.3" +rand_distr = "0.5" regex = { workspace = true } rstest = { workspace = true } serde_json = { workspace = true } -sysinfo = "0.34.2" +sysinfo = "0.35.2" test-utils = { path = "../../test-utils" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] } [target.'cfg(not(target_os = "windows"))'.dev-dependencies] -nix = { version = "0.29.0", features = ["fs"] } +nix = { version = "0.30.1", features = ["fs"] } [[bench]] harness = false @@ -179,6 +180,10 @@ name = "csv_load" harness = false name = "distinct_query_sql" +[[bench]] +harness = false +name = "push_down_filter" + [[bench]] harness = false name = "sort_limit_query_sql" diff --git a/datafusion/core/benches/aggregate_query_sql.rs b/datafusion/core/benches/aggregate_query_sql.rs index b29bfc487340..057a0e1d1b54 100644 --- a/datafusion/core/benches/aggregate_query_sql.rs +++ b/datafusion/core/benches/aggregate_query_sql.rs @@ -158,7 +158,7 @@ fn criterion_benchmark(c: &mut Criterion) { query( ctx.clone(), &rt, - "SELECT utf8, approx_percentile_cont(u64_wide, 0.5, 2500) \ + "SELECT utf8, approx_percentile_cont(0.5, 2500) WITHIN GROUP (ORDER BY u64_wide) \ FROM t GROUP BY utf8", ) }) @@ -169,7 +169,7 @@ fn criterion_benchmark(c: &mut Criterion) { query( ctx.clone(), &rt, - "SELECT utf8, approx_percentile_cont(f32, 0.5, 2500) \ + "SELECT utf8, approx_percentile_cont(0.5, 2500) WITHIN GROUP (ORDER BY f32) \ FROM t GROUP BY utf8", ) }) diff --git a/datafusion/core/benches/data_utils/mod.rs b/datafusion/core/benches/data_utils/mod.rs index fc5f8945c439..c0477b1306f7 100644 --- a/datafusion/core/benches/data_utils/mod.rs +++ b/datafusion/core/benches/data_utils/mod.rs @@ -26,8 +26,8 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::MemTable; use datafusion::error::Result; use datafusion_common::DataFusionError; +use rand::prelude::IndexedRandom; use rand::rngs::StdRng; -use rand::seq::SliceRandom; use rand::{Rng, SeedableRng}; use rand_distr::Distribution; use rand_distr::{Normal, Pareto}; @@ -49,11 +49,6 @@ pub fn create_table_provider( MemTable::try_new(schema, partitions).map(Arc::new) } -/// create a seedable [`StdRng`](rand::StdRng) -fn seedable_rng() -> StdRng { - StdRng::seed_from_u64(42) -} - /// Create test data schema pub fn create_schema() -> Schema { Schema::new(vec![ @@ -73,14 +68,14 @@ pub fn create_schema() -> Schema { fn create_data(size: usize, null_density: f64) -> Vec> { // use random numbers to avoid spurious compiler optimizations wrt to branching - let mut rng = seedable_rng(); + let mut rng = StdRng::seed_from_u64(42); (0..size) .map(|_| { - if rng.gen::() > null_density { + if rng.random::() > null_density { None } else { - Some(rng.gen::()) + Some(rng.random::()) } }) .collect() @@ -88,14 +83,14 @@ fn create_data(size: usize, null_density: f64) -> Vec> { fn create_integer_data(size: usize, value_density: f64) -> Vec> { // use random numbers to avoid spurious compiler optimizations wrt to branching - let mut rng = seedable_rng(); + let mut rng = StdRng::seed_from_u64(42); (0..size) .map(|_| { - if rng.gen::() > value_density { + if rng.random::() > value_density { None } else { - Some(rng.gen::()) + Some(rng.random::()) } }) .collect() @@ -125,7 +120,7 @@ fn create_record_batch( // Integer values between [0, 9]. let integer_values_narrow = (0..batch_size) - .map(|_| rng.gen_range(0_u64..10)) + .map(|_| rng.random_range(0_u64..10)) .collect::>(); RecordBatch::try_new( @@ -149,7 +144,7 @@ pub fn create_record_batches( partitions_len: usize, batch_size: usize, ) -> Vec> { - let mut rng = seedable_rng(); + let mut rng = StdRng::seed_from_u64(42); (0..partitions_len) .map(|_| { (0..array_len / batch_size / partitions_len) @@ -217,7 +212,7 @@ pub(crate) fn make_data( let mut ts_builder = Int64Builder::new(); let gen_id = |rng: &mut rand::rngs::SmallRng| { - rng.gen::<[u8; 16]>() + rng.random::<[u8; 16]>() .iter() .fold(String::new(), |mut output, b| { let _ = write!(output, "{b:02X}"); @@ -233,7 +228,7 @@ pub(crate) fn make_data( .map(|_| gen_sample_cnt(&mut rng)) .collect::>(); for _ in 0..sample_cnt { - let random_index = rng.gen_range(0..simultaneous_group_cnt); + let random_index = rng.random_range(0..simultaneous_group_cnt); let trace_id = &mut group_ids[random_index]; let sample_cnt = &mut group_sample_cnts[random_index]; *sample_cnt -= 1; diff --git a/datafusion/core/benches/dataframe.rs b/datafusion/core/benches/dataframe.rs index 832553ebed82..12eb34719e4b 100644 --- a/datafusion/core/benches/dataframe.rs +++ b/datafusion/core/benches/dataframe.rs @@ -32,7 +32,7 @@ use tokio::runtime::Runtime; fn create_context(field_count: u32) -> datafusion_common::Result> { let mut fields = vec![]; for i in 0..field_count { - fields.push(Field::new(format!("str{}", i), DataType::Utf8, true)) + fields.push(Field::new(format!("str{i}"), DataType::Utf8, true)) } let schema = Arc::new(Schema::new(fields)); @@ -49,8 +49,8 @@ fn run(column_count: u32, ctx: Arc, rt: &Runtime) { let mut data_frame = ctx.table("t").await.unwrap(); for i in 0..column_count { - let field_name = &format!("str{}", i); - let new_field_name = &format!("newstr{}", i); + let field_name = &format!("str{i}"); + let new_field_name = &format!("newstr{i}"); data_frame = data_frame .with_column_renamed(field_name, new_field_name) diff --git a/datafusion/core/benches/distinct_query_sql.rs b/datafusion/core/benches/distinct_query_sql.rs index c7056aab8689..c1ef55992689 100644 --- a/datafusion/core/benches/distinct_query_sql.rs +++ b/datafusion/core/benches/distinct_query_sql.rs @@ -154,7 +154,7 @@ fn criterion_benchmark_limited_distinct_sampled(c: &mut Criterion) { let sql = format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); c.bench_function( - format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + format!("distinct query with {partitions} partitions and {samples} samples per partition with limit {limit}").as_str(), |b| b.iter(|| { let (plan, ctx) = rt.block_on( create_context_sampled_data(sql.as_str(), partitions, samples) @@ -168,7 +168,7 @@ fn criterion_benchmark_limited_distinct_sampled(c: &mut Criterion) { let sql = format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); c.bench_function( - format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + format!("distinct query with {partitions} partitions and {samples} samples per partition with limit {limit}").as_str(), |b| b.iter(|| { let (plan, ctx) = rt.block_on( create_context_sampled_data(sql.as_str(), partitions, samples) @@ -182,7 +182,7 @@ fn criterion_benchmark_limited_distinct_sampled(c: &mut Criterion) { let sql = format!("select DISTINCT trace_id from traces group by trace_id limit {limit};"); c.bench_function( - format!("distinct query with {} partitions and {} samples per partition with limit {}", partitions, samples, limit).as_str(), + format!("distinct query with {partitions} partitions and {samples} samples per partition with limit {limit}").as_str(), |b| b.iter(|| { let (plan, ctx) = rt.block_on( create_context_sampled_data(sql.as_str(), partitions, samples) diff --git a/datafusion/core/benches/map_query_sql.rs b/datafusion/core/benches/map_query_sql.rs index 79229dfc2fbd..063b8e6c86bb 100644 --- a/datafusion/core/benches/map_query_sql.rs +++ b/datafusion/core/benches/map_query_sql.rs @@ -34,7 +34,7 @@ mod data_utils; fn build_keys(rng: &mut ThreadRng) -> Vec { let mut keys = vec![]; for _ in 0..1000 { - keys.push(rng.gen_range(0..9999).to_string()); + keys.push(rng.random_range(0..9999).to_string()); } keys } @@ -42,7 +42,7 @@ fn build_keys(rng: &mut ThreadRng) -> Vec { fn build_values(rng: &mut ThreadRng) -> Vec { let mut values = vec![]; for _ in 0..1000 { - values.push(rng.gen_range(0..9999)); + values.push(rng.random_range(0..9999)); } values } @@ -64,15 +64,18 @@ fn criterion_benchmark(c: &mut Criterion) { let rt = Runtime::new().unwrap(); let df = rt.block_on(ctx.lock().table("t")).unwrap(); - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let keys = build_keys(&mut rng); let values = build_values(&mut rng); let mut key_buffer = Vec::new(); let mut value_buffer = Vec::new(); for i in 0..1000 { - key_buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); - value_buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); + key_buffer.push(Expr::Literal( + ScalarValue::Utf8(Some(keys[i].clone())), + None, + )); + value_buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])), None)); } c.bench_function("map_1000_1", |b| { b.iter(|| { diff --git a/datafusion/core/benches/parquet_query_sql.rs b/datafusion/core/benches/parquet_query_sql.rs index f82a126c5652..14dcdf15f173 100644 --- a/datafusion/core/benches/parquet_query_sql.rs +++ b/datafusion/core/benches/parquet_query_sql.rs @@ -29,9 +29,10 @@ use datafusion_common::instant::Instant; use futures::stream::StreamExt; use parquet::arrow::ArrowWriter; use parquet::file::properties::{WriterProperties, WriterVersion}; -use rand::distributions::uniform::SampleUniform; -use rand::distributions::Alphanumeric; +use rand::distr::uniform::SampleUniform; +use rand::distr::Alphanumeric; use rand::prelude::*; +use rand::rng; use std::fs::File; use std::io::Read; use std::ops::Range; @@ -97,13 +98,13 @@ fn generate_string_dictionary( len: usize, valid_percent: f64, ) -> ArrayRef { - let mut rng = thread_rng(); + let mut rng = rng(); let strings: Vec<_> = (0..cardinality).map(|x| format!("{prefix}#{x}")).collect(); Arc::new(DictionaryArray::::from_iter((0..len).map( |_| { - rng.gen_bool(valid_percent) - .then(|| strings[rng.gen_range(0..cardinality)].as_str()) + rng.random_bool(valid_percent) + .then(|| strings[rng.random_range(0..cardinality)].as_str()) }, ))) } @@ -113,10 +114,10 @@ fn generate_strings( len: usize, valid_percent: f64, ) -> ArrayRef { - let mut rng = thread_rng(); + let mut rng = rng(); Arc::new(StringArray::from_iter((0..len).map(|_| { - rng.gen_bool(valid_percent).then(|| { - let string_len = rng.gen_range(string_length_range.clone()); + rng.random_bool(valid_percent).then(|| { + let string_len = rng.random_range(string_length_range.clone()); (0..string_len) .map(|_| char::from(rng.sample(Alphanumeric))) .collect::() @@ -133,10 +134,10 @@ where T: ArrowPrimitiveType, T::Native: SampleUniform, { - let mut rng = thread_rng(); + let mut rng = rng(); Arc::new(PrimitiveArray::::from_iter((0..len).map(|_| { - rng.gen_bool(valid_percent) - .then(|| rng.gen_range(range.clone())) + rng.random_bool(valid_percent) + .then(|| rng.random_range(range.clone())) }))) } diff --git a/datafusion/core/benches/physical_plan.rs b/datafusion/core/benches/physical_plan.rs index 0a65c52f72de..e4838572f60f 100644 --- a/datafusion/core/benches/physical_plan.rs +++ b/datafusion/core/benches/physical_plan.rs @@ -50,11 +50,8 @@ fn sort_preserving_merge_operator( let sort = sort .iter() - .map(|name| PhysicalSortExpr { - expr: col(name, &schema).unwrap(), - options: Default::default(), - }) - .collect::(); + .map(|name| PhysicalSortExpr::new_default(col(name, &schema).unwrap())); + let sort = LexOrdering::new(sort).unwrap(); let exec = MemorySourceConfig::try_new_exec( &batches.into_iter().map(|rb| vec![rb]).collect::>(), diff --git a/datafusion/core/benches/push_down_filter.rs b/datafusion/core/benches/push_down_filter.rs new file mode 100644 index 000000000000..139fb12c3094 --- /dev/null +++ b/datafusion/core/benches/push_down_filter.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Field, Schema}; +use bytes::{BufMut, BytesMut}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion::config::ConfigOptions; +use datafusion::prelude::{ParquetReadOptions, SessionContext}; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::ExecutionPlan; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::ObjectStore; +use parquet::arrow::ArrowWriter; +use std::sync::Arc; + +async fn create_plan() -> Arc { + let ctx = SessionContext::new(); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + Field::new("age", DataType::UInt16, true), + Field::new("salary", DataType::Float64, true), + ])); + let batch = RecordBatch::new_empty(schema); + + let store = Arc::new(InMemory::new()) as Arc; + let mut out = BytesMut::new().writer(); + { + let mut writer = ArrowWriter::try_new(&mut out, batch.schema(), None).unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + } + let data = out.into_inner().freeze(); + store + .put(&Path::from("test.parquet"), data.into()) + .await + .unwrap(); + ctx.register_object_store( + ObjectStoreUrl::parse("memory://").unwrap().as_ref(), + store, + ); + + ctx.register_parquet("t", "memory:///", ParquetReadOptions::default()) + .await + .unwrap(); + + let df = ctx + .sql( + r" + WITH brackets AS ( + SELECT age % 10 AS age_bracket + FROM t + GROUP BY age % 10 + HAVING COUNT(*) > 10 + ) + SELECT id, name, age, salary + FROM t + JOIN brackets ON t.age % 10 = brackets.age_bracket + WHERE age > 20 AND t.salary > 1000 + ORDER BY t.salary DESC + LIMIT 100 + ", + ) + .await + .unwrap(); + + df.create_physical_plan().await.unwrap() +} + +#[derive(Clone)] +struct BenchmarkPlan { + plan: Arc, + config: ConfigOptions, +} + +impl std::fmt::Display for BenchmarkPlan { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "BenchmarkPlan") + } +} + +fn bench_push_down_filter(c: &mut Criterion) { + // Create a relatively complex plan + let plan = tokio::runtime::Runtime::new() + .unwrap() + .block_on(create_plan()); + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + let plan = BenchmarkPlan { plan, config }; + let optimizer = FilterPushdown::new(); + + c.bench_function("push_down_filter", |b| { + b.iter(|| { + optimizer + .optimize(Arc::clone(&plan.plan), &plan.config) + .unwrap(); + }); + }); +} + +// It's a bit absurd that it's this complicated but to generate a flamegraph you can run: +// `cargo flamegraph -p datafusion --bench push_down_filter --flamechart --root --profile profiling --freq 1000 -- --bench` +// See https://github.com/flamegraph-rs/flamegraph +criterion_group!(benches, bench_push_down_filter); +criterion_main!(benches); diff --git a/datafusion/core/benches/sort.rs b/datafusion/core/benches/sort.rs index 85f456ce5dc2..276151e253f7 100644 --- a/datafusion/core/benches/sort.rs +++ b/datafusion/core/benches/sort.rs @@ -71,7 +71,6 @@ use std::sync::Arc; use arrow::array::StringViewArray; use arrow::{ array::{DictionaryArray, Float64Array, Int64Array, StringArray}, - compute::SortOptions, datatypes::{Int32Type, Schema}, record_batch::RecordBatch, }; @@ -272,14 +271,11 @@ impl BenchCase { /// Make sort exprs for each column in `schema` fn make_sort_exprs(schema: &Schema) -> LexOrdering { - schema + let sort_exprs = schema .fields() .iter() - .map(|f| PhysicalSortExpr { - expr: col(f.name(), schema).unwrap(), - options: SortOptions::default(), - }) - .collect() + .map(|f| PhysicalSortExpr::new_default(col(f.name(), schema).unwrap())); + LexOrdering::new(sort_exprs).unwrap() } /// Create streams of int64 (where approximately 1/3 values is repeated) @@ -595,7 +591,7 @@ impl DataGenerator { /// Create an array of i64 sorted values (where approximately 1/3 values is repeated) fn i64_values(&mut self) -> Vec { let mut vec: Vec<_> = (0..INPUT_SIZE) - .map(|_| self.rng.gen_range(0..INPUT_SIZE as i64)) + .map(|_| self.rng.random_range(0..INPUT_SIZE as i64)) .collect(); vec.sort_unstable(); @@ -620,7 +616,7 @@ impl DataGenerator { // pick from the 100 strings randomly let mut input = (0..INPUT_SIZE) .map(|_| { - let idx = self.rng.gen_range(0..strings.len()); + let idx = self.rng.random_range(0..strings.len()); let s = Arc::clone(&strings[idx]); Some(s) }) @@ -643,7 +639,7 @@ impl DataGenerator { fn random_string(&mut self) -> String { let rng = &mut self.rng; - rng.sample_iter(rand::distributions::Alphanumeric) + rng.sample_iter(rand::distr::Alphanumeric) .filter(|c| c.is_ascii_alphabetic()) .take(20) .map(char::from) @@ -665,7 +661,7 @@ where let mut outputs: Vec>> = (0..NUM_STREAMS).map(|_| Vec::new()).collect(); for i in input { - let stream_idx = rng.gen_range(0..NUM_STREAMS); + let stream_idx = rng.random_range(0..NUM_STREAMS); let stream = &mut outputs[stream_idx]; match stream.last_mut() { Some(x) if x.len() < BATCH_SIZE => x.push(i), diff --git a/datafusion/core/benches/spm.rs b/datafusion/core/benches/spm.rs index 63b06f20cd86..d13407864297 100644 --- a/datafusion/core/benches/spm.rs +++ b/datafusion/core/benches/spm.rs @@ -21,7 +21,6 @@ use arrow::array::{ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray}; use datafusion_execution::TaskContext; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::{collect, ExecutionPlan}; @@ -70,7 +69,7 @@ fn generate_spm_for_round_robin_tie_breaker( let partitiones = vec![rbs.clone(); partition_count]; let schema = rb.schema(); - let sort = LexOrdering::new(vec![ + let sort = [ PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: Default::default(), @@ -79,7 +78,8 @@ fn generate_spm_for_round_robin_tie_breaker( expr: col("c", &schema).unwrap(), options: Default::default(), }, - ]); + ] + .into(); let exec = MemorySourceConfig::try_new_exec(&partitiones, schema, None).unwrap(); SortPreservingMergeExec::new(sort, exec) @@ -125,8 +125,7 @@ fn criterion_benchmark(c: &mut Criterion) { for &batch_count in &batch_counts { for &partition_count in &partition_counts { let description = format!( - "{}_batch_count_{}_partition_count_{}", - cardinality_label, batch_count, partition_count + "{cardinality_label}_batch_count_{batch_count}_partition_count_{partition_count}" ); run_bench( c, diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 49cc830d58bc..d02478d2b479 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -30,9 +30,6 @@ use datafusion::datasource::MemTable; use datafusion::execution::context::SessionContext; use datafusion_common::ScalarValue; use datafusion_expr::col; -use itertools::Itertools; -use std::fs::File; -use std::io::{BufRead, BufReader}; use std::path::PathBuf; use std::sync::Arc; use test_utils::tpcds::tpcds_schemas; @@ -136,10 +133,10 @@ fn benchmark_with_param_values_many_columns( if i > 0 { aggregates.push_str(", "); } - aggregates.push_str(format!("MAX(a{})", i).as_str()); + aggregates.push_str(format!("MAX(a{i})").as_str()); } // SELECT max(attr0), ..., max(attrN) FROM t1. - let query = format!("SELECT {} FROM t1", aggregates); + let query = format!("SELECT {aggregates} FROM t1"); let statement = ctx.state().sql_to_statement(&query, "Generic").unwrap(); let plan = rt.block_on(async { ctx.state().statement_to_plan(statement).await.unwrap() }); @@ -164,7 +161,7 @@ fn register_union_order_table(ctx: &SessionContext, num_columns: usize, num_rows .map(|j| j as u64 * 100 + i) .collect::>(), )); - (format!("c{}", i), array) + (format!("c{i}"), array) }); let batch = RecordBatch::try_from_iter(iter).unwrap(); let schema = batch.schema(); @@ -172,7 +169,7 @@ fn register_union_order_table(ctx: &SessionContext, num_columns: usize, num_rows // tell DataFusion that the table is sorted by all columns let sort_order = (0..num_columns) - .map(|i| col(format!("c{}", i)).sort(true, true)) + .map(|i| col(format!("c{i}")).sort(true, true)) .collect::>(); // create the table @@ -208,12 +205,12 @@ fn union_orderby_query(n: usize) -> String { }) .collect::>() .join(", "); - query.push_str(&format!("(SELECT {} FROM t ORDER BY c{})", select_list, i)); + query.push_str(&format!("(SELECT {select_list} FROM t ORDER BY c{i})")); } query.push_str(&format!( "\nORDER BY {}", (0..n) - .map(|i| format!("c{}", i)) + .map(|i| format!("c{i}")) .collect::>() .join(", ") )); @@ -293,9 +290,9 @@ fn criterion_benchmark(c: &mut Criterion) { if i > 0 { aggregates.push_str(", "); } - aggregates.push_str(format!("MAX(a{})", i).as_str()); + aggregates.push_str(format!("MAX(a{i})").as_str()); } - let query = format!("SELECT {} FROM t1", aggregates); + let query = format!("SELECT {aggregates} FROM t1"); b.iter(|| { physical_plan(&ctx, &rt, &query); }); @@ -402,7 +399,7 @@ fn criterion_benchmark(c: &mut Criterion) { for q in tpch_queries { let sql = std::fs::read_to_string(format!("{benchmarks_path}queries/{q}.sql")).unwrap(); - c.bench_function(&format!("physical_plan_tpch_{}", q), |b| { + c.bench_function(&format!("physical_plan_tpch_{q}"), |b| { b.iter(|| physical_plan(&tpch_ctx, &rt, &sql)) }); } @@ -466,17 +463,20 @@ fn criterion_benchmark(c: &mut Criterion) { // }); // -- clickbench -- - - let queries_file = - File::open(format!("{benchmarks_path}queries/clickbench/queries.sql")).unwrap(); - let extended_file = - File::open(format!("{benchmarks_path}queries/clickbench/extended.sql")).unwrap(); - - let clickbench_queries: Vec = BufReader::new(queries_file) - .lines() - .chain(BufReader::new(extended_file).lines()) - .map(|l| l.expect("Could not parse line")) - .collect_vec(); + let clickbench_queries = (0..=42) + .map(|q| { + std::fs::read_to_string(format!( + "{benchmarks_path}queries/clickbench/queries/q{q}.sql" + )) + .unwrap() + }) + .chain((0..=7).map(|q| { + std::fs::read_to_string(format!( + "{benchmarks_path}queries/clickbench/extended/q{q}.sql" + )) + .unwrap() + })) + .collect::>(); let clickbench_ctx = register_clickbench_hits_table(&rt); diff --git a/datafusion/core/benches/sql_query_with_io.rs b/datafusion/core/benches/sql_query_with_io.rs index 58d71ee5b2eb..58797dfed6b6 100644 --- a/datafusion/core/benches/sql_query_with_io.rs +++ b/datafusion/core/benches/sql_query_with_io.rs @@ -66,7 +66,7 @@ fn create_parquet_file(rng: &mut StdRng, id_offset: usize) -> Bytes { let mut payload_builder = Int64Builder::new(); for row in 0..FILE_ROWS { id_builder.append_value((row + id_offset) as u64); - payload_builder.append_value(rng.gen()); + payload_builder.append_value(rng.random()); } let batch = RecordBatch::try_new( Arc::clone(&schema), diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index 7afb90282a80..1044717aaffb 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -46,7 +46,7 @@ fn main() -> Result<()> { "scalar" => print_scalar_docs(), "window" => print_window_docs(), _ => { - panic!("Unknown function type: {}", function_type) + panic!("Unknown function type: {function_type}") } }?; @@ -92,7 +92,7 @@ fn print_window_docs() -> Result { fn save_doc_code_text(documentation: &Documentation, name: &str) { let attr_text = documentation.to_doc_attribute(); - let file_path = format!("{}.txt", name); + let file_path = format!("{name}.txt"); if std::path::Path::new(&file_path).exists() { std::fs::remove_file(&file_path).unwrap(); } @@ -215,16 +215,15 @@ fn print_docs( r#" #### Example -{} -"#, - example +{example} +"# ); } if let Some(alt_syntax) = &documentation.alternative_syntax { let _ = writeln!(docs, "#### Alternative Syntax\n"); for syntax in alt_syntax { - let _ = writeln!(docs, "```sql\n{}\n```", syntax); + let _ = writeln!(docs, "```sql\n{syntax}\n```"); } } diff --git a/datafusion/core/src/bin/print_runtime_config_docs.rs b/datafusion/core/src/bin/print_runtime_config_docs.rs new file mode 100644 index 000000000000..31425da73d35 --- /dev/null +++ b/datafusion/core/src/bin/print_runtime_config_docs.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_execution::runtime_env::RuntimeEnvBuilder; + +fn main() { + let docs = RuntimeEnvBuilder::generate_config_markdown(); + println!("{docs}"); +} diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 9a70f8f43fb6..7101a30c5df0 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -33,8 +33,8 @@ use crate::execution::context::{SessionState, TaskContext}; use crate::execution::FunctionRegistry; use crate::logical_expr::utils::find_window_exprs; use crate::logical_expr::{ - col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, LogicalPlanBuilderOptions, - Partitioning, TableType, + col, ident, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, + LogicalPlanBuilderOptions, Partitioning, TableType, }; use crate::physical_plan::{ collect, collect_partitioned, execute_stream, execute_stream_partitioned, @@ -166,9 +166,12 @@ impl Default for DataFrameWriteOptions { /// /// # Example /// ``` +/// # use std::sync::Arc; /// # use datafusion::prelude::*; /// # use datafusion::error::Result; /// # use datafusion::functions_aggregate::expr_fn::min; +/// # use datafusion::arrow::array::{Int32Array, RecordBatch, StringArray}; +/// # use datafusion::arrow::datatypes::{DataType, Field, Schema}; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); @@ -181,6 +184,28 @@ impl Default for DataFrameWriteOptions { /// .limit(0, Some(100))?; /// // Perform the actual computation /// let results = df.collect(); +/// +/// // Create a new dataframe with in-memory data +/// let schema = Schema::new(vec![ +/// Field::new("id", DataType::Int32, true), +/// Field::new("name", DataType::Utf8, true), +/// ]); +/// let batch = RecordBatch::try_new( +/// Arc::new(schema), +/// vec![ +/// Arc::new(Int32Array::from(vec![1, 2, 3])), +/// Arc::new(StringArray::from(vec!["foo", "bar", "baz"])), +/// ], +/// )?; +/// let df = ctx.read_batch(batch)?; +/// df.show().await?; +/// +/// // Create a new dataframe with in-memory data using macro +/// let df = dataframe!( +/// "id" => [1, 2, 3], +/// "name" => ["foo", "bar", "baz"] +/// )?; +/// df.show().await?; /// # Ok(()) /// # } /// ``` @@ -350,15 +375,12 @@ impl DataFrame { let expr_list: Vec = expr_list.into_iter().map(|e| e.into()).collect::>(); - let expressions = expr_list - .iter() - .filter_map(|e| match e { - SelectExpr::Expression(expr) => Some(expr.clone()), - _ => None, - }) - .collect::>(); + let expressions = expr_list.iter().filter_map(|e| match e { + SelectExpr::Expression(expr) => Some(expr), + _ => None, + }); - let window_func_exprs = find_window_exprs(&expressions); + let window_func_exprs = find_window_exprs(expressions); let plan = if window_func_exprs.is_empty() { self.plan } else { @@ -934,7 +956,7 @@ impl DataFrame { vec![], original_schema_fields .clone() - .map(|f| count(col(f.name())).alias(f.name())) + .map(|f| count(ident(f.name())).alias(f.name())) .collect::>(), ), // null_count aggregation @@ -943,7 +965,7 @@ impl DataFrame { original_schema_fields .clone() .map(|f| { - sum(case(is_null(col(f.name()))) + sum(case(is_null(ident(f.name()))) .when(lit(true), lit(1)) .otherwise(lit(0)) .unwrap()) @@ -957,7 +979,7 @@ impl DataFrame { original_schema_fields .clone() .filter(|f| f.data_type().is_numeric()) - .map(|f| avg(col(f.name())).alias(f.name())) + .map(|f| avg(ident(f.name())).alias(f.name())) .collect::>(), ), // std aggregation @@ -966,7 +988,7 @@ impl DataFrame { original_schema_fields .clone() .filter(|f| f.data_type().is_numeric()) - .map(|f| stddev(col(f.name())).alias(f.name())) + .map(|f| stddev(ident(f.name())).alias(f.name())) .collect::>(), ), // min aggregation @@ -977,7 +999,7 @@ impl DataFrame { .filter(|f| { !matches!(f.data_type(), DataType::Binary | DataType::Boolean) }) - .map(|f| min(col(f.name())).alias(f.name())) + .map(|f| min(ident(f.name())).alias(f.name())) .collect::>(), ), // max aggregation @@ -988,7 +1010,7 @@ impl DataFrame { .filter(|f| { !matches!(f.data_type(), DataType::Binary | DataType::Boolean) }) - .map(|f| max(col(f.name())).alias(f.name())) + .map(|f| max(ident(f.name())).alias(f.name())) .collect::>(), ), // median aggregation @@ -997,7 +1019,7 @@ impl DataFrame { original_schema_fields .clone() .filter(|f| f.data_type().is_numeric()) - .map(|f| median(col(f.name())).alias(f.name())) + .map(|f| median(ident(f.name())).alias(f.name())) .collect::>(), ), ]; @@ -1312,7 +1334,10 @@ impl DataFrame { /// ``` pub async fn count(self) -> Result { let rows = self - .aggregate(vec![], vec![count(Expr::Literal(COUNT_STAR_EXPANSION))])? + .aggregate( + vec![], + vec![count(Expr::Literal(COUNT_STAR_EXPANSION, None))], + )? .collect() .await?; let len = *rows @@ -1366,8 +1391,47 @@ impl DataFrame { /// # } /// ``` pub async fn show(self) -> Result<()> { + println!("{}", self.to_string().await?); + Ok(()) + } + + /// Execute the `DataFrame` and return a string representation of the results. + /// + /// # Example + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # use datafusion::execution::SessionStateBuilder; + /// + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let cfg = SessionConfig::new() + /// .set_str("datafusion.format.null", "no-value"); + /// let session_state = SessionStateBuilder::new() + /// .with_config(cfg) + /// .with_default_features() + /// .build(); + /// let ctx = SessionContext::new_with_state(session_state); + /// let df = ctx.sql("select null as 'null-column'").await?; + /// let result = df.to_string().await?; + /// assert_eq!(result, + /// "+-------------+ + /// | null-column | + /// +-------------+ + /// | no-value | + /// +-------------+" + /// ); + /// # Ok(()) + /// # } + pub async fn to_string(self) -> Result { + let options = self.session_state.config().options().format.clone(); + let arrow_options: arrow::util::display::FormatOptions = (&options).try_into()?; + let results = self.collect().await?; - Ok(pretty::print_batches(&results)?) + Ok( + pretty::pretty_format_batches_with_options(&results, &arrow_options)? + .to_string(), + ) } /// Execute the `DataFrame` and print only the first `num` rows of the @@ -1856,7 +1920,7 @@ impl DataFrame { /// # } /// ``` pub fn with_column(self, name: &str, expr: Expr) -> Result { - let window_func_exprs = find_window_exprs(std::slice::from_ref(&expr)); + let window_func_exprs = find_window_exprs([&expr]); let (window_fn_str, plan) = if window_func_exprs.is_empty() { (None, self.plan) @@ -2160,6 +2224,94 @@ impl DataFrame { }) .collect() } + + /// Helper for creating DataFrame. + /// # Example + /// ``` + /// use std::sync::Arc; + /// use arrow::array::{ArrayRef, Int32Array, StringArray}; + /// use datafusion::prelude::DataFrame; + /// let id: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + /// let name: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar", "baz"])); + /// let df = DataFrame::from_columns(vec![("id", id), ("name", name)]).unwrap(); + /// // +----+------+, + /// // | id | name |, + /// // +----+------+, + /// // | 1 | foo |, + /// // | 2 | bar |, + /// // | 3 | baz |, + /// // +----+------+, + /// ``` + pub fn from_columns(columns: Vec<(&str, ArrayRef)>) -> Result { + let fields = columns + .iter() + .map(|(name, array)| Field::new(*name, array.data_type().clone(), true)) + .collect::>(); + + let arrays = columns + .into_iter() + .map(|(_, array)| array) + .collect::>(); + + let schema = Arc::new(Schema::new(fields)); + let batch = RecordBatch::try_new(schema, arrays)?; + let ctx = SessionContext::new(); + let df = ctx.read_batch(batch)?; + Ok(df) + } +} + +/// Macro for creating DataFrame. +/// # Example +/// ``` +/// use datafusion::prelude::dataframe; +/// # use datafusion::error::Result; +/// # #[tokio::main] +/// # async fn main() -> Result<()> { +/// let df = dataframe!( +/// "id" => [1, 2, 3], +/// "name" => ["foo", "bar", "baz"] +/// )?; +/// df.show().await?; +/// // +----+------+, +/// // | id | name |, +/// // +----+------+, +/// // | 1 | foo |, +/// // | 2 | bar |, +/// // | 3 | baz |, +/// // +----+------+, +/// let df_empty = dataframe!()?; // empty DataFrame +/// assert_eq!(df_empty.schema().fields().len(), 0); +/// assert_eq!(df_empty.count().await?, 0); +/// # Ok(()) +/// # } +/// ``` +#[macro_export] +macro_rules! dataframe { + () => {{ + use std::sync::Arc; + + use datafusion::prelude::SessionContext; + use datafusion::arrow::array::RecordBatch; + use datafusion::arrow::datatypes::Schema; + + let ctx = SessionContext::new(); + let batch = RecordBatch::new_empty(Arc::new(Schema::empty())); + ctx.read_batch(batch) + }}; + + ($($name:expr => $data:expr),+ $(,)?) => {{ + use datafusion::prelude::DataFrame; + use datafusion::common::test_util::IntoArrayRef; + + let columns = vec![ + $( + ($name, $data.into_array_ref()), + )+ + ]; + + DataFrame::from_columns(columns) + }}; } #[derive(Debug)] diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index 1bb5444ca009..a2bec74ee140 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -246,4 +246,72 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn roundtrip_parquet_with_encryption() -> Result<()> { + use parquet::encryption::decrypt::FileDecryptionProperties; + use parquet::encryption::encrypt::FileEncryptionProperties; + + let test_df = test_util::test_table().await?; + + let schema = test_df.schema(); + let footer_key = b"0123456789012345".to_vec(); // 128bit/16 + let column_key = b"1234567890123450".to_vec(); // 128bit/16 + + let mut encrypt = FileEncryptionProperties::builder(footer_key.clone()); + let mut decrypt = FileDecryptionProperties::builder(footer_key.clone()); + + for field in schema.fields().iter() { + encrypt = encrypt.with_column_key(field.name().as_str(), column_key.clone()); + decrypt = decrypt.with_column_key(field.name().as_str(), column_key.clone()); + } + + let encrypt = encrypt.build()?; + let decrypt = decrypt.build()?; + + let df = test_df.clone(); + let tmp_dir = TempDir::new()?; + let tempfile = tmp_dir.path().join("roundtrip.parquet"); + let tempfile_str = tempfile.into_os_string().into_string().unwrap(); + + // Write encrypted parquet using write_parquet + let mut options = TableParquetOptions::default(); + options.crypto.file_encryption = Some((&encrypt).into()); + + df.write_parquet( + tempfile_str.as_str(), + DataFrameWriteOptions::new().with_single_file_output(true), + Some(options), + ) + .await?; + let num_rows_written = test_df.count().await?; + + // Read encrypted parquet + let ctx: SessionContext = SessionContext::new(); + let read_options = + ParquetReadOptions::default().file_decryption_properties((&decrypt).into()); + + ctx.register_parquet("roundtrip_parquet", &tempfile_str, read_options.clone()) + .await?; + + let df_enc = ctx.sql("SELECT * FROM roundtrip_parquet").await?; + let num_rows_read = df_enc.count().await?; + + assert_eq!(num_rows_read, num_rows_written); + + // Read encrypted parquet and subset rows + columns + let encrypted_parquet_df = ctx.read_parquet(tempfile_str, read_options).await?; + + // Select three columns and filter the results + // Test that the filter works as expected + let selected = encrypted_parquet_df + .clone() + .select_columns(&["c1", "c2", "c3"])? + .filter(col("c2").gt(lit(4)))?; + + let num_rows_selected = selected.count().await?; + assert_eq!(num_rows_selected, 14); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 7fc27453d1ad..b620ff62d9a6 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -27,7 +27,7 @@ use std::sync::Arc; use super::file_compression_type::FileCompressionType; use super::write::demux::DemuxedStreamReceiver; -use super::write::{create_writer, SharedBuffer}; +use super::write::SharedBuffer; use super::FileFormatFactory; use crate::datasource::file_format::write::get_writer_schema; use crate::datasource::file_format::FileFormat; @@ -51,9 +51,9 @@ use datafusion_datasource::display::FileGroupDisplay; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; use datafusion_datasource::sink::{DataSink, DataSinkExec}; +use datafusion_datasource::write::ObjectWriterBuilder; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; -use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::LexRequirement; use async_trait::async_trait; @@ -173,7 +173,6 @@ impl FileFormat for ArrowFormat { &self, _state: &dyn Session, conf: FileScanConfig, - _filters: Option<&Arc>, ) -> Result> { let source = Arc::new(ArrowSource::default()); let config = FileScanConfigBuilder::from(conf) @@ -223,7 +222,7 @@ impl FileSink for ArrowFileSink { async fn spawn_writer_tasks_and_join( &self, - _context: &Arc, + context: &Arc, demux_task: SpawnedTask>, mut file_stream_rx: DemuxedStreamReceiver, object_store: Arc, @@ -241,12 +240,19 @@ impl FileSink for ArrowFileSink { &get_writer_schema(&self.config), ipc_options.clone(), )?; - let mut object_store_writer = create_writer( + let mut object_store_writer = ObjectWriterBuilder::new( FileCompressionType::UNCOMPRESSED, &path, Arc::clone(&object_store), ) - .await?; + .with_buffer_size(Some( + context + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + )) + .build()?; file_write_tasks.spawn(async move { let mut row_count = 0; while let Some(batch) = rx.recv().await { diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 323bc28057d4..9022e340cd36 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -33,6 +33,7 @@ mod tests { use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion_catalog::Session; use datafusion_common::cast::as_string_array; + use datafusion_common::config::CsvOptions; use datafusion_common::internal_err; use datafusion_common::stats::Precision; use datafusion_common::test_util::{arrow_test_data, batches_to_string}; @@ -217,8 +218,11 @@ mod tests { assert_eq!(tt_batches, 50 /* 100/2 */); // test metadata - assert_eq!(exec.statistics()?.num_rows, Precision::Absent); - assert_eq!(exec.statistics()?.total_byte_size, Precision::Absent); + assert_eq!(exec.partition_statistics(None)?.num_rows, Precision::Absent); + assert_eq!( + exec.partition_statistics(None)?.total_byte_size, + Precision::Absent + ); Ok(()) } @@ -792,6 +796,62 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_csv_write_empty_file() -> Result<()> { + // Case 1. write to a single file + // Expect: an empty file created + let tmp_dir = tempfile::TempDir::new().unwrap(); + let path = format!("{}/empty.csv", tmp_dir.path().to_string_lossy()); + + let ctx = SessionContext::new(); + + let df = ctx.sql("SELECT 1 limit 0").await?; + + let cfg1 = + crate::dataframe::DataFrameWriteOptions::new().with_single_file_output(true); + let cfg2 = CsvOptions::default().with_has_header(true); + + df.write_csv(&path, cfg1, Some(cfg2)).await?; + assert!(std::path::Path::new(&path).exists()); + + // Case 2. write to a directory without partition columns + // Expect: under the directory, an empty file is created + let tmp_dir = tempfile::TempDir::new().unwrap(); + let path = format!("{}", tmp_dir.path().to_string_lossy()); + + let cfg1 = + crate::dataframe::DataFrameWriteOptions::new().with_single_file_output(true); + let cfg2 = CsvOptions::default().with_has_header(true); + + let df = ctx.sql("SELECT 1 limit 0").await?; + + df.write_csv(&path, cfg1, Some(cfg2)).await?; + assert!(std::path::Path::new(&path).exists()); + + let files = std::fs::read_dir(&path).unwrap(); + assert!(files.count() == 1); + + // Case 3. write to a directory with partition columns + // Expect: No file is created + let tmp_dir = tempfile::TempDir::new().unwrap(); + let path = format!("{}", tmp_dir.path().to_string_lossy()); + + let df = ctx.sql("SELECT 1 as col1, 2 as col2 limit 0").await?; + + let cfg1 = crate::dataframe::DataFrameWriteOptions::new() + .with_single_file_output(true) + .with_partition_by(vec!["col1".to_string()]); + let cfg2 = CsvOptions::default().with_has_header(true); + + df.write_csv(&path, cfg1, Some(cfg2)).await?; + + assert!(std::path::Path::new(&path).exists()); + let files = std::fs::read_dir(&path).unwrap(); + assert!(files.count() == 0); + + Ok(()) + } + /// Read a single empty csv file with header /// /// empty.csv: @@ -1023,7 +1083,7 @@ mod tests { for _ in 0..batch_count { let output = deserializer.next()?; let DeserializerOutput::RecordBatch(batch) = output else { - panic!("Expected RecordBatch, got {:?}", output); + panic!("Expected RecordBatch, got {output:?}"); }; all_batches = concat_batches(&schema, &[all_batches, batch])?; } @@ -1061,7 +1121,7 @@ mod tests { for _ in 0..batch_count { let output = deserializer.next()?; let DeserializerOutput::RecordBatch(batch) = output else { - panic!("Expected RecordBatch, got {:?}", output); + panic!("Expected RecordBatch, got {output:?}"); }; all_batches = concat_batches(&schema, &[all_batches, batch])?; } @@ -1142,18 +1202,14 @@ mod tests { fn csv_line(line_number: usize) -> Bytes { let (int_value, float_value, bool_value, char_value) = csv_values(line_number); - format!( - "{},{},{},{}\n", - int_value, float_value, bool_value, char_value - ) - .into() + format!("{int_value},{float_value},{bool_value},{char_value}\n").into() } fn csv_values(line_number: usize) -> (i32, f64, bool, String) { let int_value = line_number as i32; let float_value = line_number as f64; let bool_value = line_number % 2 == 0; - let char_value = format!("{}-string", line_number); + let char_value = format!("{line_number}-string"); (int_value, float_value, bool_value, char_value) } diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index a70a0f51d330..d818187bb307 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -31,6 +31,7 @@ mod tests { use arrow_schema::Schema; use bytes::Bytes; use datafusion_catalog::Session; + use datafusion_common::config::JsonOptions; use datafusion_common::test_util::batches_to_string; use datafusion_datasource::decoder::{ BatchDeserializer, DecoderDeserializer, DeserializerOutput, @@ -75,8 +76,11 @@ mod tests { assert_eq!(tt_batches, 6 /* 12/2 */); // test metadata - assert_eq!(exec.statistics()?.num_rows, Precision::Absent); - assert_eq!(exec.statistics()?.total_byte_size, Precision::Absent); + assert_eq!(exec.partition_statistics(None)?.num_rows, Precision::Absent); + assert_eq!( + exec.partition_statistics(None)?.total_byte_size, + Precision::Absent + ); Ok(()) } @@ -254,6 +258,61 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_json_write_empty_file() -> Result<()> { + // Case 1. write to a single file + // Expect: an empty file created + let tmp_dir = tempfile::TempDir::new().unwrap(); + let path = format!("{}/empty.json", tmp_dir.path().to_string_lossy()); + + let ctx = SessionContext::new(); + + let df = ctx.sql("SELECT 1 limit 0").await?; + + let cfg1 = + crate::dataframe::DataFrameWriteOptions::new().with_single_file_output(true); + let cfg2 = JsonOptions::default(); + + df.write_json(&path, cfg1, Some(cfg2)).await?; + assert!(std::path::Path::new(&path).exists()); + + // Case 2. write to a directory without partition columns + // Expect: under the directory, an empty file is created + let tmp_dir = tempfile::TempDir::new().unwrap(); + let path = format!("{}", tmp_dir.path().to_string_lossy()); + + let cfg1 = + crate::dataframe::DataFrameWriteOptions::new().with_single_file_output(true); + let cfg2 = JsonOptions::default(); + + let df = ctx.sql("SELECT 1 limit 0").await?; + + df.write_json(&path, cfg1, Some(cfg2)).await?; + assert!(std::path::Path::new(&path).exists()); + + let files = std::fs::read_dir(&path).unwrap(); + assert!(files.count() == 1); + + // Case 3. write to a directory with partition columns + // Expect: No file is created + let tmp_dir = tempfile::TempDir::new().unwrap(); + let path = format!("{}", tmp_dir.path().to_string_lossy()); + + let df = ctx.sql("SELECT 1 as col1, 2 as col2 limit 0").await?; + + let cfg1 = crate::dataframe::DataFrameWriteOptions::new() + .with_single_file_output(true) + .with_partition_by(vec!["col1".to_string()]); + let cfg2 = JsonOptions::default(); + + df.write_json(&path, cfg1, Some(cfg2)).await?; + + assert!(std::path::Path::new(&path).exists()); + let files = std::fs::read_dir(&path).unwrap(); + assert!(files.count() == 0); + Ok(()) + } + #[test] fn test_json_deserializer_finish() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -275,7 +334,7 @@ mod tests { for _ in 0..3 { let output = deserializer.next()?; let DeserializerOutput::RecordBatch(batch) = output else { - panic!("Expected RecordBatch, got {:?}", output); + panic!("Expected RecordBatch, got {output:?}"); }; all_batches = concat_batches(&schema, &[all_batches, batch])? } @@ -315,7 +374,7 @@ mod tests { for _ in 0..2 { let output = deserializer.next()?; let DeserializerOutput::RecordBatch(batch) = output else { - panic!("Expected RecordBatch, got {:?}", output); + panic!("Expected RecordBatch, got {output:?}"); }; all_batches = concat_batches(&schema, &[all_batches, batch])? } diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 3a098301f14e..e165707c2eb0 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -93,7 +93,6 @@ pub(crate) mod test_util { .with_projection(projection) .with_limit(limit) .build(), - None, ) .await?; Ok(exec) diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index 08e9a628dd61..02b792823a82 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -34,7 +34,7 @@ use crate::error::Result; use crate::execution::context::{SessionConfig, SessionState}; use arrow::datatypes::{DataType, Schema, SchemaRef}; -use datafusion_common::config::TableOptions; +use datafusion_common::config::{ConfigFileDecryptionProperties, TableOptions}; use datafusion_common::{ DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, DEFAULT_CSV_EXTENSION, DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, @@ -252,6 +252,8 @@ pub struct ParquetReadOptions<'a> { pub schema: Option<&'a Schema>, /// Indicates how the file is sorted pub file_sort_order: Vec>, + /// Properties for decryption of Parquet files that use modular encryption + pub file_decryption_properties: Option, } impl Default for ParquetReadOptions<'_> { @@ -263,6 +265,7 @@ impl Default for ParquetReadOptions<'_> { skip_metadata: None, schema: None, file_sort_order: vec![], + file_decryption_properties: None, } } } @@ -313,6 +316,15 @@ impl<'a> ParquetReadOptions<'a> { self.file_sort_order = file_sort_order; self } + + /// Configure file decryption properties for reading encrypted Parquet files + pub fn file_decryption_properties( + mut self, + file_decryption_properties: ConfigFileDecryptionProperties, + ) -> Self { + self.file_decryption_properties = Some(file_decryption_properties); + self + } } /// Options that control the reading of ARROW files. @@ -550,7 +562,7 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) - .with_target_partitions(config.target_partitions()) + .with_session_config_options(config) .with_table_partition_cols(self.table_partition_cols.clone()) .with_file_sort_order(self.file_sort_order.clone()) } @@ -574,7 +586,11 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { config: &SessionConfig, table_options: TableOptions, ) -> ListingOptions { - let mut file_format = ParquetFormat::new().with_options(table_options.parquet); + let mut options = table_options.parquet; + if let Some(file_decryption_properties) = &self.file_decryption_properties { + options.crypto.file_decryption = Some(file_decryption_properties.clone()); + } + let mut file_format = ParquetFormat::new().with_options(options); if let Some(parquet_pruning) = self.parquet_pruning { file_format = file_format.with_enable_pruning(parquet_pruning) @@ -585,9 +601,9 @@ impl ReadOptions<'_> for ParquetReadOptions<'_> { ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) - .with_target_partitions(config.target_partitions()) .with_table_partition_cols(self.table_partition_cols.clone()) .with_file_sort_order(self.file_sort_order.clone()) + .with_session_config_options(config) } async fn get_resolved_schema( @@ -615,7 +631,7 @@ impl ReadOptions<'_> for NdJsonReadOptions<'_> { ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) - .with_target_partitions(config.target_partitions()) + .with_session_config_options(config) .with_table_partition_cols(self.table_partition_cols.clone()) .with_file_sort_order(self.file_sort_order.clone()) } @@ -643,7 +659,7 @@ impl ReadOptions<'_> for AvroReadOptions<'_> { ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) - .with_target_partitions(config.target_partitions()) + .with_session_config_options(config) .with_table_partition_cols(self.table_partition_cols.clone()) } @@ -669,7 +685,7 @@ impl ReadOptions<'_> for ArrowReadOptions<'_> { ListingOptions::new(Arc::new(file_format)) .with_file_extension(self.file_extension) - .with_target_partitions(config.target_partitions()) + .with_session_config_options(config) .with_table_partition_cols(self.table_partition_cols.clone()) } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 7b8b99273f4e..8a2db3431fa0 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -27,7 +27,10 @@ pub(crate) mod test_util { use crate::test::object_store::local_unpartitioned_file; - /// Writes `batches` to a temporary parquet file + /// Writes each `batch` to at least one temporary parquet file + /// + /// For example, if `batches` contains 2 batches, the function will create + /// 2 temporary files, each containing the contents of one batch /// /// If multi_page is set to `true`, the parquet file(s) are written /// with 2 rows per data page (used to test page filtering and @@ -52,7 +55,7 @@ pub(crate) mod test_util { } } - // we need the tmp files to be sorted as some tests rely on the how the returning files are ordered + // we need the tmp files to be sorted as some tests rely on the returned file ordering // https://github.com/apache/datafusion/pull/6629 let tmp_files = { let mut tmp_files: Vec<_> = (0..batches.len()) @@ -104,10 +107,8 @@ pub(crate) mod test_util { mod tests { use std::fmt::{self, Display, Formatter}; - use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; - use std::task::{Context, Poll}; use std::time::Duration; use crate::datasource::file_format::parquet::test_util::store_parquet; @@ -117,7 +118,7 @@ mod tests { use crate::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; use arrow::array::RecordBatch; - use arrow_schema::{Schema, SchemaRef}; + use arrow_schema::Schema; use datafusion_catalog::Session; use datafusion_common::cast::{ as_binary_array, as_binary_view_array, as_boolean_array, as_float32_array, @@ -137,7 +138,7 @@ mod tests { }; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; - use datafusion_execution::{RecordBatchStream, TaskContext}; + use datafusion_execution::TaskContext; use datafusion_expr::dml::InsertOp; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::{collect, ExecutionPlan}; @@ -150,7 +151,7 @@ mod tests { use async_trait::async_trait; use datafusion_datasource::file_groups::FileGroup; use futures::stream::BoxStream; - use futures::{Stream, StreamExt}; + use futures::StreamExt; use insta::assert_snapshot; use log::error; use object_store::local::LocalFileSystem; @@ -166,6 +167,8 @@ mod tests { use parquet::format::FileMetaData; use tokio::fs::File; + use crate::test_util::bounded_stream; + enum ForceViews { Yes, No, @@ -193,7 +196,8 @@ mod tests { let schema = format.infer_schema(&ctx, &store, &meta).await.unwrap(); let stats = - fetch_statistics(store.as_ref(), schema.clone(), &meta[0], None).await?; + fetch_statistics(store.as_ref(), schema.clone(), &meta[0], None, None) + .await?; assert_eq!(stats.num_rows, Precision::Exact(3)); let c1_stats = &stats.column_statistics[0]; @@ -201,7 +205,8 @@ mod tests { assert_eq!(c1_stats.null_count, Precision::Exact(1)); assert_eq!(c2_stats.null_count, Precision::Exact(3)); - let stats = fetch_statistics(store.as_ref(), schema, &meta[1], None).await?; + let stats = + fetch_statistics(store.as_ref(), schema, &meta[1], None, None).await?; assert_eq!(stats.num_rows, Precision::Exact(3)); let c1_stats = &stats.column_statistics[0]; let c2_stats = &stats.column_statistics[1]; @@ -373,9 +378,14 @@ mod tests { // Use a size hint larger than the parquet footer but smaller than the actual metadata, requiring a second fetch // for the remaining metadata - fetch_parquet_metadata(store.as_ref() as &dyn ObjectStore, &meta[0], Some(9)) - .await - .expect("error reading metadata with hint"); + fetch_parquet_metadata( + store.as_ref() as &dyn ObjectStore, + &meta[0], + Some(9), + None, + ) + .await + .expect("error reading metadata with hint"); assert_eq!(store.request_count(), 2); @@ -393,9 +403,14 @@ mod tests { .await .unwrap(); - let stats = - fetch_statistics(store.upcast().as_ref(), schema.clone(), &meta[0], Some(9)) - .await?; + let stats = fetch_statistics( + store.upcast().as_ref(), + schema.clone(), + &meta[0], + Some(9), + None, + ) + .await?; assert_eq!(stats.num_rows, Precision::Exact(3)); let c1_stats = &stats.column_statistics[0]; @@ -410,7 +425,7 @@ mod tests { // Use the file size as the hint so we can get the full metadata from the first fetch let size_hint = meta[0].size as usize; - fetch_parquet_metadata(store.upcast().as_ref(), &meta[0], Some(size_hint)) + fetch_parquet_metadata(store.upcast().as_ref(), &meta[0], Some(size_hint), None) .await .expect("error reading metadata with hint"); @@ -429,6 +444,7 @@ mod tests { schema.clone(), &meta[0], Some(size_hint), + None, ) .await?; @@ -445,7 +461,7 @@ mod tests { // Use the a size hint larger than the file size to make sure we don't panic let size_hint = (meta[0].size + 100) as usize; - fetch_parquet_metadata(store.upcast().as_ref(), &meta[0], Some(size_hint)) + fetch_parquet_metadata(store.upcast().as_ref(), &meta[0], Some(size_hint), None) .await .expect("error reading metadata with hint"); @@ -484,7 +500,8 @@ mod tests { let schema = format.infer_schema(&state, &store, &files).await.unwrap(); // Fetch statistics for first file - let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[0], None).await?; + let pq_meta = + fetch_parquet_metadata(store.as_ref(), &files[0], None, None).await?; let stats = statistics_from_parquet_meta_calc(&pq_meta, schema.clone())?; assert_eq!(stats.num_rows, Precision::Exact(4)); @@ -542,7 +559,8 @@ mod tests { }; // Fetch statistics for first file - let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[0], None).await?; + let pq_meta = + fetch_parquet_metadata(store.as_ref(), &files[0], None, None).await?; let stats = statistics_from_parquet_meta_calc(&pq_meta, schema.clone())?; assert_eq!(stats.num_rows, Precision::Exact(3)); // column c1 @@ -568,7 +586,8 @@ mod tests { assert_eq!(c2_stats.min_value, Precision::Exact(null_i64.clone())); // Fetch statistics for second file - let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[1], None).await?; + let pq_meta = + fetch_parquet_metadata(store.as_ref(), &files[1], None, None).await?; let stats = statistics_from_parquet_meta_calc(&pq_meta, schema.clone())?; assert_eq!(stats.num_rows, Precision::Exact(3)); // column c1: missing from the file so the table treats all 3 rows as null @@ -616,9 +635,15 @@ mod tests { assert_eq!(tt_batches, 4 /* 8/2 */); // test metadata - assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); + assert_eq!( + exec.partition_statistics(None)?.num_rows, + Precision::Exact(8) + ); // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); + assert_eq!( + exec.partition_statistics(None)?.total_byte_size, + Precision::Exact(671) + ); Ok(()) } @@ -659,9 +684,15 @@ mod tests { get_exec(&state, "alltypes_plain.parquet", projection, Some(1)).await?; // note: even if the limit is set, the executor rounds up to the batch size - assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); + assert_eq!( + exec.partition_statistics(None)?.num_rows, + Precision::Exact(8) + ); // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); + assert_eq!( + exec.partition_statistics(None)?.total_byte_size, + Precision::Exact(671) + ); let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(11, batches[0].num_columns()); @@ -1073,7 +1104,7 @@ mod tests { let format = state .get_file_format_factory("parquet") .map(|factory| factory.create(state, &Default::default()).unwrap()) - .unwrap_or(Arc::new(ParquetFormat::new())); + .unwrap_or_else(|| Arc::new(ParquetFormat::new())); scan_format( state, &*format, None, &testdata, file_name, projection, limit, @@ -1231,6 +1262,61 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_parquet_write_empty_file() -> Result<()> { + // Case 1. write to a single file + // Expect: an empty file created + let tmp_dir = tempfile::TempDir::new().unwrap(); + let path = format!("{}/empty.parquet", tmp_dir.path().to_string_lossy()); + + let ctx = SessionContext::new(); + + let df = ctx.sql("SELECT 1 limit 0").await?; + + let cfg1 = + crate::dataframe::DataFrameWriteOptions::new().with_single_file_output(true); + let cfg2 = TableParquetOptions::default(); + + df.write_parquet(&path, cfg1, Some(cfg2)).await?; + assert!(std::path::Path::new(&path).exists()); + + // Case 2. write to a directory without partition columns + // Expect: under the directory, an empty file is created + let tmp_dir = tempfile::TempDir::new().unwrap(); + let path = format!("{}", tmp_dir.path().to_string_lossy()); + + let cfg1 = + crate::dataframe::DataFrameWriteOptions::new().with_single_file_output(true); + let cfg2 = TableParquetOptions::default(); + + let df = ctx.sql("SELECT 1 limit 0").await?; + + df.write_parquet(&path, cfg1, Some(cfg2)).await?; + assert!(std::path::Path::new(&path).exists()); + + let files = std::fs::read_dir(&path).unwrap(); + assert!(files.count() == 1); + + // Case 3. write to a directory with partition columns + // Expect: No file is created + let tmp_dir = tempfile::TempDir::new().unwrap(); + let path = format!("{}", tmp_dir.path().to_string_lossy()); + + let df = ctx.sql("SELECT 1 as col1, 2 as col2 limit 0").await?; + + let cfg1 = crate::dataframe::DataFrameWriteOptions::new() + .with_single_file_output(true) + .with_partition_by(vec!["col1".to_string()]); + let cfg2 = TableParquetOptions::default(); + + df.write_parquet(&path, cfg1, Some(cfg2)).await?; + + assert!(std::path::Path::new(&path).exists()); + let files = std::fs::read_dir(&path).unwrap(); + assert!(files.count() == 0); + Ok(()) + } + #[tokio::test] async fn parquet_sink_write_insert_schema_into_metadata() -> Result<()> { // expected kv metadata without schema @@ -1308,7 +1394,7 @@ mod tests { #[tokio::test] async fn parquet_sink_write_with_extension() -> Result<()> { let filename = "test_file.custom_ext"; - let file_path = format!("file:///path/to/{}", filename); + let file_path = format!("file:///path/to/{filename}"); let parquet_sink = create_written_parquet_sink(file_path.as_str()).await?; // assert written to proper path @@ -1523,8 +1609,7 @@ mod tests { let prefix = path_parts[0].as_ref(); assert!( expected_partitions.contains(prefix), - "expected path prefix to match partition, instead found {:?}", - prefix + "expected path prefix to match partition, instead found {prefix:?}" ); expected_partitions.remove(prefix); @@ -1648,43 +1733,4 @@ mod tests { Ok(()) } - - /// Creates an bounded stream for testing purposes. - fn bounded_stream( - batch: RecordBatch, - limit: usize, - ) -> datafusion_execution::SendableRecordBatchStream { - Box::pin(BoundedStream { - count: 0, - limit, - batch, - }) - } - - struct BoundedStream { - limit: usize, - count: usize, - batch: RecordBatch, - } - - impl Stream for BoundedStream { - type Item = Result; - - fn poll_next( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - if self.count >= self.limit { - return Poll::Ready(None); - } - self.count += 1; - Poll::Ready(Some(Ok(self.batch.clone()))) - } - } - - impl RecordBatchStream for BoundedStream { - fn schema(&self) -> SchemaRef { - self.batch.schema() - } - } } diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index f32a32355cbd..3ddf1c85e241 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -17,106 +17,135 @@ //! The table implementation. -use super::helpers::{expr_applicable_for_cols, pruned_partition_list}; -use super::{ListingTableUrl, PartitionedFile}; -use std::collections::HashMap; -use std::{any::Any, str::FromStr, sync::Arc}; - -use crate::datasource::{ - create_ordering, - file_format::{ - file_compression_type::FileCompressionType, FileFormat, FilePushdownSupport, - }, - physical_plan::FileSinkConfig, +use super::{ + helpers::{expr_applicable_for_cols, pruned_partition_list}, + ListingTableUrl, PartitionedFile, }; -use crate::execution::context::SessionState; -use datafusion_catalog::TableProvider; -use datafusion_common::{config_err, DataFusionError, Result}; -use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; -use datafusion_expr::dml::InsertOp; -use datafusion_expr::{utils::conjunction, Expr, TableProviderFilterPushDown}; -use datafusion_expr::{SortExpr, TableType}; -use datafusion_physical_plan::empty::EmptyExec; -use datafusion_physical_plan::{ExecutionPlan, Statistics}; - -use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, SchemaRef}; +use crate::{ + datasource::file_format::{file_compression_type::FileCompressionType, FileFormat}, + datasource::{create_ordering, physical_plan::FileSinkConfig}, + execution::context::SessionState, +}; +use arrow::datatypes::{DataType, Field, SchemaBuilder, SchemaRef}; +use arrow_schema::Schema; +use async_trait::async_trait; +use datafusion_catalog::{Session, TableProvider}; use datafusion_common::{ - config_datafusion_err, internal_err, plan_err, project_schema, Constraints, - SchemaExt, ToDFSchema, + config_datafusion_err, config_err, internal_err, plan_err, project_schema, + stats::Precision, Constraints, DataFusionError, Result, SchemaExt, }; -use datafusion_execution::cache::{ - cache_manager::FileStatisticsCache, cache_unit::DefaultFileStatisticsCache, +use datafusion_datasource::{ + compute_all_files_statistics, + file_groups::FileGroup, + file_scan_config::{FileScanConfig, FileScanConfigBuilder}, + schema_adapter::DefaultSchemaAdapterFactory, }; -use datafusion_physical_expr::{ - create_physical_expr, LexOrdering, PhysicalSortRequirement, +use datafusion_execution::{ + cache::{cache_manager::FileStatisticsCache, cache_unit::DefaultFileStatisticsCache}, + config::SessionConfig, }; - -use async_trait::async_trait; -use datafusion_catalog::Session; -use datafusion_common::stats::Precision; -use datafusion_datasource::compute_all_files_statistics; -use datafusion_datasource::file_groups::FileGroup; -use datafusion_physical_expr_common::sort_expr::LexRequirement; +use datafusion_expr::{ + dml::InsertOp, Expr, SortExpr, TableProviderFilterPushDown, TableType, +}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::{empty::EmptyExec, ExecutionPlan, Statistics}; use futures::{future, stream, Stream, StreamExt, TryStreamExt}; use itertools::Itertools; use object_store::ObjectStore; +use std::{any::Any, collections::HashMap, str::FromStr, sync::Arc}; +/// Indicates the source of the schema for a [`ListingTable`] +// PartialEq required for assert_eq! in tests +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum SchemaSource { + /// Schema is not yet set (initial state) + None, + /// Schema was inferred from first table_path + Inferred, + /// Schema was specified explicitly via with_schema + Specified, +} /// Configuration for creating a [`ListingTable`] +/// +/// #[derive(Debug, Clone)] pub struct ListingTableConfig { /// Paths on the `ObjectStore` for creating `ListingTable`. /// They should share the same schema and object store. pub table_paths: Vec, /// Optional `SchemaRef` for the to be created `ListingTable`. + /// + /// See details on [`ListingTableConfig::with_schema`] pub file_schema: Option, - /// Optional `ListingOptions` for the to be created `ListingTable`. + /// Optional [`ListingOptions`] for the to be created [`ListingTable`]. + /// + /// See details on [`ListingTableConfig::with_listing_options`] pub options: Option, + /// Tracks the source of the schema information + schema_source: SchemaSource, } impl ListingTableConfig { - /// Creates new [`ListingTableConfig`]. - /// - /// The [`SchemaRef`] and [`ListingOptions`] are inferred based on - /// the suffix of the provided `table_paths` first element. + /// Creates new [`ListingTableConfig`] for reading the specified URL pub fn new(table_path: ListingTableUrl) -> Self { let table_paths = vec![table_path]; Self { table_paths, file_schema: None, options: None, + schema_source: SchemaSource::None, } } /// Creates new [`ListingTableConfig`] with multiple table paths. /// - /// The [`SchemaRef`] and [`ListingOptions`] are inferred based on - /// the suffix of the provided `table_paths` first element. + /// See [`Self::infer_options`] for details on what happens with multiple paths pub fn new_with_multi_paths(table_paths: Vec) -> Self { Self { table_paths, file_schema: None, options: None, + schema_source: SchemaSource::None, } } - /// Add `schema` to [`ListingTableConfig`] + + /// Returns the source of the schema for this configuration + pub fn schema_source(&self) -> SchemaSource { + self.schema_source + } + /// Set the `schema` for the overall [`ListingTable`] + /// + /// [`ListingTable`] will automatically coerce, when possible, the schema + /// for individual files to match this schema. + /// + /// If a schema is not provided, it is inferred using + /// [`Self::infer_schema`]. + /// + /// If the schema is provided, it must contain only the fields in the file + /// without the table partitioning columns. pub fn with_schema(self, schema: SchemaRef) -> Self { Self { table_paths: self.table_paths, file_schema: Some(schema), options: self.options, + schema_source: SchemaSource::Specified, } } /// Add `listing_options` to [`ListingTableConfig`] + /// + /// If not provided, format and other options are inferred via + /// [`Self::infer_options`]. pub fn with_listing_options(self, listing_options: ListingOptions) -> Self { Self { table_paths: self.table_paths, file_schema: self.file_schema, options: Some(listing_options), + schema_source: self.schema_source, } } - ///Returns a tupe of (file_extension, optional compression_extension) + /// Returns a tuple of `(file_extension, optional compression_extension)` /// /// For example a path ending with blah.test.csv.gz returns `("csv", Some("gz"))` /// For example a path ending with blah.test.csv returns `("csv", None)` @@ -138,7 +167,9 @@ impl ListingTableConfig { } } - /// Infer `ListingOptions` based on `table_path` suffix. + /// Infer `ListingOptions` based on `table_path` and file suffix. + /// + /// The format is inferred based on the first `table_path`. pub async fn infer_options(self, state: &dyn Session) -> Result { let store = if let Some(url) = self.table_paths.first() { state.runtime_env().object_store(url)? @@ -183,41 +214,68 @@ impl ListingTableConfig { let listing_options = ListingOptions::new(file_format) .with_file_extension(listing_file_extension) - .with_target_partitions(state.config().target_partitions()); + .with_target_partitions(state.config().target_partitions()) + .with_collect_stat(state.config().collect_statistics()); Ok(Self { table_paths: self.table_paths, file_schema: self.file_schema, options: Some(listing_options), + schema_source: self.schema_source, }) } - /// Infer the [`SchemaRef`] based on `table_path` suffix. Requires `self.options` to be set prior to using. + /// Infer the [`SchemaRef`] based on `table_path`s. + /// + /// This method infers the table schema using the first `table_path`. + /// See [`ListingOptions::infer_schema`] for more details + /// + /// # Errors + /// * if `self.options` is not set. See [`Self::with_listing_options`] pub async fn infer_schema(self, state: &dyn Session) -> Result { match self.options { Some(options) => { - let schema = if let Some(url) = self.table_paths.first() { - options.infer_schema(state, url).await? - } else { - Arc::new(Schema::empty()) + let ListingTableConfig { + table_paths, + file_schema, + options: _, + schema_source, + } = self; + + let (schema, new_schema_source) = match file_schema { + Some(schema) => (schema, schema_source), // Keep existing source if schema exists + None => { + if let Some(url) = table_paths.first() { + ( + options.infer_schema(state, url).await?, + SchemaSource::Inferred, + ) + } else { + (Arc::new(Schema::empty()), SchemaSource::Inferred) + } + } }; Ok(Self { - table_paths: self.table_paths, + table_paths, file_schema: Some(schema), options: Some(options), + schema_source: new_schema_source, }) } None => internal_err!("No `ListingOptions` set for inferring schema"), } } - /// Convenience wrapper for calling `infer_options` and `infer_schema` + /// Convenience method to call both [`Self::infer_options`] and [`Self::infer_schema`] pub async fn infer(self, state: &dyn Session) -> Result { self.infer_options(state).await?.infer_schema(state).await } - /// Infer the partition columns from the path. Requires `self.options` to be set prior to using. + /// Infer the partition columns from `table_paths`. + /// + /// # Errors + /// * if `self.options` is not set. See [`Self::with_listing_options`] pub async fn infer_partitions_from_path(self, state: &dyn Session) -> Result { match self.options { Some(options) => { @@ -243,6 +301,7 @@ impl ListingTableConfig { table_paths: self.table_paths, file_schema: self.file_schema, options: Some(options), + schema_source: self.schema_source, }) } None => config_err!("No `ListingOptions` set for inferring schema"), @@ -277,6 +336,7 @@ pub struct ListingOptions { /// parquet metadata. /// /// See + /// /// NOTE: This attribute stores all equivalent orderings (the outer `Vec`) /// where each ordering consists of an individual lexicographic /// ordering (encapsulated by a `Vec`). If there aren't @@ -291,18 +351,29 @@ impl ListingOptions { /// - use default file extension filter /// - no input partition to discover /// - one target partition - /// - stat collection + /// - do not collect statistics pub fn new(format: Arc) -> Self { Self { file_extension: format.get_ext(), format, table_partition_cols: vec![], - collect_stat: true, + collect_stat: false, target_partitions: 1, file_sort_order: vec![], } } + /// Set options from [`SessionConfig`] and returns self. + /// + /// Currently this sets `target_partitions` and `collect_stat` + /// but if more options are added in the future that need to be coordinated + /// they will be synchronized thorugh this method. + pub fn with_session_config_options(mut self, config: &SessionConfig) -> Self { + self = self.with_target_partitions(config.target_partitions()); + self = self.with_collect_stat(config.collect_statistics()); + self + } + /// Set file extension on [`ListingOptions`] and returns self. /// /// # Example @@ -479,11 +550,13 @@ impl ListingOptions { } /// Infer the schema of the files at the given path on the provided object store. - /// The inferred schema does not include the partitioning columns. /// - /// This method will not be called by the table itself but before creating it. - /// This way when creating the logical plan we can decide to resolve the schema - /// locally or ask a remote service to do it (e.g a scheduler). + /// If the table_path contains one or more files (i.e. it is a directory / + /// prefix of files) their schema is merged by calling [`FileFormat::infer_schema`] + /// + /// Note: The inferred schema does not include any partitioning columns. + /// + /// This method is called as part of creating a [`ListingTable`]. pub async fn infer_schema<'a>( &'a self, state: &dyn Session, @@ -656,16 +729,14 @@ impl ListingOptions { /// `ListingTable` also supports limit, filter and projection pushdown for formats that /// support it as such as Parquet. /// -/// # Implementation +/// # See Also /// -/// `ListingTable` Uses [`DataSourceExec`] to execute the data. See that struct -/// for more details. +/// 1. [`ListingTableConfig`]: Configuration options +/// 1. [`DataSourceExec`]: `ExecutionPlan` used by `ListingTable` /// /// [`DataSourceExec`]: crate::datasource::source::DataSourceExec /// -/// # Example -/// -/// To read a directory of parquet files using a [`ListingTable`]: +/// # Example: Read a directory of parquet files using a [`ListingTable`] /// /// ```no_run /// # use datafusion::prelude::SessionContext; @@ -712,7 +783,7 @@ impl ListingOptions { /// # Ok(()) /// # } /// ``` -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ListingTable { table_paths: Vec, /// `file_schema` contains only the columns physically stored in the data files themselves. @@ -723,6 +794,8 @@ pub struct ListingTable { /// - Partition columns are derived from directory paths (not stored in files) /// - These are columns like "year=2022/month=01" in paths like `/data/year=2022/month=01/file.parquet` table_schema: SchemaRef, + /// Indicates how the schema was derived (inferred or explicitly specified) + schema_source: SchemaSource, options: ListingOptions, definition: Option, collected_statistics: FileStatisticsCache, @@ -731,17 +804,13 @@ pub struct ListingTable { } impl ListingTable { - /// Create new [`ListingTable`] that lists the FS to get the files - /// to scan. See [`ListingTable`] for and example. - /// - /// Takes a `ListingTableConfig` as input which requires an `ObjectStore` and `table_path`. - /// `ListingOptions` and `SchemaRef` are optional. If they are not - /// provided the file type is inferred based on the file suffix. - /// If the schema is provided then it must be resolved before creating the table - /// and should contain the fields of the file without the table - /// partitioning columns. + /// Create new [`ListingTable`] /// + /// See documentation and example on [`ListingTable`] and [`ListingTableConfig`] pub fn try_new(config: ListingTableConfig) -> Result { + // Extract schema_source before moving other parts of the config + let schema_source = config.schema_source(); + let file_schema = config .file_schema .ok_or_else(|| DataFusionError::Internal("No schema provided.".into()))?; @@ -766,10 +835,11 @@ impl ListingTable { table_paths: config.table_paths, file_schema, table_schema, + schema_source, options, definition: None, collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), - constraints: Constraints::empty(), + constraints: Constraints::default(), column_defaults: HashMap::new(), }; @@ -799,7 +869,7 @@ impl ListingTable { /// If `None`, creates a new [`DefaultFileStatisticsCache`] scoped to this query. pub fn with_cache(mut self, cache: Option) -> Self { self.collected_statistics = - cache.unwrap_or(Arc::new(DefaultFileStatisticsCache::default())); + cache.unwrap_or_else(|| Arc::new(DefaultFileStatisticsCache::default())); self } @@ -819,6 +889,11 @@ impl ListingTable { &self.options } + /// Get the schema source + pub fn schema_source(&self) -> SchemaSource { + self.schema_source + } + /// If file_sort_order is specified, creates the appropriate physical expressions fn try_create_output_ordering(&self) -> Result> { create_ordering(&self.table_schema, &self.options.file_sort_order) @@ -921,19 +996,6 @@ impl TableProvider for ListingTable { None => {} // no ordering required }; - let filters = match conjunction(filters.to_vec()) { - Some(expr) => { - let table_df_schema = self.table_schema.as_ref().clone().to_dfschema()?; - let filters = create_physical_expr( - &expr, - &table_df_schema, - state.execution_props(), - )?; - Some(filters) - } - None => None, - }; - let Some(object_store_url) = self.table_paths.first().map(ListingTableUrl::object_store) else { @@ -958,7 +1020,6 @@ impl TableProvider for ListingTable { .with_output_ordering(output_ordering) .with_table_partition_cols(table_partition_cols) .build(), - filters.as_ref(), ) .await } @@ -982,18 +1043,6 @@ impl TableProvider for ListingTable { return Ok(TableProviderFilterPushDown::Exact); } - // if we can't push it down completely with only the filename-based/path-based - // column names, then we should check if we can do parquet predicate pushdown - let supports_pushdown = self.options.format.supports_filters_pushdown( - &self.file_schema, - &self.table_schema, - &[filter], - )?; - - if supports_pushdown == FilePushdownSupport::Supported { - return Ok(TableProviderFilterPushDown::Exact); - } - Ok(TableProviderFilterPushDown::Inexact) }) .collect() @@ -1051,25 +1100,9 @@ impl TableProvider for ListingTable { file_extension: self.options().format.get_ext(), }; - let order_requirements = if !self.options().file_sort_order.is_empty() { - // Multiple sort orders in outer vec are equivalent, so we pass only the first one - let orderings = self.try_create_output_ordering()?; - let Some(ordering) = orderings.first() else { - return internal_err!( - "Expected ListingTable to have a sort order, but none found!" - ); - }; - // Converts Vec> into type required by execution plan to specify its required input ordering - Some(LexRequirement::new( - ordering - .into_iter() - .cloned() - .map(PhysicalSortRequirement::from) - .collect::>(), - )) - } else { - None - }; + let orderings = self.try_create_output_ordering()?; + // It is sufficient to pass only one of the equivalent orderings: + let order_requirements = orderings.into_iter().next().map(Into::into); self.options() .format @@ -1130,12 +1163,24 @@ impl ListingTable { get_files_with_limit(files, limit, self.options.collect_stat).await?; let file_groups = file_group.split_files(self.options.target_partitions); - compute_all_files_statistics( + let (mut file_groups, mut stats) = compute_all_files_statistics( file_groups, self.schema(), self.options.collect_stat, inexact_stats, - ) + )?; + let (schema_mapper, _) = DefaultSchemaAdapterFactory::from_schema(self.schema()) + .map_schema(self.file_schema.as_ref())?; + stats.column_statistics = + schema_mapper.map_column_statistics(&stats.column_statistics)?; + file_groups.iter_mut().try_for_each(|file_group| { + if let Some(stat) = file_group.statistics_mut() { + stat.column_statistics = + schema_mapper.map_column_statistics(&stat.column_statistics)?; + } + Ok::<_, DataFusionError>(()) + })?; + Ok((file_groups, stats)) } /// Collects statistics for a given partitioned file. @@ -1256,100 +1301,139 @@ async fn get_files_with_limit( #[cfg(test)] mod tests { use super::*; - use crate::datasource::file_format::csv::CsvFormat; - use crate::datasource::file_format::json::JsonFormat; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; - use crate::datasource::{provider_as_source, DefaultTableSource, MemTable}; - use crate::execution::options::ArrowReadOptions; use crate::prelude::*; - use crate::test::{columns, object_store::register_test_store}; - - use arrow::compute::SortOptions; - use arrow::record_batch::RecordBatch; - use datafusion_common::stats::Precision; - use datafusion_common::test_util::batches_to_string; - use datafusion_common::{assert_contains, ScalarValue}; + use crate::{ + datasource::{ + file_format::csv::CsvFormat, file_format::json::JsonFormat, + provider_as_source, DefaultTableSource, MemTable, + }, + execution::options::ArrowReadOptions, + test::{ + columns, object_store::ensure_head_concurrency, + object_store::make_test_store_and_state, object_store::register_test_store, + }, + }; + use arrow::{compute::SortOptions, record_batch::RecordBatch}; + use datafusion_common::{ + assert_contains, + stats::Precision, + test_util::{batches_to_string, datafusion_test_data}, + ScalarValue, + }; use datafusion_expr::{BinaryExpr, LogicalPlanBuilder, Operator}; use datafusion_physical_expr::PhysicalSortExpr; - use datafusion_physical_plan::collect; - use datafusion_physical_plan::ExecutionPlanProperties; - - use crate::test::object_store::{ensure_head_concurrency, make_test_store_and_state}; + use datafusion_physical_plan::{collect, ExecutionPlanProperties}; + use std::io::Write; use tempfile::TempDir; use url::Url; + /// Creates a test schema with standard field types used in tests + fn create_test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Float32, true), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::Boolean, true), + Field::new("c4", DataType::Utf8, true), + ])) + } + + /// Helper function to generate test file paths with given prefix, count, and optional start index + fn generate_test_files(prefix: &str, count: usize) -> Vec { + generate_test_files_with_start(prefix, count, 0) + } + + /// Helper function to generate test file paths with given prefix, count, and start index + fn generate_test_files_with_start( + prefix: &str, + count: usize, + start_index: usize, + ) -> Vec { + (start_index..start_index + count) + .map(|i| format!("{prefix}/file{i}")) + .collect() + } + #[tokio::test] - async fn read_single_file() -> Result<()> { + async fn test_schema_source_tracking_comprehensive() -> Result<()> { let ctx = SessionContext::new(); + let testdata = datafusion_test_data(); + let filename = format!("{testdata}/aggregate_simple.csv"); + let table_path = ListingTableUrl::parse(filename).unwrap(); - let table = load_table(&ctx, "alltypes_plain.parquet").await?; - let projection = None; - let exec = table - .scan(&ctx.state(), projection, &[], None) - .await - .expect("Scan table"); + // Test default schema source + let config = ListingTableConfig::new(table_path.clone()); + assert_eq!(config.schema_source(), SchemaSource::None); - assert_eq!(exec.children().len(), 0); - assert_eq!(exec.output_partitioning().partition_count(), 1); + // Test schema source after setting a schema explicitly + let provided_schema = create_test_schema(); + let config_with_schema = config.clone().with_schema(provided_schema.clone()); + assert_eq!(config_with_schema.schema_source(), SchemaSource::Specified); - // test metadata - assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); - assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); + // Test schema source after inferring schema + let format = CsvFormat::default(); + let options = ListingOptions::new(Arc::new(format)); + let config_with_options = config.with_listing_options(options.clone()); + assert_eq!(config_with_options.schema_source(), SchemaSource::None); - Ok(()) - } + let config_with_inferred = config_with_options.infer_schema(&ctx.state()).await?; + assert_eq!(config_with_inferred.schema_source(), SchemaSource::Inferred); - #[cfg(feature = "parquet")] - #[tokio::test] - async fn load_table_stats_by_default() -> Result<()> { - use crate::datasource::file_format::parquet::ParquetFormat; - - let testdata = crate::test_util::parquet_test_data(); - let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); - let table_path = ListingTableUrl::parse(filename).unwrap(); + // Test schema preservation through operations + let config_with_schema_and_options = config_with_schema + .clone() + .with_listing_options(options.clone()); + assert_eq!( + config_with_schema_and_options.schema_source(), + SchemaSource::Specified + ); - let ctx = SessionContext::new(); - let state = ctx.state(); + // Make sure inferred schema doesn't override specified schema + let config_with_schema_and_infer = config_with_schema_and_options + .clone() + .infer(&ctx.state()) + .await?; + assert_eq!( + config_with_schema_and_infer.schema_source(), + SchemaSource::Specified + ); - let opt = ListingOptions::new(Arc::new(ParquetFormat::default())); - let schema = opt.infer_schema(&state, &table_path).await?; - let config = ListingTableConfig::new(table_path) - .with_listing_options(opt) - .with_schema(schema); - let table = ListingTable::try_new(config)?; + // Verify sources in actual ListingTable objects + let table_specified = ListingTable::try_new(config_with_schema_and_options)?; + assert_eq!(table_specified.schema_source(), SchemaSource::Specified); - let exec = table.scan(&state, None, &[], None).await?; - assert_eq!(exec.statistics()?.num_rows, Precision::Exact(8)); - // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 - assert_eq!(exec.statistics()?.total_byte_size, Precision::Exact(671)); + let table_inferred = ListingTable::try_new(config_with_inferred)?; + assert_eq!(table_inferred.schema_source(), SchemaSource::Inferred); Ok(()) } - #[cfg(feature = "parquet")] #[tokio::test] - async fn load_table_stats_when_no_stats() -> Result<()> { - use crate::datasource::file_format::parquet::ParquetFormat; - - let testdata = crate::test_util::parquet_test_data(); - let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); - let table_path = ListingTableUrl::parse(filename).unwrap(); + async fn read_single_file() -> Result<()> { + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_collect_statistics(true), + ); - let ctx = SessionContext::new(); - let state = ctx.state(); + let table = load_table(&ctx, "alltypes_plain.parquet").await?; + let projection = None; + let exec = table + .scan(&ctx.state(), projection, &[], None) + .await + .expect("Scan table"); - let opt = ListingOptions::new(Arc::new(ParquetFormat::default())) - .with_collect_stat(false); - let schema = opt.infer_schema(&state, &table_path).await?; - let config = ListingTableConfig::new(table_path) - .with_listing_options(opt) - .with_schema(schema); - let table = ListingTable::try_new(config)?; + assert_eq!(exec.children().len(), 0); + assert_eq!(exec.output_partitioning().partition_count(), 1); - let exec = table.scan(&state, None, &[], None).await?; - assert_eq!(exec.statistics()?.num_rows, Precision::Absent); - assert_eq!(exec.statistics()?.total_byte_size, Precision::Absent); + // test metadata + assert_eq!( + exec.partition_statistics(None)?.num_rows, + Precision::Exact(8) + ); + assert_eq!( + exec.partition_statistics(None)?.total_byte_size, + Precision::Exact(671) + ); Ok(()) } @@ -1372,7 +1456,7 @@ mod tests { // (file_sort_order, expected_result) let cases = vec![ - (vec![], Ok(vec![])), + (vec![], Ok(Vec::::new())), // sort expr, but non column ( vec![vec![ @@ -1383,15 +1467,13 @@ mod tests { // ok with one column ( vec![vec![col("string_col").sort(true, false)]], - Ok(vec![LexOrdering::new( - vec![PhysicalSortExpr { + Ok(vec![[PhysicalSortExpr { expr: physical_col("string_col", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: false, }, - }], - ) + }].into(), ]) ), // ok with two columns, different options @@ -1400,16 +1482,14 @@ mod tests { col("string_col").sort(true, false), col("int_col").sort(false, true), ]], - Ok(vec![LexOrdering::new( - vec![ + Ok(vec![[ PhysicalSortExpr::new_default(physical_col("string_col", &schema).unwrap()) .asc() .nulls_last(), PhysicalSortExpr::new_default(physical_col("int_col", &schema).unwrap()) .desc() .nulls_first() - ], - ) + ].into(), ]) ), ]; @@ -1488,263 +1568,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_assert_list_files_for_scan_grouping() -> Result<()> { - // more expected partitions than files - assert_list_files_for_scan_grouping( - &[ - "bucket/key-prefix/file0", - "bucket/key-prefix/file1", - "bucket/key-prefix/file2", - "bucket/key-prefix/file3", - "bucket/key-prefix/file4", - ], - "test:///bucket/key-prefix/", - 12, - 5, - Some(""), - ) - .await?; - - // as many expected partitions as files - assert_list_files_for_scan_grouping( - &[ - "bucket/key-prefix/file0", - "bucket/key-prefix/file1", - "bucket/key-prefix/file2", - "bucket/key-prefix/file3", - ], - "test:///bucket/key-prefix/", - 4, - 4, - Some(""), - ) - .await?; - - // more files as expected partitions - assert_list_files_for_scan_grouping( - &[ - "bucket/key-prefix/file0", - "bucket/key-prefix/file1", - "bucket/key-prefix/file2", - "bucket/key-prefix/file3", - "bucket/key-prefix/file4", - ], - "test:///bucket/key-prefix/", - 2, - 2, - Some(""), - ) - .await?; - - // no files => no groups - assert_list_files_for_scan_grouping( - &[], - "test:///bucket/key-prefix/", - 2, - 0, - Some(""), - ) - .await?; - - // files that don't match the prefix - assert_list_files_for_scan_grouping( - &[ - "bucket/key-prefix/file0", - "bucket/key-prefix/file1", - "bucket/other-prefix/roguefile", - ], - "test:///bucket/key-prefix/", - 10, - 2, - Some(""), - ) - .await?; - - // files that don't match the prefix or the default file extention - assert_list_files_for_scan_grouping( - &[ - "bucket/key-prefix/file0.json", - "bucket/key-prefix/file1.parquet", - "bucket/other-prefix/roguefile.json", - ], - "test:///bucket/key-prefix/", - 10, - 1, - None, - ) - .await?; - Ok(()) - } - - #[tokio::test] - async fn test_assert_list_files_for_multi_path() -> Result<()> { - // more expected partitions than files - assert_list_files_for_multi_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - "bucket/key3/file5", - ], - &["test:///bucket/key1/", "test:///bucket/key2/"], - 12, - 5, - Some(""), - ) - .await?; - - // as many expected partitions as files - assert_list_files_for_multi_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - "bucket/key3/file5", - ], - &["test:///bucket/key1/", "test:///bucket/key2/"], - 5, - 5, - Some(""), - ) - .await?; - - // more files as expected partitions - assert_list_files_for_multi_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - "bucket/key3/file5", - ], - &["test:///bucket/key1/"], - 2, - 2, - Some(""), - ) - .await?; - - // no files => no groups - assert_list_files_for_multi_paths(&[], &["test:///bucket/key1/"], 2, 0, Some("")) - .await?; - - // files that don't match the prefix - assert_list_files_for_multi_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - "bucket/key3/file5", - ], - &["test:///bucket/key3/"], - 2, - 1, - Some(""), - ) - .await?; - - // files that don't match the prefix or the default file ext - assert_list_files_for_multi_paths( - &[ - "bucket/key1/file0.json", - "bucket/key1/file1.csv", - "bucket/key1/file2.json", - "bucket/key2/file3.csv", - "bucket/key2/file4.json", - "bucket/key3/file5.csv", - ], - &["test:///bucket/key1/", "test:///bucket/key3/"], - 2, - 2, - None, - ) - .await?; - Ok(()) - } - - #[tokio::test] - async fn test_assert_list_files_for_exact_paths() -> Result<()> { - // more expected partitions than files - assert_list_files_for_exact_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - ], - 12, - 5, - Some(""), - ) - .await?; - - // more files than meta_fetch_concurrency (32) - let files: Vec = - (0..64).map(|i| format!("bucket/key1/file{}", i)).collect(); - // Collect references to each string - let file_refs: Vec<&str> = files.iter().map(|s| s.as_str()).collect(); - assert_list_files_for_exact_paths(file_refs.as_slice(), 5, 5, Some("")).await?; - - // as many expected partitions as files - assert_list_files_for_exact_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - ], - 5, - 5, - Some(""), - ) - .await?; - - // more files as expected partitions - assert_list_files_for_exact_paths( - &[ - "bucket/key1/file0", - "bucket/key1/file1", - "bucket/key1/file2", - "bucket/key2/file3", - "bucket/key2/file4", - ], - 2, - 2, - Some(""), - ) - .await?; - - // no files => no groups - assert_list_files_for_exact_paths(&[], 2, 0, Some("")).await?; - - // files that don't match the default file ext - assert_list_files_for_exact_paths( - &[ - "bucket/key1/file0.json", - "bucket/key1/file1.csv", - "bucket/key1/file2.json", - "bucket/key2/file3.csv", - "bucket/key2/file4.json", - "bucket/key3/file5.csv", - ], - 2, - 2, - None, - ) - .await?; - Ok(()) - } - async fn load_table( ctx: &SessionContext, name: &str, @@ -1847,10 +1670,10 @@ mod tests { .execution .meta_fetch_concurrency; let expected_concurrency = files.len().min(meta_fetch_concurrency); - let head_blocking_store = ensure_head_concurrency(store, expected_concurrency); + let head_concurrency_store = ensure_head_concurrency(store, expected_concurrency); let url = Url::parse("test://").unwrap(); - ctx.register_object_store(&url, head_blocking_store.clone()); + ctx.register_object_store(&url, head_concurrency_store.clone()); let format = JsonFormat::default(); @@ -1862,7 +1685,7 @@ mod tests { let table_paths = files .iter() - .map(|t| ListingTableUrl::parse(format!("test:///{}", t)).unwrap()) + .map(|t| ListingTableUrl::parse(format!("test:///{t}")).unwrap()) .collect(); let config = ListingTableConfig::new_with_multi_paths(table_paths) .with_listing_options(opt) @@ -1877,80 +1700,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_insert_into_append_new_json_files() -> Result<()> { - let mut config_map: HashMap = HashMap::new(); - config_map.insert("datafusion.execution.batch_size".into(), "10".into()); - config_map.insert( - "datafusion.execution.soft_max_rows_per_output_file".into(), - "10".into(), - ); - helper_test_append_new_files_to_table( - JsonFormat::default().get_ext(), - FileCompressionType::UNCOMPRESSED, - Some(config_map), - 2, - ) - .await?; - Ok(()) - } - - #[tokio::test] - async fn test_insert_into_append_new_csv_files() -> Result<()> { - let mut config_map: HashMap = HashMap::new(); - config_map.insert("datafusion.execution.batch_size".into(), "10".into()); - config_map.insert( - "datafusion.execution.soft_max_rows_per_output_file".into(), - "10".into(), - ); - helper_test_append_new_files_to_table( - CsvFormat::default().get_ext(), - FileCompressionType::UNCOMPRESSED, - Some(config_map), - 2, - ) - .await?; - Ok(()) - } - - #[cfg(feature = "parquet")] - #[tokio::test] - async fn test_insert_into_append_2_new_parquet_files_defaults() -> Result<()> { - let mut config_map: HashMap = HashMap::new(); - config_map.insert("datafusion.execution.batch_size".into(), "10".into()); - config_map.insert( - "datafusion.execution.soft_max_rows_per_output_file".into(), - "10".into(), - ); - helper_test_append_new_files_to_table( - ParquetFormat::default().get_ext(), - FileCompressionType::UNCOMPRESSED, - Some(config_map), - 2, - ) - .await?; - Ok(()) - } - - #[cfg(feature = "parquet")] - #[tokio::test] - async fn test_insert_into_append_1_new_parquet_files_defaults() -> Result<()> { - let mut config_map: HashMap = HashMap::new(); - config_map.insert("datafusion.execution.batch_size".into(), "20".into()); - config_map.insert( - "datafusion.execution.soft_max_rows_per_output_file".into(), - "20".into(), - ); - helper_test_append_new_files_to_table( - ParquetFormat::default().get_ext(), - FileCompressionType::UNCOMPRESSED, - Some(config_map), - 1, - ) - .await?; - Ok(()) - } - #[tokio::test] async fn test_insert_into_sql_csv_defaults() -> Result<()> { helper_test_insert_into_sql("csv", FileCompressionType::UNCOMPRESSED, "", None) @@ -2183,7 +1932,7 @@ mod tests { let filter_predicate = Expr::BinaryExpr(BinaryExpr::new( Box::new(Expr::Column("column1".into())), Operator::GtEq, - Box::new(Expr::Literal(ScalarValue::Int32(Some(0)))), + Box::new(Expr::Literal(ScalarValue::Int32(Some(0)), None)), )); // Create a new batch of data to insert into the table @@ -2367,8 +2116,10 @@ mod tests { // create table let tmp_dir = TempDir::new()?; - let tmp_path = tmp_dir.into_path(); - let str_path = tmp_path.to_str().expect("Temp path should convert to &str"); + let str_path = tmp_dir + .path() + .to_str() + .expect("Temp path should convert to &str"); session_ctx .sql(&format!( "create external table foo(a varchar, b varchar, c int) \ @@ -2409,7 +2160,7 @@ mod tests { #[tokio::test] async fn test_infer_options_compressed_csv() -> Result<()> { let testdata = crate::test_util::arrow_test_data(); - let filename = format!("{}/csv/aggregate_test_100.csv.gz", testdata); + let filename = format!("{testdata}/csv/aggregate_test_100.csv.gz"); let table_path = ListingTableUrl::parse(filename).unwrap(); let ctx = SessionContext::new(); @@ -2424,4 +2175,382 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn infer_preserves_provided_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let testdata = datafusion_test_data(); + let filename = format!("{testdata}/aggregate_simple.csv"); + let table_path = ListingTableUrl::parse(filename).unwrap(); + + let provided_schema = create_test_schema(); + + let config = + ListingTableConfig::new(table_path).with_schema(Arc::clone(&provided_schema)); + + let config = config.infer(&ctx.state()).await?; + + assert_eq!(*config.file_schema.unwrap(), *provided_schema); + + Ok(()) + } + + #[tokio::test] + async fn test_listing_table_config_with_multiple_files_comprehensive() -> Result<()> { + let ctx = SessionContext::new(); + + // Create test files with different schemas + let tmp_dir = TempDir::new()?; + let file_path1 = tmp_dir.path().join("file1.csv"); + let file_path2 = tmp_dir.path().join("file2.csv"); + + // File 1: c1,c2,c3 + let mut file1 = std::fs::File::create(&file_path1)?; + writeln!(file1, "c1,c2,c3")?; + writeln!(file1, "1,2,3")?; + writeln!(file1, "4,5,6")?; + + // File 2: c1,c2,c3,c4 + let mut file2 = std::fs::File::create(&file_path2)?; + writeln!(file2, "c1,c2,c3,c4")?; + writeln!(file2, "7,8,9,10")?; + writeln!(file2, "11,12,13,14")?; + + // Parse paths + let table_path1 = ListingTableUrl::parse(file_path1.to_str().unwrap())?; + let table_path2 = ListingTableUrl::parse(file_path2.to_str().unwrap())?; + + // Create format and options + let format = CsvFormat::default().with_has_header(true); + let options = ListingOptions::new(Arc::new(format)); + + // Test case 1: Infer schema using first file's schema + let config1 = ListingTableConfig::new_with_multi_paths(vec![ + table_path1.clone(), + table_path2.clone(), + ]) + .with_listing_options(options.clone()); + let config1 = config1.infer_schema(&ctx.state()).await?; + assert_eq!(config1.schema_source(), SchemaSource::Inferred); + + // Verify schema matches first file + let schema1 = config1.file_schema.as_ref().unwrap().clone(); + assert_eq!(schema1.fields().len(), 3); + assert_eq!(schema1.field(0).name(), "c1"); + assert_eq!(schema1.field(1).name(), "c2"); + assert_eq!(schema1.field(2).name(), "c3"); + + // Test case 2: Use specified schema with 3 columns + let schema_3cols = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Utf8, true), + Field::new("c3", DataType::Utf8, true), + ])); + + let config2 = ListingTableConfig::new_with_multi_paths(vec![ + table_path1.clone(), + table_path2.clone(), + ]) + .with_schema(schema_3cols) + .with_listing_options(options.clone()); + let config2 = config2.infer_schema(&ctx.state()).await?; + assert_eq!(config2.schema_source(), SchemaSource::Specified); + + // Verify that the schema is still the one we specified (3 columns) + let schema2 = config2.file_schema.as_ref().unwrap().clone(); + assert_eq!(schema2.fields().len(), 3); + assert_eq!(schema2.field(0).name(), "c1"); + assert_eq!(schema2.field(1).name(), "c2"); + assert_eq!(schema2.field(2).name(), "c3"); + + // Test case 3: Use specified schema with 4 columns + let schema_4cols = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Utf8, true), + Field::new("c3", DataType::Utf8, true), + Field::new("c4", DataType::Utf8, true), + ])); + + let config3 = ListingTableConfig::new_with_multi_paths(vec![ + table_path1.clone(), + table_path2.clone(), + ]) + .with_schema(schema_4cols) + .with_listing_options(options.clone()); + let config3 = config3.infer_schema(&ctx.state()).await?; + assert_eq!(config3.schema_source(), SchemaSource::Specified); + + // Verify that the schema is still the one we specified (4 columns) + let schema3 = config3.file_schema.as_ref().unwrap().clone(); + assert_eq!(schema3.fields().len(), 4); + assert_eq!(schema3.field(0).name(), "c1"); + assert_eq!(schema3.field(1).name(), "c2"); + assert_eq!(schema3.field(2).name(), "c3"); + assert_eq!(schema3.field(3).name(), "c4"); + + // Test case 4: Verify order matters when inferring schema + let config4 = ListingTableConfig::new_with_multi_paths(vec![ + table_path2.clone(), + table_path1.clone(), + ]) + .with_listing_options(options); + let config4 = config4.infer_schema(&ctx.state()).await?; + + // Should use first file's schema, which now has 4 columns + let schema4 = config4.file_schema.as_ref().unwrap().clone(); + assert_eq!(schema4.fields().len(), 4); + assert_eq!(schema4.field(0).name(), "c1"); + assert_eq!(schema4.field(1).name(), "c2"); + assert_eq!(schema4.field(2).name(), "c3"); + assert_eq!(schema4.field(3).name(), "c4"); + + Ok(()) + } + + #[tokio::test] + async fn test_list_files_configurations() -> Result<()> { + // Define common test cases as (description, files, paths, target_partitions, expected_partitions, file_ext) + let test_cases = vec![ + // Single path cases + ( + "Single path, more partitions than files", + generate_test_files("bucket/key-prefix", 5), + vec!["test:///bucket/key-prefix/"], + 12, + 5, + Some(""), + ), + ( + "Single path, equal partitions and files", + generate_test_files("bucket/key-prefix", 4), + vec!["test:///bucket/key-prefix/"], + 4, + 4, + Some(""), + ), + ( + "Single path, more files than partitions", + generate_test_files("bucket/key-prefix", 5), + vec!["test:///bucket/key-prefix/"], + 2, + 2, + Some(""), + ), + // Multi path cases + ( + "Multi path, more partitions than files", + { + let mut files = generate_test_files("bucket/key1", 3); + files.extend(generate_test_files_with_start("bucket/key2", 2, 3)); + files.extend(generate_test_files_with_start("bucket/key3", 1, 5)); + files + }, + vec!["test:///bucket/key1/", "test:///bucket/key2/"], + 12, + 5, + Some(""), + ), + // No files case + ( + "No files", + vec![], + vec!["test:///bucket/key-prefix/"], + 2, + 0, + Some(""), + ), + // Exact path cases + ( + "Exact paths test", + { + let mut files = generate_test_files("bucket/key1", 3); + files.extend(generate_test_files_with_start("bucket/key2", 2, 3)); + files + }, + vec![ + "test:///bucket/key1/file0", + "test:///bucket/key1/file1", + "test:///bucket/key1/file2", + "test:///bucket/key2/file3", + "test:///bucket/key2/file4", + ], + 12, + 5, + Some(""), + ), + ]; + + // Run each test case + for (test_name, files, paths, target_partitions, expected_partitions, file_ext) in + test_cases + { + println!("Running test: {test_name}"); + + if files.is_empty() { + // Test empty files case + assert_list_files_for_multi_paths( + &[], + &paths, + target_partitions, + expected_partitions, + file_ext, + ) + .await?; + } else if paths.len() == 1 { + // Test using single path API + let file_refs: Vec<&str> = files.iter().map(|s| s.as_str()).collect(); + assert_list_files_for_scan_grouping( + &file_refs, + paths[0], + target_partitions, + expected_partitions, + file_ext, + ) + .await?; + } else if paths[0].contains("test:///bucket/key") { + // Test using multi path API + let file_refs: Vec<&str> = files.iter().map(|s| s.as_str()).collect(); + assert_list_files_for_multi_paths( + &file_refs, + &paths, + target_partitions, + expected_partitions, + file_ext, + ) + .await?; + } else { + // Test using exact path API for specific cases + let file_refs: Vec<&str> = files.iter().map(|s| s.as_str()).collect(); + assert_list_files_for_exact_paths( + &file_refs, + target_partitions, + expected_partitions, + file_ext, + ) + .await?; + } + } + + Ok(()) + } + + #[cfg(feature = "parquet")] + #[tokio::test] + async fn test_table_stats_behaviors() -> Result<()> { + use crate::datasource::file_format::parquet::ParquetFormat; + + let testdata = crate::test_util::parquet_test_data(); + let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); + let table_path = ListingTableUrl::parse(filename).unwrap(); + + let ctx = SessionContext::new(); + let state = ctx.state(); + + // Test 1: Default behavior - stats not collected + let opt_default = ListingOptions::new(Arc::new(ParquetFormat::default())); + let schema_default = opt_default.infer_schema(&state, &table_path).await?; + let config_default = ListingTableConfig::new(table_path.clone()) + .with_listing_options(opt_default) + .with_schema(schema_default); + let table_default = ListingTable::try_new(config_default)?; + + let exec_default = table_default.scan(&state, None, &[], None).await?; + assert_eq!( + exec_default.partition_statistics(None)?.num_rows, + Precision::Absent + ); + + // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 + assert_eq!( + exec_default.partition_statistics(None)?.total_byte_size, + Precision::Absent + ); + + // Test 2: Explicitly disable stats + let opt_disabled = ListingOptions::new(Arc::new(ParquetFormat::default())) + .with_collect_stat(false); + let schema_disabled = opt_disabled.infer_schema(&state, &table_path).await?; + let config_disabled = ListingTableConfig::new(table_path.clone()) + .with_listing_options(opt_disabled) + .with_schema(schema_disabled); + let table_disabled = ListingTable::try_new(config_disabled)?; + + let exec_disabled = table_disabled.scan(&state, None, &[], None).await?; + assert_eq!( + exec_disabled.partition_statistics(None)?.num_rows, + Precision::Absent + ); + assert_eq!( + exec_disabled.partition_statistics(None)?.total_byte_size, + Precision::Absent + ); + + // Test 3: Explicitly enable stats + let opt_enabled = ListingOptions::new(Arc::new(ParquetFormat::default())) + .with_collect_stat(true); + let schema_enabled = opt_enabled.infer_schema(&state, &table_path).await?; + let config_enabled = ListingTableConfig::new(table_path) + .with_listing_options(opt_enabled) + .with_schema(schema_enabled); + let table_enabled = ListingTable::try_new(config_enabled)?; + + let exec_enabled = table_enabled.scan(&state, None, &[], None).await?; + assert_eq!( + exec_enabled.partition_statistics(None)?.num_rows, + Precision::Exact(8) + ); + // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 + assert_eq!( + exec_enabled.partition_statistics(None)?.total_byte_size, + Precision::Exact(671) + ); + + Ok(()) + } + + #[tokio::test] + async fn test_insert_into_parameterized() -> Result<()> { + let test_cases = vec![ + // (file_format, batch_size, soft_max_rows, expected_files) + ("json", 10, 10, 2), + ("csv", 10, 10, 2), + #[cfg(feature = "parquet")] + ("parquet", 10, 10, 2), + #[cfg(feature = "parquet")] + ("parquet", 20, 20, 1), + ]; + + for (format, batch_size, soft_max_rows, expected_files) in test_cases { + println!("Testing insert with format: {format}, batch_size: {batch_size}, expected files: {expected_files}"); + + let mut config_map = HashMap::new(); + config_map.insert( + "datafusion.execution.batch_size".into(), + batch_size.to_string(), + ); + config_map.insert( + "datafusion.execution.soft_max_rows_per_output_file".into(), + soft_max_rows.to_string(), + ); + + let file_extension = match format { + "json" => JsonFormat::default().get_ext(), + "csv" => CsvFormat::default().get_ext(), + #[cfg(feature = "parquet")] + "parquet" => ParquetFormat::default().get_ext(), + _ => unreachable!("Unsupported format"), + }; + + helper_test_append_new_files_to_table( + file_extension, + FileCompressionType::UNCOMPRESSED, + Some(config_map), + expected_files, + ) + .await?; + } + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 636d1623c5e9..580fa4be47af 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -108,12 +108,11 @@ impl TableProviderFactory for ListingTableFactory { (Some(schema), table_partition_cols) }; - let table_path = ListingTableUrl::parse(&cmd.location)?; + let mut table_path = ListingTableUrl::parse(&cmd.location)?; let options = ListingOptions::new(file_format) - .with_collect_stat(state.config().collect_statistics()) - .with_file_extension(file_extension) - .with_target_partitions(state.config().target_partitions()) + .with_file_extension(&file_extension) + .with_session_config_options(session_state.config()) .with_table_partition_cols(table_partition_cols); options @@ -126,6 +125,13 @@ impl TableProviderFactory for ListingTableFactory { // specifically for parquet file format. // See: https://github.com/apache/datafusion/issues/7317 None => { + // if the folder then rewrite a file path as 'path/*.parquet' + // to only read the files the reader can understand + if table_path.is_folder() && table_path.get_glob().is_none() { + table_path = table_path.with_glob( + format!("*.{}", cmd.file_type.to_lowercase()).as_ref(), + )?; + } let schema = options.infer_schema(session_state, &table_path).await?; let df_schema = Arc::clone(&schema).to_dfschema()?; let column_refs: HashSet<_> = cmd @@ -202,7 +208,7 @@ mod tests { order_exprs: vec![], unbounded: false, options: HashMap::from([("format.has_header".into(), "true".into())]), - constraints: Constraints::empty(), + constraints: Constraints::default(), column_defaults: HashMap::new(), }; let table_provider = factory.create(&state, &cmd).await.unwrap(); @@ -242,7 +248,7 @@ mod tests { order_exprs: vec![], unbounded: false, options, - constraints: Constraints::empty(), + constraints: Constraints::default(), column_defaults: HashMap::new(), }; let table_provider = factory.create(&state, &cmd).await.unwrap(); diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 25a89644cd2a..94d651ddadd5 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -52,27 +52,26 @@ pub use datafusion_physical_expr::create_ordering; mod tests { use crate::prelude::SessionContext; - - use std::fs; - use std::sync::Arc; - - use arrow::array::{Int32Array, StringArray}; - use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use arrow::record_batch::RecordBatch; - use datafusion_common::test_util::batches_to_sort_string; - use datafusion_datasource::file_scan_config::FileScanConfigBuilder; - use datafusion_datasource::schema_adapter::{ - DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, SchemaMapper, + use ::object_store::{path::Path, ObjectMeta}; + use arrow::{ + array::{Int32Array, StringArray}, + datatypes::{DataType, Field, Schema, SchemaRef}, + record_batch::RecordBatch, + }; + use datafusion_common::{record_batch, test_util::batches_to_sort_string}; + use datafusion_datasource::{ + file::FileSource, + file_scan_config::FileScanConfigBuilder, + schema_adapter::{ + DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, + SchemaMapper, + }, + source::DataSourceExec, + PartitionedFile, }; - use datafusion_datasource::PartitionedFile; use datafusion_datasource_parquet::source::ParquetSource; - - use datafusion_common::record_batch; - - use ::object_store::path::Path; - use ::object_store::ObjectMeta; - use datafusion_datasource::source::DataSourceExec; use datafusion_physical_plan::collect; + use std::{fs, sync::Arc}; use tempfile::TempDir; #[tokio::test] @@ -124,10 +123,9 @@ mod tests { let f2 = Field::new("extra_column", DataType::Utf8, true); let schema = Arc::new(Schema::new(vec![f1.clone(), f2.clone()])); - let source = Arc::new( - ParquetSource::default() - .with_schema_adapter_factory(Arc::new(TestSchemaAdapterFactory {})), - ); + let source = ParquetSource::default() + .with_schema_adapter_factory(Arc::new(TestSchemaAdapterFactory {})) + .unwrap(); let base_conf = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), schema, @@ -264,5 +262,12 @@ mod tests { Ok(RecordBatch::try_new(schema, new_columns).unwrap()) } + + fn map_column_statistics( + &self, + _file_col_statistics: &[datafusion_common::ColumnStatistics], + ) -> datafusion_common::Result> { + unimplemented!() + } } } diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/core/src/datasource/physical_plan/arrow_file.rs index f0a1f94d87e1..d0af96329b5f 100644 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ b/datafusion/core/src/datasource/physical_plan/arrow_file.rs @@ -15,196 +15,40 @@ // specific language governing permissions and limitations // under the License. -//! Execution plan for reading Arrow files - use std::any::Any; use std::sync::Arc; use crate::datasource::physical_plan::{FileMeta, FileOpenFuture, FileOpener}; use crate::error::Result; +use datafusion_datasource::as_file_source; +use datafusion_datasource::schema_adapter::SchemaAdapterFactory; use arrow::buffer::Buffer; use arrow::datatypes::SchemaRef; use arrow_ipc::reader::FileDecoder; -use datafusion_common::config::ConfigOptions; -use datafusion_common::{Constraints, Statistics}; +use datafusion_common::Statistics; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_datasource::source::DataSourceExec; -use datafusion_datasource_json::source::JsonSource; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; -use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, -}; +use datafusion_datasource::PartitionedFile; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; -use datafusion_datasource::file_groups::FileGroup; use futures::StreamExt; use itertools::Itertools; use object_store::{GetOptions, GetRange, GetResultPayload, ObjectStore}; -/// Execution plan for scanning Arrow data source -#[derive(Debug, Clone)] -#[deprecated(since = "46.0.0", note = "use DataSourceExec instead")] -pub struct ArrowExec { - inner: DataSourceExec, - base_config: FileScanConfig, -} - -#[allow(unused, deprecated)] -impl ArrowExec { - /// Create a new Arrow reader execution plan provided base configurations - pub fn new(base_config: FileScanConfig) -> Self { - let ( - projected_schema, - projected_constraints, - projected_statistics, - projected_output_ordering, - ) = base_config.project(); - let cache = Self::compute_properties( - Arc::clone(&projected_schema), - &projected_output_ordering, - projected_constraints, - &base_config, - ); - let arrow = ArrowSource::default(); - let base_config = base_config.with_source(Arc::new(arrow)); - Self { - inner: DataSourceExec::new(Arc::new(base_config.clone())), - base_config, - } - } - /// Ref to the base configs - pub fn base_config(&self) -> &FileScanConfig { - &self.base_config - } - - fn file_scan_config(&self) -> FileScanConfig { - self.inner - .data_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - fn json_source(&self) -> JsonSource { - self.file_scan_config() - .file_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - fn output_partitioning_helper(file_scan_config: &FileScanConfig) -> Partitioning { - Partitioning::UnknownPartitioning(file_scan_config.file_groups.len()) - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - output_ordering: &[LexOrdering], - constraints: Constraints, - file_scan_config: &FileScanConfig, - ) -> PlanProperties { - // Equivalence Properties - let eq_properties = - EquivalenceProperties::new_with_orderings(schema, output_ordering) - .with_constraints(constraints); - - PlanProperties::new( - eq_properties, - Self::output_partitioning_helper(file_scan_config), // Output Partitioning - EmissionType::Incremental, - Boundedness::Bounded, - ) - } - - fn with_file_groups(mut self, file_groups: Vec) -> Self { - self.base_config.file_groups = file_groups.clone(); - let mut file_source = self.file_scan_config(); - file_source = file_source.with_file_groups(file_groups); - self.inner = self.inner.with_data_source(Arc::new(file_source)); - self - } -} - -#[allow(unused, deprecated)] -impl DisplayAs for ArrowExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - self.inner.fmt_as(t, f) - } -} - -#[allow(unused, deprecated)] -impl ExecutionPlan for ArrowExec { - fn name(&self) -> &'static str { - "ArrowExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - self.inner.properties() - } - fn children(&self) -> Vec<&Arc> { - Vec::new() - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - /// Redistribute files across partitions according to their size - /// See comments on `FileGroupPartitioner` for more detail. - fn repartitioned( - &self, - target_partitions: usize, - config: &ConfigOptions, - ) -> Result>> { - self.inner.repartitioned(target_partitions, config) - } - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - self.inner.execute(partition, context) - } - fn metrics(&self) -> Option { - self.inner.metrics() - } - fn statistics(&self) -> Result { - self.inner.statistics() - } - fn fetch(&self) -> Option { - self.inner.fetch() - } - - fn with_fetch(&self, limit: Option) -> Option> { - self.inner.with_fetch(limit) - } -} - /// Arrow configuration struct that is given to DataSourceExec /// Does not hold anything special, since [`FileScanConfig`] is sufficient for arrow #[derive(Clone, Default)] pub struct ArrowSource { metrics: ExecutionPlanMetricsSet, projected_statistics: Option, + schema_adapter_factory: Option>, +} + +impl From for Arc { + fn from(source: ArrowSource) -> Self { + as_file_source(source) + } } impl FileSource for ArrowSource { @@ -255,6 +99,20 @@ impl FileSource for ArrowSource { fn file_type(&self) -> &str { "arrow" } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } } /// The struct arrow that implements `[FileOpener]` trait @@ -264,7 +122,11 @@ pub struct ArrowOpener { } impl FileOpener for ArrowOpener { - fn open(&self, file_meta: FileMeta) -> Result { + fn open( + &self, + file_meta: FileMeta, + _file: PartitionedFile, + ) -> Result { let object_store = Arc::clone(&self.object_store); let projection = self.projection.clone(); Ok(Box::pin(async move { diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 3ef403013452..e33761a0abb3 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -369,7 +369,8 @@ mod tests { .build(); // Add partition columns - config.table_partition_cols = vec![Field::new("date", DataType::Utf8, false)]; + config.table_partition_cols = + vec![Arc::new(Field::new("date", DataType::Utf8, false))]; config.file_groups[0][0].partition_values = vec![ScalarValue::from("2021-10-26")]; // We should be able to project on the partition column diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 736248fbd95d..0d45711c76fb 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -19,7 +19,6 @@ //! //! [`FileSource`]: datafusion_datasource::file::FileSource -#[allow(deprecated)] pub use datafusion_datasource_json::source::*; #[cfg(test)] diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index e3f237803b34..3f71b253d969 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -27,30 +27,18 @@ pub mod parquet; #[cfg(feature = "avro")] pub mod avro; -#[allow(deprecated)] #[cfg(feature = "avro")] -pub use avro::{AvroExec, AvroSource}; +pub use avro::AvroSource; #[cfg(feature = "parquet")] pub use datafusion_datasource_parquet::source::ParquetSource; #[cfg(feature = "parquet")] -#[allow(deprecated)] -pub use datafusion_datasource_parquet::{ - ParquetExec, ParquetExecBuilder, ParquetFileMetrics, ParquetFileReaderFactory, -}; +pub use datafusion_datasource_parquet::{ParquetFileMetrics, ParquetFileReaderFactory}; -#[allow(deprecated)] -pub use arrow_file::ArrowExec; pub use arrow_file::ArrowSource; -#[allow(deprecated)] -pub use json::NdJsonExec; - pub use json::{JsonOpener, JsonSource}; -#[allow(deprecated)] -pub use csv::{CsvExec, CsvExecBuilder}; - pub use csv::{CsvOpener, CsvSource}; pub use datafusion_datasource::file::FileSource; pub use datafusion_datasource::file_groups::FileGroup; diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet.rs index e9bb8b0db368..55db0d854204 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet.rs @@ -39,12 +39,12 @@ mod tests { use crate::test::object_store::local_unpartitioned_file; use arrow::array::{ ArrayRef, AsArray, Date64Array, Int32Array, Int64Array, Int8Array, StringArray, - StructArray, + StringViewArray, StructArray, TimestampNanosecondArray, }; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaBuilder}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; - use arrow_schema::SchemaRef; + use arrow_schema::{SchemaRef, TimeUnit}; use bytes::{BufMut, BytesMut}; use datafusion_common::config::TableParquetOptions; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; @@ -54,6 +54,7 @@ mod tests { use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; + use datafusion_datasource::file::FileSource; use datafusion_datasource::{FileRange, PartitionedFile}; use datafusion_datasource_parquet::source::ParquetSource; use datafusion_datasource_parquet::{ @@ -95,10 +96,15 @@ mod tests { #[derive(Debug, Default)] struct RoundTrip { projection: Option>, - schema: Option, + /// Optional logical table schema to use when reading the parquet files + /// + /// If None, the logical schema to use will be inferred from the + /// original data via [`Schema::try_merge`] + table_schema: Option, predicate: Option, pushdown_predicate: bool, page_index_predicate: bool, + bloom_filters: bool, } impl RoundTrip { @@ -111,8 +117,11 @@ mod tests { self } - fn with_schema(mut self, schema: SchemaRef) -> Self { - self.schema = Some(schema); + /// Specify table schema. + /// + ///See [`Self::table_schema`] for more details + fn with_table_schema(mut self, schema: SchemaRef) -> Self { + self.table_schema = Some(schema); self } @@ -131,6 +140,11 @@ mod tests { self } + fn with_bloom_filters(mut self) -> Self { + self.bloom_filters = true; + self + } + /// run the test, returning only the resulting RecordBatches async fn round_trip_to_batches( self, @@ -139,36 +153,46 @@ mod tests { self.round_trip(batches).await.batches } - fn build_file_source(&self, file_schema: SchemaRef) -> Arc { + fn build_file_source(&self, table_schema: SchemaRef) -> Arc { // set up predicate (this is normally done by a layer higher up) let predicate = self .predicate .as_ref() - .map(|p| logical2physical(p, &file_schema)); + .map(|p| logical2physical(p, &table_schema)); let mut source = ParquetSource::default(); if let Some(predicate) = predicate { - source = source.with_predicate(Arc::clone(&file_schema), predicate); + source = source.with_predicate(predicate); } if self.pushdown_predicate { source = source .with_pushdown_filters(true) .with_reorder_filters(true); + } else { + source = source.with_pushdown_filters(false); } if self.page_index_predicate { source = source.with_enable_page_index(true); + } else { + source = source.with_enable_page_index(false); + } + + if self.bloom_filters { + source = source.with_bloom_filter_on_read(true); + } else { + source = source.with_bloom_filter_on_read(false); } - Arc::new(source) + source.with_schema(Arc::clone(&table_schema)) } fn build_parquet_exec( &self, file_schema: SchemaRef, file_group: FileGroup, - source: Arc, + source: Arc, ) -> Arc { let base_config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), @@ -182,8 +206,14 @@ mod tests { } /// run the test, returning the `RoundTripResult` + /// + /// Each input batch is written into one or more parquet files (and thus + /// they could potentially have different schemas). The resulting + /// parquet files are then read back and filters are applied to the async fn round_trip(&self, batches: Vec) -> RoundTripResult { - let file_schema = match &self.schema { + // If table_schema is not set, we need to merge the schema of the + // input batches to get a unified schema. + let table_schema = match &self.table_schema { Some(schema) => schema, None => &Arc::new( Schema::try_merge( @@ -192,7 +222,6 @@ mod tests { .unwrap(), ), }; - let file_schema = Arc::clone(file_schema); // If testing with page_index_predicate, write parquet // files with multiple pages let multi_page = self.page_index_predicate; @@ -200,9 +229,9 @@ mod tests { let file_group: FileGroup = meta.into_iter().map(Into::into).collect(); // build a ParquetExec to return the results - let parquet_source = self.build_file_source(file_schema.clone()); + let parquet_source = self.build_file_source(Arc::clone(table_schema)); let parquet_exec = self.build_parquet_exec( - file_schema.clone(), + Arc::clone(table_schema), file_group.clone(), Arc::clone(&parquet_source), ); @@ -212,9 +241,9 @@ mod tests { false, // use a new ParquetSource to avoid sharing execution metrics self.build_parquet_exec( - file_schema.clone(), + Arc::clone(table_schema), file_group.clone(), - self.build_file_source(file_schema.clone()), + self.build_file_source(Arc::clone(table_schema)), ), Arc::new(Schema::new(vec![ Field::new("plan_type", DataType::Utf8, true), @@ -287,7 +316,7 @@ mod tests { // Thus this predicate will come back as false. let filter = col("c2").eq(lit(1_i32)); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -306,7 +335,7 @@ mod tests { // If we excplicitly allow nulls the rest of the predicate should work let filter = col("c2").is_null().and(col("c1").eq(lit(1_i32))); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -345,7 +374,7 @@ mod tests { // Thus this predicate will come back as false. let filter = col("c2").eq(lit("abc")); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -364,7 +393,7 @@ mod tests { // If we excplicitly allow nulls the rest of the predicate should work let filter = col("c2").is_null().and(col("c1").eq(lit(1_i32))); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -407,7 +436,7 @@ mod tests { // Thus this predicate will come back as false. let filter = col("c2").eq(lit("abc")); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -426,7 +455,7 @@ mod tests { // If we excplicitly allow nulls the rest of the predicate should work let filter = col("c2").is_null().and(col("c1").eq(lit(1_i32))); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -469,7 +498,7 @@ mod tests { // Thus this predicate will come back as false. let filter = col("c2").eq(lit("abc")); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -488,7 +517,7 @@ mod tests { // If we excplicitly allow nulls the rest of the predicate should work let filter = col("c2").is_null().and(col("c3").eq(lit(7_i32))); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -536,7 +565,7 @@ mod tests { .and(col("c3").eq(lit(10_i32)).or(col("c2").is_null())); let rt = RoundTrip::new() - .with_schema(table_schema.clone()) + .with_table_schema(table_schema.clone()) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch.clone()]) @@ -566,7 +595,7 @@ mod tests { .or(col("c3").gt(lit(20_i32)).and(col("c2").is_null())); let rt = RoundTrip::new() - .with_schema(table_schema) + .with_table_schema(table_schema) .with_predicate(filter.clone()) .with_pushdown_predicate() .round_trip(vec![batch]) @@ -816,7 +845,7 @@ mod tests { } #[tokio::test] - async fn evolved_schema_filter() { + async fn evolved_schema_column_order_filter() { let c1: ArrayRef = Arc::new(StringArray::from(vec![Some("Foo"), None, Some("bar")])); @@ -847,6 +876,156 @@ mod tests { assert_eq!(read.len(), 0); } + #[tokio::test] + async fn evolved_schema_column_type_filter_strings() { + // The table and filter have a common data type, but the file schema differs + let c1: ArrayRef = + Arc::new(StringViewArray::from(vec![Some("foo"), Some("bar")])); + let batch = create_batch(vec![("c1", c1.clone())]); + + // Table schema is Utf8 but file schema is StringView + let table_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, false)])); + + // Predicate should prune all row groups + let filter = col("c1").eq(lit(ScalarValue::Utf8(Some("aaa".to_string())))); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_table_schema(table_schema.clone()) + .round_trip(vec![batch.clone()]) + .await; + // There should be no predicate evaluation errors + let metrics = rt.parquet_exec.metrics().unwrap(); + assert_eq!(get_value(&metrics, "predicate_evaluation_errors"), 0); + assert_eq!(get_value(&metrics, "pushdown_rows_matched"), 0); + assert_eq!(rt.batches.unwrap().len(), 0); + + // Predicate should prune no row groups + let filter = col("c1").eq(lit(ScalarValue::Utf8(Some("foo".to_string())))); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_table_schema(table_schema) + .round_trip(vec![batch]) + .await; + // There should be no predicate evaluation errors + let metrics = rt.parquet_exec.metrics().unwrap(); + assert_eq!(get_value(&metrics, "predicate_evaluation_errors"), 0); + assert_eq!(get_value(&metrics, "pushdown_rows_matched"), 0); + let read = rt + .batches + .unwrap() + .iter() + .map(|b| b.num_rows()) + .sum::(); + assert_eq!(read, 2, "Expected 2 rows to match the predicate"); + } + + #[tokio::test] + async fn evolved_schema_column_type_filter_ints() { + // The table and filter have a common data type, but the file schema differs + let c1: ArrayRef = Arc::new(Int8Array::from(vec![Some(1), Some(2)])); + let batch = create_batch(vec![("c1", c1.clone())]); + + let table_schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::UInt64, false)])); + + // Predicate should prune all row groups + let filter = col("c1").eq(lit(ScalarValue::UInt64(Some(5)))); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_table_schema(table_schema.clone()) + .round_trip(vec![batch.clone()]) + .await; + // There should be no predicate evaluation errors + let metrics = rt.parquet_exec.metrics().unwrap(); + assert_eq!(get_value(&metrics, "predicate_evaluation_errors"), 0); + assert_eq!(rt.batches.unwrap().len(), 0); + + // Predicate should prune no row groups + let filter = col("c1").eq(lit(ScalarValue::UInt64(Some(1)))); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_table_schema(table_schema) + .round_trip(vec![batch]) + .await; + // There should be no predicate evaluation errors + let metrics = rt.parquet_exec.metrics().unwrap(); + assert_eq!(get_value(&metrics, "predicate_evaluation_errors"), 0); + let read = rt + .batches + .unwrap() + .iter() + .map(|b| b.num_rows()) + .sum::(); + assert_eq!(read, 2, "Expected 2 rows to match the predicate"); + } + + #[tokio::test] + async fn evolved_schema_column_type_filter_timestamp_units() { + // The table and filter have a common data type + // The table schema is in milliseconds, but the file schema is in nanoseconds + let c1: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![ + Some(1_000_000_000), // 1970-01-01T00:00:01Z + Some(2_000_000_000), // 1970-01-01T00:00:02Z + Some(3_000_000_000), // 1970-01-01T00:00:03Z + Some(4_000_000_000), // 1970-01-01T00:00:04Z + ])); + let batch = create_batch(vec![("c1", c1.clone())]); + let table_schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), + false, + )])); + // One row should match, 2 pruned via page index, 1 pruned via filter pushdown + let filter = col("c1").eq(lit(ScalarValue::TimestampMillisecond( + Some(1_000), + Some("UTC".into()), + ))); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_pushdown_predicate() + .with_page_index_predicate() // produces pages with 2 rows each (2 pages total for our data) + .with_table_schema(table_schema.clone()) + .round_trip(vec![batch.clone()]) + .await; + // There should be no predicate evaluation errors and we keep 1 row + let metrics = rt.parquet_exec.metrics().unwrap(); + assert_eq!(get_value(&metrics, "predicate_evaluation_errors"), 0); + let read = rt + .batches + .unwrap() + .iter() + .map(|b| b.num_rows()) + .sum::(); + assert_eq!(read, 1, "Expected 1 rows to match the predicate"); + assert_eq!(get_value(&metrics, "row_groups_pruned_statistics"), 0); + assert_eq!(get_value(&metrics, "page_index_rows_pruned"), 2); + assert_eq!(get_value(&metrics, "pushdown_rows_pruned"), 1); + // If we filter with a value that is completely out of the range of the data + // we prune at the row group level. + let filter = col("c1").eq(lit(ScalarValue::TimestampMillisecond( + Some(5_000), + Some("UTC".into()), + ))); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_pushdown_predicate() + .with_table_schema(table_schema) + .round_trip(vec![batch]) + .await; + // There should be no predicate evaluation errors and we keep 0 rows + let metrics = rt.parquet_exec.metrics().unwrap(); + assert_eq!(get_value(&metrics, "predicate_evaluation_errors"), 0); + let read = rt + .batches + .unwrap() + .iter() + .map(|b| b.num_rows()) + .sum::(); + assert_eq!(read, 0, "Expected 0 rows to match the predicate"); + assert_eq!(get_value(&metrics, "row_groups_pruned_statistics"), 1); + } + #[tokio::test] async fn evolved_schema_disjoint_schema_filter() { let c1: ArrayRef = @@ -1084,7 +1263,7 @@ mod tests { // batch2: c3(int8), c2(int64), c1(string), c4(string) let batch2 = create_batch(vec![("c3", c4), ("c2", c2), ("c1", c1)]); - let schema = Schema::new(vec![ + let table_schema = Schema::new(vec![ Field::new("c1", DataType::Utf8, true), Field::new("c2", DataType::Int64, true), Field::new("c3", DataType::Int8, true), @@ -1092,7 +1271,7 @@ mod tests { // read/write them files: let read = RoundTrip::new() - .with_schema(Arc::new(schema)) + .with_table_schema(Arc::new(table_schema)) .round_trip_to_batches(vec![batch1, batch2]) .await; assert_contains!(read.unwrap_err().to_string(), @@ -1228,6 +1407,124 @@ mod tests { Ok(()) } + #[tokio::test] + async fn parquet_exec_with_int96_nested() -> Result<()> { + // This test ensures that we maintain compatibility with coercing int96 to the desired + // resolution when they're within a nested type (e.g., struct, map, list). This file + // originates from a modified CometFuzzTestSuite ParquetGenerator to generate combinations + // of primitive and complex columns using int96. Other tests cover reading the data + // correctly with this coercion. Here we're only checking the coerced schema is correct. + let testdata = "../../datafusion/core/tests/data"; + let filename = "int96_nested.parquet"; + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + + let parquet_exec = scan_format( + &state, + &ParquetFormat::default().with_coerce_int96(Some("us".to_string())), + None, + testdata, + filename, + None, + None, + ) + .await + .unwrap(); + assert_eq!(parquet_exec.output_partitioning().partition_count(), 1); + + let mut results = parquet_exec.execute(0, task_ctx.clone())?; + let batch = results.next().await.unwrap()?; + + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("c0", DataType::Timestamp(TimeUnit::Microsecond, None), true), + Field::new_struct( + "c1", + vec![Field::new( + "c0", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + )], + true, + ), + Field::new_struct( + "c2", + vec![Field::new_list( + "c0", + Field::new( + "element", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + true, + )], + true, + ), + Field::new_map( + "c3", + "key_value", + Field::new( + "key", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "value", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + false, + true, + ), + Field::new_list( + "c4", + Field::new( + "element", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + true, + ), + Field::new_list( + "c5", + Field::new_struct( + "element", + vec![Field::new( + "c0", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + )], + true, + ), + true, + ), + Field::new_list( + "c6", + Field::new_map( + "element", + "key_value", + Field::new( + "key", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "value", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + false, + true, + ), + true, + ), + ])); + + assert_eq!(batch.schema(), expected_schema); + + Ok(()) + } + #[tokio::test] async fn parquet_exec_with_range() -> Result<()> { fn file_range(meta: &ObjectMeta, start: i64, end: i64) -> PartitionedFile { @@ -1629,6 +1926,7 @@ mod tests { let rt = RoundTrip::new() .with_predicate(filter.clone()) .with_pushdown_predicate() + .with_bloom_filters() .round_trip(vec![batch1]) .await; diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index fc110a0699df..dbe5c2c00f17 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -35,7 +35,11 @@ use crate::{ }, datasource::{provider_as_source, MemTable, ViewTable}, error::{DataFusionError, Result}, - execution::{options::ArrowReadOptions, runtime_env::RuntimeEnv, FunctionRegistry}, + execution::{ + options::ArrowReadOptions, + runtime_env::{RuntimeEnv, RuntimeEnvBuilder}, + FunctionRegistry, + }, logical_expr::AggregateUDF, logical_expr::ScalarUDF, logical_expr::{ @@ -1036,13 +1040,70 @@ impl SessionContext { variable, value, .. } = stmt; - let mut state = self.state.write(); - state.config_mut().options_mut().set(&variable, &value)?; - drop(state); + // Check if this is a runtime configuration + if variable.starts_with("datafusion.runtime.") { + self.set_runtime_variable(&variable, &value)?; + } else { + let mut state = self.state.write(); + state.config_mut().options_mut().set(&variable, &value)?; + drop(state); + } self.return_empty_dataframe() } + fn set_runtime_variable(&self, variable: &str, value: &str) -> Result<()> { + let key = variable.strip_prefix("datafusion.runtime.").unwrap(); + + match key { + "memory_limit" => { + let memory_limit = Self::parse_memory_limit(value)?; + + let mut state = self.state.write(); + let mut builder = + RuntimeEnvBuilder::from_runtime_env(state.runtime_env()); + builder = builder.with_memory_limit(memory_limit, 1.0); + *state = SessionStateBuilder::from(state.clone()) + .with_runtime_env(Arc::new(builder.build()?)) + .build(); + } + _ => { + return Err(DataFusionError::Plan(format!( + "Unknown runtime configuration: {variable}" + ))) + } + } + Ok(()) + } + + /// Parse memory limit from string to number of bytes + /// Supports formats like '1.5G', '100M', '512K' + /// + /// # Examples + /// ``` + /// use datafusion::execution::context::SessionContext; + /// + /// assert_eq!(SessionContext::parse_memory_limit("1M").unwrap(), 1024 * 1024); + /// assert_eq!(SessionContext::parse_memory_limit("1.5G").unwrap(), (1.5 * 1024.0 * 1024.0 * 1024.0) as usize); + /// ``` + pub fn parse_memory_limit(limit: &str) -> Result { + let (number, unit) = limit.split_at(limit.len() - 1); + let number: f64 = number.parse().map_err(|_| { + DataFusionError::Plan(format!( + "Failed to parse number from memory limit '{limit}'" + )) + })?; + + match unit { + "K" => Ok((number * 1024.0) as usize), + "M" => Ok((number * 1024.0 * 1024.0) as usize), + "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize), + _ => Err(DataFusionError::Plan(format!( + "Unsupported unit '{unit}' in memory limit '{limit}'" + ))), + } + } + async fn create_custom_table( &self, cmd: &CreateExternalTable, @@ -1153,7 +1214,7 @@ impl SessionContext { let mut params: Vec = parameters .into_iter() .map(|e| match e { - Expr::Literal(scalar) => Ok(scalar), + Expr::Literal(scalar, _) => Ok(scalar), _ => not_impl_err!("Unsupported parameter type: {}", e), }) .collect::>()?; @@ -1647,7 +1708,7 @@ impl FunctionRegistry for SessionContext { } fn expr_planners(&self) -> Vec> { - self.state.read().expr_planners() + self.state.read().expr_planners().to_vec() } fn register_expr_planner( @@ -1833,7 +1894,6 @@ mod tests { use crate::test; use crate::test_util::{plan_and_collect, populate_csv_partitions}; use arrow::datatypes::{DataType, TimeUnit}; - use std::env; use std::error::Error; use std::path::PathBuf; diff --git a/datafusion/core/src/execution/context/parquet.rs b/datafusion/core/src/execution/context/parquet.rs index 6ec9796fe90d..731f7e59ecfa 100644 --- a/datafusion/core/src/execution/context/parquet.rs +++ b/datafusion/core/src/execution/context/parquet.rs @@ -31,6 +31,21 @@ impl SessionContext { /// [`read_table`](Self::read_table) with a [`super::ListingTable`]. /// /// For an example, see [`read_csv`](Self::read_csv) + /// + /// # Note: Statistics + /// + /// NOTE: by default, statistics are collected when reading the Parquet + /// files This can slow down the initial DataFrame creation while + /// greatly accelerating queries with certain filters. + /// + /// To disable statistics collection, set the [config option] + /// `datafusion.execution.collect_statistics` to `false`. See + /// [`ConfigOptions`] and [`ExecutionOptions::collect_statistics`] for more + /// details. + /// + /// [config option]: https://datafusion.apache.org/user-guide/configs.html + /// [`ConfigOptions`]: crate::config::ConfigOptions + /// [`ExecutionOptions::collect_statistics`]: crate::config::ExecutionOptions::collect_statistics pub async fn read_parquet( &self, table_paths: P, @@ -41,6 +56,13 @@ impl SessionContext { /// Registers a Parquet file as a table that can be referenced from SQL /// statements executed against this context. + /// + /// # Note: Statistics + /// + /// Statistics are not collected by default. See [`read_parquet`] for more + /// details and how to enable them. + /// + /// [`read_parquet`]: Self::read_parquet pub async fn register_parquet( &self, table_ref: impl Into, @@ -84,10 +106,14 @@ mod tests { use crate::parquet::basic::Compression; use crate::test_util::parquet_test_data; + use arrow::util::pretty::pretty_format_batches; use datafusion_common::config::TableParquetOptions; + use datafusion_common::{ + assert_batches_eq, assert_batches_sorted_eq, assert_contains, + }; use datafusion_execution::config::SessionConfig; - use tempfile::tempdir; + use tempfile::{tempdir, TempDir}; #[tokio::test] async fn read_with_glob_path() -> Result<()> { @@ -129,6 +155,49 @@ mod tests { Ok(()) } + async fn explain_query_all_with_config(config: SessionConfig) -> Result { + let ctx = SessionContext::new_with_config(config); + + ctx.register_parquet( + "test", + &format!("{}/alltypes_plain*.parquet", parquet_test_data()), + ParquetReadOptions::default(), + ) + .await?; + let df = ctx.sql("EXPLAIN SELECT * FROM test").await?; + let results = df.collect().await?; + let content = pretty_format_batches(&results).unwrap().to_string(); + Ok(content) + } + + #[tokio::test] + async fn register_parquet_respects_collect_statistics_config() -> Result<()> { + // The default is true + let mut config = SessionConfig::new(); + config.options_mut().explain.physical_plan_only = true; + config.options_mut().explain.show_statistics = true; + let content = explain_query_all_with_config(config).await?; + assert_contains!(content, "statistics=[Rows=Exact("); + + // Explicitly set to true + let mut config = SessionConfig::new(); + config.options_mut().explain.physical_plan_only = true; + config.options_mut().explain.show_statistics = true; + config.options_mut().execution.collect_statistics = true; + let content = explain_query_all_with_config(config).await?; + assert_contains!(content, "statistics=[Rows=Exact("); + + // Explicitly set to false + let mut config = SessionConfig::new(); + config.options_mut().explain.physical_plan_only = true; + config.options_mut().explain.show_statistics = true; + config.options_mut().execution.collect_statistics = false; + let content = explain_query_all_with_config(config).await?; + assert_contains!(content, "statistics=[Rows=Absent,"); + + Ok(()) + } + #[tokio::test] async fn read_from_registered_table_with_glob_path() -> Result<()> { let ctx = SessionContext::new(); @@ -286,7 +355,7 @@ mod tests { let expected_path = binding[0].as_str(); assert_eq!( read_df.unwrap_err().strip_backtrace(), - format!("Execution error: File path '{}' does not match the expected extension '.parquet'", expected_path) + format!("Execution error: File path '{expected_path}' does not match the expected extension '.parquet'") ); // Read the dataframe from 'output3.parquet.snappy.parquet' with the correct file extension. @@ -333,4 +402,124 @@ mod tests { assert_eq!(total_rows, 5); Ok(()) } + + #[tokio::test] + async fn read_from_parquet_folder() -> Result<()> { + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + let test_path = tmp_dir.path().to_str().unwrap().to_string(); + + ctx.sql("SELECT 1 a") + .await? + .write_parquet(&test_path, DataFrameWriteOptions::default(), None) + .await?; + + ctx.sql("SELECT 2 a") + .await? + .write_parquet(&test_path, DataFrameWriteOptions::default(), None) + .await?; + + // Adding CSV to check it is not read with Parquet reader + ctx.sql("SELECT 3 a") + .await? + .write_csv(&test_path, DataFrameWriteOptions::default(), None) + .await?; + + let actual = ctx + .read_parquet(&test_path, ParquetReadOptions::default()) + .await? + .collect() + .await?; + + #[cfg_attr(any(), rustfmt::skip)] + assert_batches_sorted_eq!(&[ + "+---+", + "| a |", + "+---+", + "| 2 |", + "| 1 |", + "+---+", + ], &actual); + + let actual = ctx + .read_parquet(test_path, ParquetReadOptions::default()) + .await? + .collect() + .await?; + + #[cfg_attr(any(), rustfmt::skip)] + assert_batches_sorted_eq!(&[ + "+---+", + "| a |", + "+---+", + "| 2 |", + "| 1 |", + "+---+", + ], &actual); + + Ok(()) + } + + #[tokio::test] + async fn read_from_parquet_folder_table() -> Result<()> { + let ctx = SessionContext::new(); + let tmp_dir = TempDir::new()?; + let test_path = tmp_dir.path().to_str().unwrap().to_string(); + + ctx.sql("SELECT 1 a") + .await? + .write_parquet(&test_path, DataFrameWriteOptions::default(), None) + .await?; + + ctx.sql("SELECT 2 a") + .await? + .write_parquet(&test_path, DataFrameWriteOptions::default(), None) + .await?; + + // Adding CSV to check it is not read with Parquet reader + ctx.sql("SELECT 3 a") + .await? + .write_csv(&test_path, DataFrameWriteOptions::default(), None) + .await?; + + ctx.sql(format!("CREATE EXTERNAL TABLE parquet_folder_t1 STORED AS PARQUET LOCATION '{test_path}'").as_ref()) + .await?; + + let actual = ctx + .sql("select * from parquet_folder_t1") + .await? + .collect() + .await?; + #[cfg_attr(any(), rustfmt::skip)] + assert_batches_sorted_eq!(&[ + "+---+", + "| a |", + "+---+", + "| 2 |", + "| 1 |", + "+---+", + ], &actual); + + Ok(()) + } + + #[tokio::test] + async fn read_dummy_folder() -> Result<()> { + let ctx = SessionContext::new(); + let test_path = "/foo/"; + + let actual = ctx + .read_parquet(test_path, ParquetReadOptions::default()) + .await? + .collect() + .await?; + + #[cfg_attr(any(), rustfmt::skip)] + assert_batches_eq!(&[ + "++", + "++", + ], &actual); + + Ok(()) + } } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 28f599304f8c..1c0363f421af 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -494,7 +494,7 @@ impl SessionState { enable_options_value_normalization: sql_parser_options .enable_options_value_normalization, support_varchar_with_length: sql_parser_options.support_varchar_with_length, - map_varchar_to_utf8view: sql_parser_options.map_varchar_to_utf8view, + map_string_types_to_utf8view: sql_parser_options.map_string_types_to_utf8view, collect_spans: sql_parser_options.collect_spans, } } @@ -552,6 +552,11 @@ impl SessionState { &self.optimizer } + /// Returns the [`ExprPlanner`]s for this session + pub fn expr_planners(&self) -> &[Arc] { + &self.expr_planners + } + /// Returns the [`QueryPlanner`] for this session pub fn query_planner(&self) -> &Arc { &self.query_planner @@ -1348,28 +1353,30 @@ impl SessionStateBuilder { } = self; let config = config.unwrap_or_default(); - let runtime_env = runtime_env.unwrap_or(Arc::new(RuntimeEnv::default())); + let runtime_env = runtime_env.unwrap_or_else(|| Arc::new(RuntimeEnv::default())); let mut state = SessionState { - session_id: session_id.unwrap_or(Uuid::new_v4().to_string()), + session_id: session_id.unwrap_or_else(|| Uuid::new_v4().to_string()), analyzer: analyzer.unwrap_or_default(), expr_planners: expr_planners.unwrap_or_default(), type_planner, optimizer: optimizer.unwrap_or_default(), physical_optimizers: physical_optimizers.unwrap_or_default(), - query_planner: query_planner.unwrap_or(Arc::new(DefaultQueryPlanner {})), - catalog_list: catalog_list - .unwrap_or(Arc::new(MemoryCatalogProviderList::new()) - as Arc), + query_planner: query_planner + .unwrap_or_else(|| Arc::new(DefaultQueryPlanner {})), + catalog_list: catalog_list.unwrap_or_else(|| { + Arc::new(MemoryCatalogProviderList::new()) as Arc + }), table_functions: table_functions.unwrap_or_default(), scalar_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), serializer_registry: serializer_registry - .unwrap_or(Arc::new(EmptySerializerRegistry)), + .unwrap_or_else(|| Arc::new(EmptySerializerRegistry)), file_formats: HashMap::new(), - table_options: table_options - .unwrap_or(TableOptions::default_from_session_config(config.options())), + table_options: table_options.unwrap_or_else(|| { + TableOptions::default_from_session_config(config.options()) + }), config, execution_props: execution_props.unwrap_or_default(), table_factories: table_factories.unwrap_or_default(), @@ -1635,7 +1642,7 @@ struct SessionContextProvider<'a> { impl ContextProvider for SessionContextProvider<'_> { fn get_expr_planners(&self) -> &[Arc] { - &self.state.expr_planners + self.state.expr_planners() } fn get_type_planner(&self) -> Option> { @@ -1668,6 +1675,13 @@ impl ContextProvider for SessionContextProvider<'_> { .get(name) .cloned() .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?; + let dummy_schema = DFSchema::empty(); + let simplifier = + ExprSimplifier::new(SessionSimplifyProvider::new(self.state, &dummy_schema)); + let args = args + .into_iter() + .map(|arg| simplifier.simplify(arg)) + .collect::>>()?; let provider = tbl_func.create_table_provider(&args)?; Ok(provider_as_source(provider)) @@ -1751,7 +1765,7 @@ impl FunctionRegistry for SessionState { let result = self.scalar_functions.get(name); result.cloned().ok_or_else(|| { - plan_datafusion_err!("There is no UDF named \"{name}\" in the registry") + plan_datafusion_err!("There is no UDF named \"{name}\" in the registry. Use session context `register_udf` function to register a custom UDF") }) } @@ -1759,7 +1773,7 @@ impl FunctionRegistry for SessionState { let result = self.aggregate_functions.get(name); result.cloned().ok_or_else(|| { - plan_datafusion_err!("There is no UDAF named \"{name}\" in the registry") + plan_datafusion_err!("There is no UDAF named \"{name}\" in the registry. Use session context `register_udaf` function to register a custom UDAF") }) } @@ -1767,7 +1781,7 @@ impl FunctionRegistry for SessionState { let result = self.window_functions.get(name); result.cloned().ok_or_else(|| { - plan_datafusion_err!("There is no UDWF named \"{name}\" in the registry") + plan_datafusion_err!("There is no UDWF named \"{name}\" in the registry. Use session context `register_udwf` function to register a custom UDWF") }) } @@ -1957,8 +1971,17 @@ pub(crate) struct PreparedPlan { #[cfg(test)] mod tests { use super::{SessionContextProvider, SessionStateBuilder}; + use crate::common::assert_contains; + use crate::config::ConfigOptions; + use crate::datasource::empty::EmptyTable; + use crate::datasource::provider_as_source; use crate::datasource::MemTable; use crate::execution::context::SessionState; + use crate::logical_expr::planner::ExprPlanner; + use crate::logical_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; + use crate::physical_plan::ExecutionPlan; + use crate::sql::planner::ContextProvider; + use crate::sql::{ResolvedTableReference, TableReference}; use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_catalog::MemoryCatalogProviderList; @@ -1968,6 +1991,7 @@ mod tests { use datafusion_expr::Expr; use datafusion_optimizer::optimizer::OptimizerRule; use datafusion_optimizer::Optimizer; + use datafusion_physical_plan::display::DisplayableExecutionPlan; use datafusion_sql::planner::{PlannerContext, SqlToRel}; use std::collections::HashMap; use std::sync::Arc; @@ -2125,4 +2149,148 @@ mod tests { Ok(()) } + + /// This test demonstrates why it's more convenient and somewhat necessary to provide + /// an `expr_planners` method for `SessionState`. + #[tokio::test] + async fn test_with_expr_planners() -> Result<()> { + // A helper method for planning count wildcard with or without expr planners. + async fn plan_count_wildcard( + with_expr_planners: bool, + ) -> Result> { + let mut context_provider = MyContextProvider::new().with_table( + "t", + provider_as_source(Arc::new(EmptyTable::new(Schema::empty().into()))), + ); + if with_expr_planners { + context_provider = context_provider.with_expr_planners(); + } + + let state = &context_provider.state; + let statement = state.sql_to_statement("select count(*) from t", "mysql")?; + let plan = SqlToRel::new(&context_provider).statement_to_plan(statement)?; + state.create_physical_plan(&plan).await + } + + // Planning count wildcard without expr planners should fail. + let got = plan_count_wildcard(false).await; + assert_contains!( + got.unwrap_err().to_string(), + "Physical plan does not support logical expression Wildcard" + ); + + // Planning count wildcard with expr planners should succeed. + let got = plan_count_wildcard(true).await?; + let displayable = DisplayableExecutionPlan::new(got.as_ref()); + assert_eq!( + displayable.indent(false).to_string(), + "ProjectionExec: expr=[0 as count(*)]\n PlaceholderRowExec\n" + ); + + Ok(()) + } + + /// A `ContextProvider` based on `SessionState`. + /// + /// Almost all planning context are retrieved from the `SessionState`. + struct MyContextProvider { + /// The session state. + state: SessionState, + /// Registered tables. + tables: HashMap>, + /// Controls whether to return expression planners when called `ContextProvider::expr_planners`. + return_expr_planners: bool, + } + + impl MyContextProvider { + /// Creates a new `SessionContextProvider`. + pub fn new() -> Self { + Self { + state: SessionStateBuilder::default() + .with_default_features() + .build(), + tables: HashMap::new(), + return_expr_planners: false, + } + } + + /// Registers a table. + /// + /// The catalog and schema are provided by default. + pub fn with_table(mut self, table: &str, source: Arc) -> Self { + self.tables.insert( + ResolvedTableReference { + catalog: "default".to_string().into(), + schema: "public".to_string().into(), + table: table.to_string().into(), + }, + source, + ); + self + } + + /// Sets the `return_expr_planners` flag to true. + pub fn with_expr_planners(self) -> Self { + Self { + return_expr_planners: true, + ..self + } + } + } + + impl ContextProvider for MyContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { + let resolved_table_ref = ResolvedTableReference { + catalog: "default".to_string().into(), + schema: "public".to_string().into(), + table: name.table().to_string().into(), + }; + let source = self.tables.get(&resolved_table_ref).cloned().unwrap(); + Ok(source) + } + + /// We use a `return_expr_planners` flag to demonstrate why it's necessary to + /// return the expression planners in the `SessionState`. + /// + /// Note, the default implementation returns an empty slice. + fn get_expr_planners(&self) -> &[Arc] { + if self.return_expr_planners { + self.state.expr_planners() + } else { + &[] + } + } + + fn get_function_meta(&self, name: &str) -> Option> { + self.state.scalar_functions().get(name).cloned() + } + + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.state.aggregate_functions().get(name).cloned() + } + + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions().get(name).cloned() + } + + fn get_variable_type(&self, _variable_names: &[String]) -> Option { + None + } + + fn options(&self) -> &ConfigOptions { + self.state.config_options() + } + + fn udf_names(&self) -> Vec { + self.state.scalar_functions().keys().cloned().collect() + } + + fn udaf_names(&self) -> Vec { + self.state.aggregate_functions().keys().cloned().collect() + } + + fn udwf_names(&self) -> Vec { + self.state.window_functions().keys().cloned().collect() + } + } } diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index cc510bc81f1a..dc9f7cf1cc18 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -22,13 +22,25 @@ #![cfg_attr(docsrs, feature(doc_auto_cfg))] // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 -#![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] +// +// Eliminate unnecessary function calls(some may be not cheap) due to `xxx_or` +// for performance. Also avoid abusing `xxx_or_else` for readability: +// https://github.com/apache/datafusion/issues/15802 +#![cfg_attr( + not(test), + deny( + clippy::clone_on_ref_ptr, + clippy::or_fun_call, + clippy::unnecessary_lazy_evaluations + ) +)] #![warn(missing_docs, clippy::needless_borrow)] //! [DataFusion] is an extensible query engine written in Rust that //! uses [Apache Arrow] as its in-memory format. DataFusion's target users are //! developers building fast and feature rich database and analytic systems, -//! customized to particular workloads. See [use cases] for examples. +//! customized to particular workloads. Please see the [DataFusion website] for +//! additional documentation, [use cases] and examples. //! //! "Out of the box," DataFusion offers [SQL] and [`Dataframe`] APIs, //! excellent [performance], built-in support for CSV, Parquet, JSON, and Avro, @@ -42,6 +54,7 @@ //! See the [Architecture] section below for more details. //! //! [DataFusion]: https://datafusion.apache.org/ +//! [DataFusion website]: https://datafusion.apache.org //! [Apache Arrow]: https://arrow.apache.org //! [use cases]: https://datafusion.apache.org/user-guide/introduction.html#use-cases //! [SQL]: https://datafusion.apache.org/user-guide/sql/index.html @@ -300,14 +313,17 @@ //! ``` //! //! A [`TableProvider`] provides information for planning and -//! an [`ExecutionPlan`]s for execution. DataFusion includes [`ListingTable`] -//! which supports reading several common file formats, and you can support any -//! new file format by implementing the [`TableProvider`] trait. See also: +//! an [`ExecutionPlan`] for execution. DataFusion includes [`ListingTable`], +//! a [`TableProvider`] which reads individual files or directories of files +//! ("partitioned datasets") of the same file format. Users can add +//! support for new file formats by implementing the [`TableProvider`] +//! trait. +//! +//! See also: //! -//! 1. [`ListingTable`]: Reads data from Parquet, JSON, CSV, or AVRO -//! files. Supports single files or multiple files with HIVE style -//! partitioning, optional compression, directly reading from remote -//! object store and more. +//! 1. [`ListingTable`]: Reads data from one or more Parquet, JSON, CSV, or AVRO +//! files supporting HIVE style partitioning, optional compression, directly +//! reading from remote object store and more. //! //! 2. [`MemTable`]: Reads data from in memory [`RecordBatch`]es. //! @@ -326,11 +342,11 @@ //! A [`LogicalPlan`] is a Directed Acyclic Graph (DAG) of other //! [`LogicalPlan`]s, each potentially containing embedded [`Expr`]s. //! -//! `LogicalPlan`s can be rewritten with [`TreeNode`] API, see the +//! [`LogicalPlan`]s can be rewritten with [`TreeNode`] API, see the //! [`tree_node module`] for more details. //! //! [`Expr`]s can also be rewritten with [`TreeNode`] API and simplified using -//! [`ExprSimplifier`]. Examples of working with and executing `Expr`s can be +//! [`ExprSimplifier`]. Examples of working with and executing [`Expr`]s can be //! found in the [`expr_api`.rs] example //! //! [`TreeNode`]: datafusion_common::tree_node::TreeNode @@ -415,17 +431,17 @@ //! //! ## Streaming Execution //! -//! DataFusion is a "streaming" query engine which means `ExecutionPlan`s incrementally +//! DataFusion is a "streaming" query engine which means [`ExecutionPlan`]s incrementally //! read from their input(s) and compute output one [`RecordBatch`] at a time //! by continually polling [`SendableRecordBatchStream`]s. Output and -//! intermediate `RecordBatch`s each have approximately `batch_size` rows, +//! intermediate [`RecordBatch`]s each have approximately `batch_size` rows, //! which amortizes per-batch overhead of execution. //! //! Note that certain operations, sometimes called "pipeline breakers", //! (for example full sorts or hash aggregations) are fundamentally non streaming and //! must read their input fully before producing **any** output. As much as possible, //! other operators read a single [`RecordBatch`] from their input to produce a -//! single `RecordBatch` as output. +//! single [`RecordBatch`] as output. //! //! For example, given this SQL query: //! @@ -434,9 +450,9 @@ //! ``` //! //! The diagram below shows the call sequence when a consumer calls [`next()`] to -//! get the next `RecordBatch` of output. While it is possible that some +//! get the next [`RecordBatch`] of output. While it is possible that some //! steps run on different threads, typically tokio will use the same thread -//! that called `next()` to read from the input, apply the filter, and +//! that called [`next()`] to read from the input, apply the filter, and //! return the results without interleaving any other operations. This results //! in excellent cache locality as the same CPU core that produces the data often //! consumes it immediately as well. @@ -474,39 +490,53 @@ //! DataFusion automatically runs each plan with multiple CPU cores using //! a [Tokio] [`Runtime`] as a thread pool. While tokio is most commonly used //! for asynchronous network I/O, the combination of an efficient, work-stealing -//! scheduler and first class compiler support for automatic continuation -//! generation (`async`), also makes it a compelling choice for CPU intensive +//! scheduler, and first class compiler support for automatic continuation +//! generation (`async`) also makes it a compelling choice for CPU intensive //! applications as explained in the [Using Rustlang’s Async Tokio //! Runtime for CPU-Bound Tasks] blog. //! //! The number of cores used is determined by the `target_partitions` //! configuration setting, which defaults to the number of CPU cores. //! While preparing for execution, DataFusion tries to create this many distinct -//! `async` [`Stream`]s for each `ExecutionPlan`. -//! The `Stream`s for certain `ExecutionPlans`, such as as [`RepartitionExec`] -//! and [`CoalescePartitionsExec`], spawn [Tokio] [`task`]s, that are run by -//! threads managed by the `Runtime`. -//! Many DataFusion `Stream`s perform CPU intensive processing. +//! `async` [`Stream`]s for each [`ExecutionPlan`]. +//! The [`Stream`]s for certain [`ExecutionPlan`]s, such as [`RepartitionExec`] +//! and [`CoalescePartitionsExec`], spawn [Tokio] [`task`]s, that run on +//! threads managed by the [`Runtime`]. +//! Many DataFusion [`Stream`]s perform CPU intensive processing. +//! +//! ### Cooperative Scheduling +//! +//! DataFusion uses cooperative scheduling, which means that each [`Stream`] +//! is responsible for yielding control back to the [`Runtime`] after +//! some amount of work is done. Please see the [`coop`] module documentation +//! for more details. +//! +//! [`coop`]: datafusion_physical_plan::coop +//! +//! ### Network I/O and CPU intensive tasks //! //! Using `async` for CPU intensive tasks makes it easy for [`TableProvider`]s //! to perform network I/O using standard Rust `async` during execution. //! However, this design also makes it very easy to mix CPU intensive and latency //! sensitive I/O work on the same thread pool ([`Runtime`]). -//! Using the same (default) `Runtime` is convenient, and often works well for +//! Using the same (default) [`Runtime`] is convenient, and often works well for //! initial development and processing local files, but it can lead to problems //! under load and/or when reading from network sources such as AWS S3. //! +//! ### Optimizing Latency: Throttled CPU / IO under Highly Concurrent Load +//! //! If your system does not fully utilize either the CPU or network bandwidth //! during execution, or you see significantly higher tail (e.g. p99) latencies //! responding to network requests, **it is likely you need to use a different -//! `Runtime` for CPU intensive DataFusion plans**. This effect can be especially -//! pronounced when running several queries concurrently. +//! [`Runtime`] for DataFusion plans**. The [thread_pools example] +//! has an example of how to do so. //! -//! As shown in the following figure, using the same `Runtime` for both CPU -//! intensive processing and network requests can introduce significant -//! delays in responding to those network requests. Delays in processing network -//! requests can and does lead network flow control to throttle the available -//! bandwidth in response. +//! As shown below, using the same [`Runtime`] for both CPU intensive processing +//! and network requests can introduce significant delays in responding to +//! those network requests. Delays in processing network requests can and does +//! lead network flow control to throttle the available bandwidth in response. +//! This effect can be especially pronounced when running multiple queries +//! concurrently. //! //! ```text //! Legend @@ -588,6 +618,7 @@ //! //! [Tokio]: https://tokio.rs //! [`Runtime`]: tokio::runtime::Runtime +//! [thread_pools example]: https://github.com/apache/datafusion/tree/main/datafusion-examples/examples/thread_pools.rs //! [`task`]: tokio::task //! [Using Rustlang’s Async Tokio Runtime for CPU-Bound Tasks]: https://thenewstack.io/using-rustlangs-async-tokio-runtime-for-cpu-bound-tasks/ //! [`RepartitionExec`]: physical_plan::repartition::RepartitionExec @@ -603,8 +634,8 @@ //! The state required to execute queries is managed by the following //! structures: //! -//! 1. [`SessionContext`]: State needed for create [`LogicalPlan`]s such -//! as the table definitions, and the function registries. +//! 1. [`SessionContext`]: State needed to create [`LogicalPlan`]s such +//! as the table definitions and the function registries. //! //! 2. [`TaskContext`]: State needed for execution such as the //! [`MemoryPool`], [`DiskManager`], and [`ObjectStoreRegistry`]. @@ -1021,14 +1052,20 @@ doc_comment::doctest!( #[cfg(doctest)] doc_comment::doctest!( - "../../../docs/source/user-guide/sql/write_options.md", - user_guide_sql_write_options + "../../../docs/source/user-guide/sql/format_options.md", + user_guide_sql_format_options +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/functions/adding-udfs.md", + library_user_guide_functions_adding_udfs ); #[cfg(doctest)] doc_comment::doctest!( - "../../../docs/source/library-user-guide/adding-udfs.md", - library_user_guide_adding_udfs + "../../../docs/source/library-user-guide/functions/spark.md", + library_user_guide_functions_spark ); #[cfg(doctest)] diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index be24206c676c..90cc0b572fef 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -55,18 +55,20 @@ use crate::physical_plan::{ displayable, windows, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, Partitioning, PhysicalExpr, WindowExpr, }; -use datafusion_physical_plan::empty::EmptyExec; -use datafusion_physical_plan::recursive_query::RecursiveQueryExec; +use crate::schema_equivalence::schema_satisfied_by; use arrow::array::{builder::StringBuilder, RecordBatch}; use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor, +}; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, ScalarValue, }; +use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::memory::MemorySourceConfig; use datafusion_expr::dml::{CopyTo, InsertOp}; use datafusion_expr::expr::{ @@ -82,19 +84,22 @@ use datafusion_expr::{ }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::{Column, Literal}; -use datafusion_physical_expr::LexOrdering; +use datafusion_physical_expr::{ + create_physical_sort_exprs, LexOrdering, PhysicalSortExpr, +}; use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::execution_plan::InvariantLevel; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion_physical_plan::recursive_query::RecursiveQueryExec; use datafusion_physical_plan::unnest::ListUnnest; +use sqlparser::ast::NullTreatment; -use crate::schema_equivalence::schema_satisfied_by; use async_trait::async_trait; -use datafusion_datasource::file_groups::FileGroup; +use datafusion_physical_plan::async_func::{AsyncFuncExec, AsyncMapper}; use futures::{StreamExt, TryStreamExt}; use itertools::{multiunzip, Itertools}; use log::{debug, trace}; -use sqlparser::ast::NullTreatment; use tokio::sync::Mutex; /// Physical query planner that converts a `LogicalPlan` to an @@ -522,7 +527,7 @@ impl DefaultPhysicalPlanner { Some("true") => true, Some("false") => false, Some(value) => - return Err(DataFusionError::Configuration(format!("provided value for 'execution.keep_partition_by_columns' was not recognized: \"{}\"", value))), + return Err(DataFusionError::Configuration(format!("provided value for 'execution.keep_partition_by_columns' was not recognized: \"{value}\""))), }; let sink_format = file_type_to_format(file_type)? @@ -572,27 +577,25 @@ impl DefaultPhysicalPlanner { let input_exec = children.one()?; let get_sort_keys = |expr: &Expr| match expr { - Expr::WindowFunction(WindowFunction { - params: - WindowFunctionParams { - ref partition_by, - ref order_by, - .. - }, - .. - }) => generate_sort_key(partition_by, order_by), + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams { + ref partition_by, + ref order_by, + .. + } = &window_fun.as_ref().params; + generate_sort_key(partition_by, order_by) + } Expr::Alias(Alias { expr, .. }) => { // Convert &Box to &T match &**expr { - Expr::WindowFunction(WindowFunction { - params: - WindowFunctionParams { - ref partition_by, - ref order_by, - .. - }, - .. - }) => generate_sort_key(partition_by, order_by), + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams { + ref partition_by, + ref order_by, + .. + } = &window_fun.as_ref().params; + generate_sort_key(partition_by, order_by) + } _ => unreachable!(), } } @@ -694,7 +697,7 @@ impl DefaultPhysicalPlanner { } return internal_err!("Physical input schema should be the same as the one converted from logical input schema. Differences: {}", differences .iter() - .map(|s| format!("\n\t- {}", s)) + .map(|s| format!("\n\t- {s}")) .join("")); } @@ -775,12 +778,46 @@ impl DefaultPhysicalPlanner { let runtime_expr = self.create_physical_expr(predicate, input_dfschema, session_state)?; + + let input_schema = input.schema(); + let filter = match self.try_plan_async_exprs( + input_schema.fields().len(), + PlannedExprResult::Expr(vec![runtime_expr]), + input_schema.as_arrow(), + )? { + PlanAsyncExpr::Sync(PlannedExprResult::Expr(runtime_expr)) => { + FilterExec::try_new(Arc::clone(&runtime_expr[0]), physical_input)? + } + PlanAsyncExpr::Async( + async_map, + PlannedExprResult::Expr(runtime_expr), + ) => { + let async_exec = AsyncFuncExec::try_new( + async_map.async_exprs, + physical_input, + )?; + FilterExec::try_new( + Arc::clone(&runtime_expr[0]), + Arc::new(async_exec), + )? + // project the output columns excluding the async functions + // The async functions are always appended to the end of the schema. + .with_projection(Some( + (0..input.schema().fields().len()).collect(), + ))? + } + _ => { + return internal_err!( + "Unexpected result from try_plan_async_exprs" + ) + } + }; + let selectivity = session_state .config() .options() .optimizer .default_filter_selectivity; - let filter = FilterExec::try_new(runtime_expr, physical_input)?; Arc::new(filter.with_default_selectivity(selectivity)?) } LogicalPlan::Repartition(Repartition { @@ -822,13 +859,17 @@ impl DefaultPhysicalPlanner { }) => { let physical_input = children.one()?; let input_dfschema = input.as_ref().schema(); - let sort_expr = create_physical_sort_exprs( + let sort_exprs = create_physical_sort_exprs( expr, input_dfschema, session_state.execution_props(), )?; - let new_sort = - SortExec::new(sort_expr, physical_input).with_fetch(*fetch); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + return internal_err!( + "SortExec requires at least one sort expression" + ); + }; + let new_sort = SortExec::new(ordering, physical_input).with_fetch(*fetch); Arc::new(new_sort) } LogicalPlan::Subquery(_) => todo!(), @@ -895,12 +936,10 @@ impl DefaultPhysicalPlanner { on: keys, filter, join_type, - null_equals_null, + null_equality, schema: join_schema, .. }) => { - let null_equals_null = *null_equals_null; - let [physical_left, physical_right] = children.two()?; // If join has expression equijoin keys, add physical projection. @@ -1113,8 +1152,6 @@ impl DefaultPhysicalPlanner { && !prefer_hash_join { // Use SortMergeJoin if hash join is not preferred - // Sort-Merge join support currently is experimental - let join_on_len = join_on.len(); Arc::new(SortMergeJoinExec::try_new( physical_left, @@ -1123,7 +1160,7 @@ impl DefaultPhysicalPlanner { join_filter, *join_type, vec![SortOptions::default(); join_on_len], - null_equals_null, + *null_equality, )?) } else if session_state.config().target_partitions() > 1 && session_state.config().repartition_joins() @@ -1137,7 +1174,7 @@ impl DefaultPhysicalPlanner { join_type, None, PartitionMode::Auto, - null_equals_null, + *null_equality, )?) } else { Arc::new(HashJoinExec::try_new( @@ -1148,7 +1185,7 @@ impl DefaultPhysicalPlanner { join_type, None, PartitionMode::CollectLeft, - null_equals_null, + *null_equality, )?) }; @@ -1506,17 +1543,18 @@ pub fn create_window_expr_with_name( let name = name.into(); let physical_schema: &Schema = &logical_schema.into(); match e { - Expr::WindowFunction(WindowFunction { - fun, - params: - WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - }, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + } = window_fun.as_ref(); let physical_args = create_physical_exprs(args, logical_schema, execution_props)?; let partition_by = @@ -1539,7 +1577,7 @@ pub fn create_window_expr_with_name( name, &physical_args, &partition_by, - order_by.as_ref(), + &order_by, window_frame, physical_schema, ignore_nulls, @@ -1567,8 +1605,8 @@ type AggregateExprWithOptionalArgs = ( Arc, // The filter clause, if any Option>, - // Ordering requirements, if any - Option, + // Expressions in the ORDER BY clause + Vec, ); /// Create an aggregate expression with a name from a logical expression @@ -1612,22 +1650,19 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; - let (agg_expr, filter, order_by) = { - let physical_sort_exprs = match order_by { - Some(exprs) => Some(create_physical_sort_exprs( + let (agg_expr, filter, order_bys) = { + let order_bys = match order_by { + Some(exprs) => create_physical_sort_exprs( exprs, logical_input_schema, execution_props, - )?), - None => None, + )?, + None => vec![], }; - let ordering_reqs: LexOrdering = - physical_sort_exprs.clone().unwrap_or_default(); - let agg_expr = AggregateExprBuilder::new(func.to_owned(), physical_args.to_vec()) - .order_by(ordering_reqs) + .order_by(order_bys.clone()) .schema(Arc::new(physical_input_schema.to_owned())) .alias(name) .human_display(human_displan) @@ -1636,10 +1671,10 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( .build() .map(Arc::new)?; - (agg_expr, filter, physical_sort_exprs) + (agg_expr, filter, order_bys) }; - Ok((agg_expr, filter, order_by)) + Ok((agg_expr, filter, order_bys)) } other => internal_err!("Invalid aggregate expression '{other:?}'"), } @@ -1675,14 +1710,6 @@ pub fn create_aggregate_expr_and_maybe_filter( ) } -#[deprecated( - since = "47.0.0", - note = "use datafusion::{create_physical_sort_expr, create_physical_sort_exprs}" -)] -pub use datafusion_physical_expr::{ - create_physical_sort_expr, create_physical_sort_exprs, -}; - impl DefaultPhysicalPlanner { /// Handles capturing the various plans for EXPLAIN queries /// @@ -1714,6 +1741,14 @@ impl DefaultPhysicalPlanner { let config = &session_state.config_options().explain; let explain_format = &e.explain_format; + if !e.logical_optimization_succeeded { + return Ok(Arc::new(ExplainExec::new( + Arc::clone(e.schema.inner()), + e.stringified_plans.clone(), + true, + ))); + } + match explain_format { ExplainFormat::Indent => { /* fall through */ } ExplainFormat::Tree => { @@ -1952,7 +1987,7 @@ impl DefaultPhysicalPlanner { "Optimized physical plan:\n{}\n", displayable(new_plan.as_ref()).indent(false) ); - trace!("Detailed optimized physical plan:\n{:?}", new_plan); + trace!("Detailed optimized physical plan:\n{new_plan:?}"); Ok(new_plan) } @@ -2044,13 +2079,91 @@ impl DefaultPhysicalPlanner { }) .collect::>>()?; - Ok(Arc::new(ProjectionExec::try_new( - physical_exprs, - input_exec, - )?)) + let num_input_columns = input_exec.schema().fields().len(); + + match self.try_plan_async_exprs( + num_input_columns, + PlannedExprResult::ExprWithName(physical_exprs), + input_physical_schema.as_ref(), + )? { + PlanAsyncExpr::Sync(PlannedExprResult::ExprWithName(physical_exprs)) => Ok( + Arc::new(ProjectionExec::try_new(physical_exprs, input_exec)?), + ), + PlanAsyncExpr::Async( + async_map, + PlannedExprResult::ExprWithName(physical_exprs), + ) => { + let async_exec = + AsyncFuncExec::try_new(async_map.async_exprs, input_exec)?; + let new_proj_exec = + ProjectionExec::try_new(physical_exprs, Arc::new(async_exec))?; + Ok(Arc::new(new_proj_exec)) + } + _ => internal_err!("Unexpected PlanAsyncExpressions variant"), + } + } + + fn try_plan_async_exprs( + &self, + num_input_columns: usize, + physical_expr: PlannedExprResult, + schema: &Schema, + ) -> Result { + let mut async_map = AsyncMapper::new(num_input_columns); + match &physical_expr { + PlannedExprResult::ExprWithName(exprs) => { + exprs + .iter() + .try_for_each(|(expr, _)| async_map.find_references(expr, schema))?; + } + PlannedExprResult::Expr(exprs) => { + exprs + .iter() + .try_for_each(|expr| async_map.find_references(expr, schema))?; + } + } + + if async_map.is_empty() { + return Ok(PlanAsyncExpr::Sync(physical_expr)); + } + + let new_exprs = match physical_expr { + PlannedExprResult::ExprWithName(exprs) => PlannedExprResult::ExprWithName( + exprs + .iter() + .map(|(expr, column_name)| { + let new_expr = Arc::clone(expr) + .transform_up(|e| Ok(async_map.map_expr(e)))?; + Ok((new_expr.data, column_name.to_string())) + }) + .collect::>()?, + ), + PlannedExprResult::Expr(exprs) => PlannedExprResult::Expr( + exprs + .iter() + .map(|expr| { + let new_expr = Arc::clone(expr) + .transform_up(|e| Ok(async_map.map_expr(e)))?; + Ok(new_expr.data) + }) + .collect::>()?, + ), + }; + // rewrite the projection's expressions in terms of the columns with the result of async evaluation + Ok(PlanAsyncExpr::Async(async_map, new_exprs)) } } +enum PlannedExprResult { + ExprWithName(Vec<(Arc, String)>), + Expr(Vec>), +} + +enum PlanAsyncExpr { + Sync(PlannedExprResult), + Async(AsyncMapper, PlannedExprResult), +} + fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { match value { (Ok(e), Ok(e1)) => Ok((e, e1)), @@ -2069,29 +2182,36 @@ fn maybe_fix_physical_column_name( expr: Result>, input_physical_schema: &SchemaRef, ) -> Result> { - if let Ok(e) = &expr { - if let Some(column) = e.as_any().downcast_ref::() { - let physical_field = input_physical_schema.field(column.index()); + let Ok(expr) = expr else { return expr }; + expr.transform_down(|node| { + if let Some(column) = node.as_any().downcast_ref::() { + let idx = column.index(); + let physical_field = input_physical_schema.field(idx); let expr_col_name = column.name(); let physical_name = physical_field.name(); - if physical_name != expr_col_name { + if expr_col_name != physical_name { // handle edge cases where the physical_name contains ':'. let colon_count = physical_name.matches(':').count(); let mut splits = expr_col_name.match_indices(':'); let split_pos = splits.nth(colon_count); - if let Some((idx, _)) = split_pos { - let base_name = &expr_col_name[..idx]; + if let Some((i, _)) = split_pos { + let base_name = &expr_col_name[..i]; if base_name == physical_name { - let updated_column = Column::new(physical_name, column.index()); - return Ok(Arc::new(updated_column)); + let updated_column = Column::new(physical_name, idx); + return Ok(Transformed::yes(Arc::new(updated_column))); } } } + + // If names already match or fix is not possible, just leave it as it is + Ok(Transformed::no(node)) + } else { + Ok(Transformed::no(node)) } - } - expr + }) + .data() } struct OptimizationInvariantChecker<'a> { @@ -2192,11 +2312,16 @@ mod tests { use arrow::array::{ArrayRef, DictionaryArray, Int32Array}; use arrow::datatypes::{DataType, Field, Int32Type}; use datafusion_common::config::ConfigOptions; - use datafusion_common::{assert_contains, DFSchemaRef, TableReference}; + use datafusion_common::{ + assert_contains, DFSchemaRef, TableReference, ToDFSchema as _, + }; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; - use datafusion_expr::{col, lit, LogicalPlanBuilder, UserDefinedLogicalNodeCore}; + use datafusion_expr::{ + col, lit, LogicalPlanBuilder, Operator, UserDefinedLogicalNodeCore, + }; use datafusion_functions_aggregate::expr_fn::sum; + use datafusion_physical_expr::expressions::{BinaryExpr, IsNotNullExpr}; use datafusion_physical_expr::EquivalenceProperties; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -2238,7 +2363,8 @@ mod tests { // verify that the plan correctly casts u8 to i64 // the cast from u8 to i64 for literal will be simplified, and get lit(int64(5)) // the cast here is implicit so has CastOptions with safe=true - let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) }, fail_on_overflow: false }"; + let expected = r#"BinaryExpr { left: Column { name: "c7", index: 2 }, op: Lt, right: Literal { value: Int64(5), field: Field { name: "lit", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }"#; + assert!(format!("{exec_plan:?}").contains(expected)); Ok(()) } @@ -2263,7 +2389,7 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#; + let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL), field: Field { name: "lit", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c1"), (Literal { value: Int64(NULL), field: Field { name: "lit", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c2"), (Literal { value: Int64(NULL), field: Field { name: "lit", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#; assert_eq!(format!("{cube:?}"), expected); @@ -2290,7 +2416,7 @@ mod tests { &session_state, ); - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#; + let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL), field: Field { name: "lit", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c1"), (Literal { value: Int64(NULL), field: Field { name: "lit", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c2"), (Literal { value: Int64(NULL), field: Field { name: "lit", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#; assert_eq!(format!("{rollup:?}"), expected); @@ -2474,7 +2600,7 @@ mod tests { let execution_plan = plan(&logical_plan).await?; // verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated. - let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") }, fail_on_overflow: false }, fail_on_overflow: false }"; + let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }, fail_on_overflow: false }"; let actual = format!("{execution_plan:?}"); assert!(actual.contains(expected), "{}", actual); @@ -2689,6 +2815,54 @@ mod tests { } } + #[tokio::test] + async fn test_explain_indent_err() { + let planner = DefaultPhysicalPlanner::default(); + let ctx = SessionContext::new(); + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + let plan = Arc::new( + scan_empty(Some("employee"), &schema, None) + .unwrap() + .explain(true, false) + .unwrap() + .build() + .unwrap(), + ); + + // Create a schema + let schema = Arc::new(Schema::new(vec![ + Field::new("plan_type", DataType::Utf8, false), + Field::new("plan", DataType::Utf8, false), + ])); + + // Create invalid indentation in the plan + let stringified_plans = + vec![StringifiedPlan::new(PlanType::FinalLogicalPlan, "Test Err")]; + + let explain = Explain { + verbose: false, + explain_format: ExplainFormat::Indent, + plan, + stringified_plans, + schema: schema.to_dfschema_ref().unwrap(), + logical_optimization_succeeded: false, + }; + let plan = planner + .handle_explain(&explain, &ctx.state()) + .await + .unwrap(); + if let Some(plan) = plan.as_any().downcast_ref::() { + let stringified_plans = plan.stringified_plans(); + assert_eq!(stringified_plans.len(), 1); + assert_eq!(stringified_plans[0].plan.as_str(), "Test Err"); + } else { + panic!( + "Plan was not an explain plan: {}", + displayable(plan.as_ref()).indent(true) + ); + } + } + #[tokio::test] async fn test_maybe_fix_colon_in_physical_name() { // The physical schema has a field name with a colon @@ -2713,6 +2887,47 @@ mod tests { assert_eq!(col.name(), "metric:avg"); } + + #[tokio::test] + async fn test_maybe_fix_nested_column_name_with_colon() { + let schema = Schema::new(vec![Field::new("column", DataType::Int32, false)]); + let schema_ref: SchemaRef = Arc::new(schema); + + // Construct the nested expr + let col_expr = Arc::new(Column::new("column:1", 0)) as Arc; + let is_not_null_expr = Arc::new(IsNotNullExpr::new(col_expr.clone())); + + // Create a binary expression and put the column inside + let binary_expr = Arc::new(BinaryExpr::new( + is_not_null_expr.clone(), + Operator::Or, + is_not_null_expr.clone(), + )) as Arc; + + let fixed_expr = + maybe_fix_physical_column_name(Ok(binary_expr), &schema_ref).unwrap(); + + let bin = fixed_expr + .as_any() + .downcast_ref::() + .expect("Expected BinaryExpr"); + + // Check that both sides where renamed + for expr in &[bin.left(), bin.right()] { + let is_not_null = expr + .as_any() + .downcast_ref::() + .expect("Expected IsNotNull"); + + let col = is_not_null + .arg() + .as_any() + .downcast_ref::() + .expect("Expected Column"); + + assert_eq!(col.name(), "column"); + } + } struct ErrorExtensionPlanner {} #[async_trait] diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index 9c9fcd04bf09..d723620d3232 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -25,6 +25,7 @@ //! use datafusion::prelude::*; //! ``` +pub use crate::dataframe; pub use crate::dataframe::DataFrame; pub use crate::execution::context::{SQLOptions, SessionConfig, SessionContext}; pub use crate::execution::options::{ diff --git a/datafusion/core/src/test/object_store.rs b/datafusion/core/src/test/object_store.rs index 8b19658bb147..ed8474bbfc81 100644 --- a/datafusion/core/src/test/object_store.rs +++ b/datafusion/core/src/test/object_store.rs @@ -148,7 +148,7 @@ impl ObjectStore for BlockingObjectStore { "{} barrier wait timed out for {location}", BlockingObjectStore::NAME ); - log::error!("{}", error_message); + log::error!("{error_message}"); return Err(Error::Generic { store: BlockingObjectStore::NAME, source: error_message.into(), diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index d6865ca3d532..2f8e66a2bbfb 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -22,12 +22,14 @@ pub mod parquet; pub mod csv; +use futures::Stream; use std::any::Any; use std::collections::HashMap; use std::fs::File; use std::io::Write; use std::path::Path; use std::sync::Arc; +use std::task::{Context, Poll}; use crate::catalog::{TableProvider, TableProviderFactory}; use crate::dataframe::DataFrame; @@ -38,11 +40,13 @@ use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::ExecutionPlan; use crate::prelude::{CsvReadOptions, SessionContext}; +use crate::execution::SendableRecordBatchStream; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_catalog::Session; use datafusion_common::TableReference; use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; +use std::pin::Pin; use async_trait::async_trait; @@ -52,6 +56,8 @@ use tempfile::TempDir; pub use datafusion_common::test_util::parquet_test_data; pub use datafusion_common::test_util::{arrow_test_data, get_data_dir}; +use crate::execution::RecordBatchStream; + /// Scan an empty data source, mainly used in tests pub fn scan_empty( name: Option<&str>, @@ -234,3 +240,44 @@ pub fn register_unbounded_file_with_ordering( ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; Ok(()) } + +/// Creates a bounded stream that emits the same record batch a specified number of times. +/// This is useful for testing purposes. +pub fn bounded_stream( + record_batch: RecordBatch, + limit: usize, +) -> SendableRecordBatchStream { + Box::pin(BoundedStream { + record_batch, + count: 0, + limit, + }) +} + +struct BoundedStream { + record_batch: RecordBatch, + count: usize, + limit: usize, +} + +impl Stream for BoundedStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + if self.count >= self.limit { + Poll::Ready(None) + } else { + self.count += 1; + Poll::Ready(Some(Ok(self.record_batch.clone()))) + } + } +} + +impl RecordBatchStream for BoundedStream { + fn schema(&self) -> SchemaRef { + self.record_batch.schema() + } +} diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index f5753af64d93..eb4c61c02524 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -37,6 +37,7 @@ use crate::physical_plan::metrics::MetricsSet; use crate::physical_plan::ExecutionPlan; use crate::prelude::{Expr, SessionConfig, SessionContext}; +use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_datasource::source::DataSourceExec; use object_store::path::Path; @@ -82,23 +83,22 @@ impl TestParquetFile { props: WriterProperties, batches: impl IntoIterator, ) -> Result { - let file = File::create(&path).unwrap(); + let file = File::create(&path)?; let mut batches = batches.into_iter(); let first_batch = batches.next().expect("need at least one record batch"); let schema = first_batch.schema(); - let mut writer = - ArrowWriter::try_new(file, Arc::clone(&schema), Some(props)).unwrap(); + let mut writer = ArrowWriter::try_new(file, Arc::clone(&schema), Some(props))?; - writer.write(&first_batch).unwrap(); + writer.write(&first_batch)?; let mut num_rows = first_batch.num_rows(); for batch in batches { - writer.write(&batch).unwrap(); + writer.write(&batch)?; num_rows += batch.num_rows(); } - writer.close().unwrap(); + writer.close()?; println!("Generated test dataset with {num_rows} rows"); @@ -182,10 +182,11 @@ impl TestParquetFile { let physical_filter_expr = create_physical_expr(&filter, &df_schema, &ExecutionProps::default())?; - let source = Arc::new(ParquetSource::new(parquet_options).with_predicate( - Arc::clone(&self.schema), - Arc::clone(&physical_filter_expr), - )); + let source = Arc::new( + ParquetSource::new(parquet_options) + .with_predicate(Arc::clone(&physical_filter_expr)), + ) + .with_schema(Arc::clone(&self.schema)); let config = scan_config_builder.with_source(source).build(); let parquet_exec = DataSourceExec::from_data_source(config); diff --git a/datafusion/core/tests/custom_sources_cases/mod.rs b/datafusion/core/tests/custom_sources_cases/mod.rs index eb930b9a60bc..cbdc4a448ea4 100644 --- a/datafusion/core/tests/custom_sources_cases/mod.rs +++ b/datafusion/core/tests/custom_sources_cases/mod.rs @@ -180,6 +180,13 @@ impl ExecutionPlan for CustomExecutionPlan { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema())); + } let batch = TEST_CUSTOM_RECORD_BATCH!().unwrap(); Ok(Statistics { num_rows: Precision::Exact(batch.num_rows()), diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index f68bcfaf1550..c80c0b4bf54b 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -179,12 +179,12 @@ impl TableProvider for CustomProvider { match &filters[0] { Expr::BinaryExpr(BinaryExpr { right, .. }) => { let int_value = match &**right { - Expr::Literal(ScalarValue::Int8(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int16(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int32(Some(i))) => *i as i64, - Expr::Literal(ScalarValue::Int64(Some(i))) => *i, + Expr::Literal(ScalarValue::Int8(Some(i)), _) => *i as i64, + Expr::Literal(ScalarValue::Int16(Some(i)), _) => *i as i64, + Expr::Literal(ScalarValue::Int32(Some(i)), _) => *i as i64, + Expr::Literal(ScalarValue::Int64(Some(i)), _) => *i, Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() { - Expr::Literal(lit_value) => match lit_value { + Expr::Literal(lit_value, _) => match lit_value { ScalarValue::Int8(Some(v)) => *v as i64, ScalarValue::Int16(Some(v)) => *v as i64, ScalarValue::Int32(Some(v)) => *v as i64, diff --git a/datafusion/core/tests/custom_sources_cases/statistics.rs b/datafusion/core/tests/custom_sources_cases/statistics.rs index 66c886510e96..f9b0db0e808c 100644 --- a/datafusion/core/tests/custom_sources_cases/statistics.rs +++ b/datafusion/core/tests/custom_sources_cases/statistics.rs @@ -184,6 +184,14 @@ impl ExecutionPlan for StatisticsValidation { fn statistics(&self) -> Result { Ok(self.stats.clone()) } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + Ok(Statistics::new_unknown(&self.schema)) + } else { + Ok(self.stats.clone()) + } + } } fn init_ctx(stats: Statistics, schema: Schema) -> Result { @@ -232,7 +240,7 @@ async fn sql_basic() -> Result<()> { let physical_plan = df.create_physical_plan().await.unwrap(); // the statistics should be those of the source - assert_eq!(stats, physical_plan.statistics()?); + assert_eq!(stats, physical_plan.partition_statistics(None)?); Ok(()) } @@ -248,7 +256,7 @@ async fn sql_filter() -> Result<()> { .unwrap(); let physical_plan = df.create_physical_plan().await.unwrap(); - let stats = physical_plan.statistics()?; + let stats = physical_plan.partition_statistics(None)?; assert_eq!(stats.num_rows, Precision::Inexact(1)); Ok(()) @@ -270,7 +278,7 @@ async fn sql_limit() -> Result<()> { column_statistics: col_stats, total_byte_size: Precision::Absent }, - physical_plan.statistics()? + physical_plan.partition_statistics(None)? ); let df = ctx @@ -279,7 +287,7 @@ async fn sql_limit() -> Result<()> { .unwrap(); let physical_plan = df.create_physical_plan().await.unwrap(); // when the limit is larger than the original number of lines, statistics remain unchanged - assert_eq!(stats, physical_plan.statistics()?); + assert_eq!(stats, physical_plan.partition_statistics(None)?); Ok(()) } @@ -296,7 +304,7 @@ async fn sql_window() -> Result<()> { let physical_plan = df.create_physical_plan().await.unwrap(); - let result = physical_plan.statistics()?; + let result = physical_plan.partition_statistics(None)?; assert_eq!(stats.num_rows, result.num_rows); let col_stats = result.column_statistics; diff --git a/datafusion/core/tests/data/filter_pushdown/single_file.gz.parquet b/datafusion/core/tests/data/filter_pushdown/single_file.gz.parquet new file mode 100644 index 000000000000..ed700576a5af Binary files /dev/null and b/datafusion/core/tests/data/filter_pushdown/single_file.gz.parquet differ diff --git a/datafusion/core/tests/data/filter_pushdown/single_file_small_pages.gz.parquet b/datafusion/core/tests/data/filter_pushdown/single_file_small_pages.gz.parquet new file mode 100644 index 000000000000..29282cfbb622 Binary files /dev/null and b/datafusion/core/tests/data/filter_pushdown/single_file_small_pages.gz.parquet differ diff --git a/datafusion/core/tests/data/int96_nested.parquet b/datafusion/core/tests/data/int96_nested.parquet new file mode 100644 index 000000000000..708823ded6fa Binary files /dev/null and b/datafusion/core/tests/data/int96_nested.parquet differ diff --git a/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet new file mode 100644 index 000000000000..ec164c6df7b5 Binary files /dev/null and b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet differ diff --git a/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet new file mode 100644 index 000000000000..4b78cf963c11 Binary files /dev/null and b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet differ diff --git a/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet new file mode 100644 index 000000000000..09a01771d503 Binary files /dev/null and b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet differ diff --git a/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet new file mode 100644 index 000000000000..6398cc43a2f5 Binary files /dev/null and b/datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet differ diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index c763d4c8de2d..40590d74ad91 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -384,7 +384,7 @@ async fn test_fn_approx_median() -> Result<()> { #[tokio::test] async fn test_fn_approx_percentile_cont() -> Result<()> { - let expr = approx_percentile_cont(col("b"), lit(0.5), None); + let expr = approx_percentile_cont(col("b").sort(true, false), lit(0.5), None); let df = create_test_table().await?; let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; @@ -392,11 +392,26 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { assert_snapshot!( batches_to_string(&batches), @r" - +---------------------------------------------+ - | approx_percentile_cont(test.b,Float64(0.5)) | - +---------------------------------------------+ - | 10 | - +---------------------------------------------+ + +---------------------------------------------------------------------------+ + | approx_percentile_cont(Float64(0.5)) WITHIN GROUP [test.b ASC NULLS LAST] | + +---------------------------------------------------------------------------+ + | 10 | + +---------------------------------------------------------------------------+ + "); + + let expr = approx_percentile_cont(col("b").sort(false, false), lit(0.1), None); + + let df = create_test_table().await?; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_snapshot!( + batches_to_string(&batches), + @r" + +----------------------------------------------------------------------------+ + | approx_percentile_cont(Float64(0.1)) WITHIN GROUP [test.b DESC NULLS LAST] | + +----------------------------------------------------------------------------+ + | 100 | + +----------------------------------------------------------------------------+ "); // the arg2 parameter is a complex expr, but it can be evaluated to the literal value @@ -405,23 +420,59 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { None::<&str>, "arg_2".to_string(), )); - let expr = approx_percentile_cont(col("b"), alias_expr, None); + let expr = approx_percentile_cont(col("b").sort(true, false), alias_expr, None); let df = create_test_table().await?; let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; assert_snapshot!( batches_to_string(&batches), @r" - +--------------------------------------+ - | approx_percentile_cont(test.b,arg_2) | - +--------------------------------------+ - | 10 | - +--------------------------------------+ + +--------------------------------------------------------------------+ + | approx_percentile_cont(arg_2) WITHIN GROUP [test.b ASC NULLS LAST] | + +--------------------------------------------------------------------+ + | 10 | + +--------------------------------------------------------------------+ + " + ); + + let alias_expr = Expr::Alias(Alias::new( + cast(lit(0.1), DataType::Float32), + None::<&str>, + "arg_2".to_string(), + )); + let expr = approx_percentile_cont(col("b").sort(false, false), alias_expr, None); + let df = create_test_table().await?; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_snapshot!( + batches_to_string(&batches), + @r" + +---------------------------------------------------------------------+ + | approx_percentile_cont(arg_2) WITHIN GROUP [test.b DESC NULLS LAST] | + +---------------------------------------------------------------------+ + | 100 | + +---------------------------------------------------------------------+ " ); // with number of centroids set - let expr = approx_percentile_cont(col("b"), lit(0.5), Some(lit(2))); + let expr = approx_percentile_cont(col("b").sort(true, false), lit(0.5), Some(lit(2))); + + let df = create_test_table().await?; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_snapshot!( + batches_to_string(&batches), + @r" + +------------------------------------------------------------------------------------+ + | approx_percentile_cont(Float64(0.5),Int32(2)) WITHIN GROUP [test.b ASC NULLS LAST] | + +------------------------------------------------------------------------------------+ + | 30 | + +------------------------------------------------------------------------------------+ + "); + + let expr = + approx_percentile_cont(col("b").sort(false, false), lit(0.1), Some(lit(2))); let df = create_test_table().await?; let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; @@ -429,11 +480,11 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { assert_snapshot!( batches_to_string(&batches), @r" - +------------------------------------------------------+ - | approx_percentile_cont(test.b,Float64(0.5),Int32(2)) | - +------------------------------------------------------+ - | 30 | - +------------------------------------------------------+ + +-------------------------------------------------------------------------------------+ + | approx_percentile_cont(Float64(0.1),Int32(2)) WITHIN GROUP [test.b DESC NULLS LAST] | + +-------------------------------------------------------------------------------------+ + | 69 | + +-------------------------------------------------------------------------------------+ "); Ok(()) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 1855a512048d..8d60dbea3d01 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -32,6 +32,7 @@ use arrow::datatypes::{ }; use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_batches; +use datafusion::{assert_batches_eq, dataframe}; use datafusion_functions_aggregate::count::{count_all, count_all_window}; use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, count, count_distinct, max, median, min, sum, @@ -69,7 +70,7 @@ use datafusion_common::{ use datafusion_common_runtime::SpawnedTask; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_expr::expr::{GroupingSet, Sort, WindowFunction}; +use datafusion_expr::expr::{FieldMetadata, GroupingSet, Sort, WindowFunction}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ cast, col, create_udf, exists, in_subquery, lit, out_ref_col, placeholder, @@ -906,7 +907,7 @@ async fn window_using_aggregates() -> Result<()> { vec![col("c3")], ); - Expr::WindowFunction(w) + Expr::from(w) .null_treatment(NullTreatment::IgnoreNulls) .order_by(vec![col("c2").sort(true, true), col("c3").sort(true, true)]) .window_frame(WindowFrame::new_bounds( @@ -1209,7 +1210,7 @@ async fn join_on_filter_datatype() -> Result<()> { let join = left.clone().join_on( right.clone(), JoinType::Inner, - Some(Expr::Literal(ScalarValue::Null)), + Some(Expr::Literal(ScalarValue::Null, None)), )?; assert_snapshot!(join.into_optimized_plan().unwrap(), @"EmptyRelation"); @@ -1852,6 +1853,56 @@ async fn with_column_renamed_case_sensitive() -> Result<()> { Ok(()) } +#[tokio::test] +async fn describe_lookup_via_quoted_identifier() -> Result<()> { + let ctx = SessionContext::new(); + let name = "aggregate_test_100"; + register_aggregate_csv(&ctx, name).await?; + let df = ctx.table(name); + + let df = df + .await? + .filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))? + .limit(0, Some(1))? + .sort(vec![ + // make the test deterministic + col("c1").sort(true, true), + col("c2").sort(true, true), + col("c3").sort(true, true), + ])? + .select_columns(&["c1"])?; + + let df_renamed = df.clone().with_column_renamed("c1", "CoLu.Mn[\"1\"]")?; + + let describe_result = df_renamed.describe().await?; + describe_result + .clone() + .sort(vec![ + col("describe").sort(true, true), + col("CoLu.Mn[\"1\"]").sort(true, true), + ])? + .show() + .await?; + assert_snapshot!( + batches_to_sort_string(&describe_result.clone().collect().await?), + @r###" + +------------+--------------+ + | describe | CoLu.Mn["1"] | + +------------+--------------+ + | count | 1 | + | max | a | + | mean | null | + | median | null | + | min | a | + | null_count | 0 | + | std | null | + +------------+--------------+ + "### + ); + + Ok(()) +} + #[tokio::test] async fn cast_expr_test() -> Result<()> { let df = test_table() @@ -2094,6 +2145,7 @@ async fn verify_join_output_partitioning() -> Result<()> { JoinType::LeftAnti, JoinType::RightAnti, JoinType::LeftMark, + JoinType::RightMark, ]; let default_partition_count = SessionConfig::new().target_partitions(); @@ -2127,7 +2179,8 @@ async fn verify_join_output_partitioning() -> Result<()> { JoinType::Inner | JoinType::Right | JoinType::RightSemi - | JoinType::RightAnti => { + | JoinType::RightAnti + | JoinType::RightMark => { let right_exprs: Vec> = vec![ Arc::new(Column::new_with_schema("c2_c1", &join_schema)?), Arc::new(Column::new_with_schema("c2_c2", &join_schema)?), @@ -2454,6 +2507,11 @@ async fn write_table_with_order() -> Result<()> { write_df = write_df .with_column_renamed("column1", "tablecol1") .unwrap(); + + // Ensure the column type matches the target table + write_df = + write_df.with_column("tablecol1", cast(col("tablecol1"), DataType::Utf8View))?; + let sql_str = "create external table data(tablecol1 varchar) stored as parquet location '" .to_owned() @@ -2525,7 +2583,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> { | | TableScan: t1 projection=[b] | | physical_plan | ProjectionExec: expr=[b@0 as b, count(*)@1 as count(*)] | | | SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] | - | | SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true] | + | | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] | | | ProjectionExec: expr=[b@0 as b, count(Int64(1))@1 as count(*), count(Int64(1))@1 as count(Int64(1))] | | | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(Int64(1))] | | | CoalesceBatchesExec: target_batch_size=8192 | @@ -3570,16 +3628,15 @@ async fn unnest_columns() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), @r###" - +----------+------------------------------------------------+--------------------+ - | shape_id | points | tags | - +----------+------------------------------------------------+--------------------+ - | 1 | [{x: -3, y: -4}, {x: -3, y: 6}, {x: 2, y: -2}] | [tag1] | - | 2 | | [tag1, tag2] | - | 3 | [{x: -9, y: 2}, {x: -10, y: -4}] | | - | 4 | [{x: -3, y: 5}, {x: 2, y: -1}] | [tag1, tag2, tag3] | - +----------+------------------------------------------------+--------------------+ - "### - ); + +----------+---------------------------------+--------------------------+ + | shape_id | points | tags | + +----------+---------------------------------+--------------------------+ + | 1 | [{x: 5, y: -8}, {x: -3, y: -4}] | [tag1] | + | 2 | [{x: 6, y: 2}, {x: -2, y: -8}] | [tag1] | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | [tag1, tag2, tag3, tag4] | + | 4 | | [tag1, tag2, tag3] | + +----------+---------------------------------+--------------------------+ + "###); // Unnest tags let df = table_with_nested_types(NUM_ROWS).await?; @@ -3587,19 +3644,20 @@ async fn unnest_columns() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), @r###" - +----------+------------------------------------------------+------+ - | shape_id | points | tags | - +----------+------------------------------------------------+------+ - | 1 | [{x: -3, y: -4}, {x: -3, y: 6}, {x: 2, y: -2}] | tag1 | - | 2 | | tag1 | - | 2 | | tag2 | - | 3 | [{x: -9, y: 2}, {x: -10, y: -4}] | | - | 4 | [{x: -3, y: 5}, {x: 2, y: -1}] | tag1 | - | 4 | [{x: -3, y: 5}, {x: 2, y: -1}] | tag2 | - | 4 | [{x: -3, y: 5}, {x: 2, y: -1}] | tag3 | - +----------+------------------------------------------------+------+ - "### - ); + +----------+---------------------------------+------+ + | shape_id | points | tags | + +----------+---------------------------------+------+ + | 1 | [{x: 5, y: -8}, {x: -3, y: -4}] | tag1 | + | 2 | [{x: 6, y: 2}, {x: -2, y: -8}] | tag1 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag1 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag2 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag3 | + | 3 | [{x: -9, y: -7}, {x: -2, y: 5}] | tag4 | + | 4 | | tag1 | + | 4 | | tag2 | + | 4 | | tag3 | + +----------+---------------------------------+------+ + "###); // Test aggregate results for tags. let df = table_with_nested_types(NUM_ROWS).await?; @@ -3612,20 +3670,18 @@ async fn unnest_columns() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), @r###" - +----------+-----------------+--------------------+ - | shape_id | points | tags | - +----------+-----------------+--------------------+ - | 1 | {x: -3, y: -4} | [tag1] | - | 1 | {x: -3, y: 6} | [tag1] | - | 1 | {x: 2, y: -2} | [tag1] | - | 2 | | [tag1, tag2] | - | 3 | {x: -10, y: -4} | | - | 3 | {x: -9, y: 2} | | - | 4 | {x: -3, y: 5} | [tag1, tag2, tag3] | - | 4 | {x: 2, y: -1} | [tag1, tag2, tag3] | - +----------+-----------------+--------------------+ - "### - ); + +----------+----------------+--------------------------+ + | shape_id | points | tags | + +----------+----------------+--------------------------+ + | 1 | {x: -3, y: -4} | [tag1] | + | 1 | {x: 5, y: -8} | [tag1] | + | 2 | {x: -2, y: -8} | [tag1] | + | 2 | {x: 6, y: 2} | [tag1] | + | 3 | {x: -2, y: 5} | [tag1, tag2, tag3, tag4] | + | 3 | {x: -9, y: -7} | [tag1, tag2, tag3, tag4] | + | 4 | | [tag1, tag2, tag3] | + +----------+----------------+--------------------------+ + "###); // Test aggregate results for points. let df = table_with_nested_types(NUM_ROWS).await?; @@ -3642,25 +3698,26 @@ async fn unnest_columns() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), @r###" - +----------+-----------------+------+ - | shape_id | points | tags | - +----------+-----------------+------+ - | 1 | {x: -3, y: -4} | tag1 | - | 1 | {x: -3, y: 6} | tag1 | - | 1 | {x: 2, y: -2} | tag1 | - | 2 | | tag1 | - | 2 | | tag2 | - | 3 | {x: -10, y: -4} | | - | 3 | {x: -9, y: 2} | | - | 4 | {x: -3, y: 5} | tag1 | - | 4 | {x: -3, y: 5} | tag2 | - | 4 | {x: -3, y: 5} | tag3 | - | 4 | {x: 2, y: -1} | tag1 | - | 4 | {x: 2, y: -1} | tag2 | - | 4 | {x: 2, y: -1} | tag3 | - +----------+-----------------+------+ - "### - ); + +----------+----------------+------+ + | shape_id | points | tags | + +----------+----------------+------+ + | 1 | {x: -3, y: -4} | tag1 | + | 1 | {x: 5, y: -8} | tag1 | + | 2 | {x: -2, y: -8} | tag1 | + | 2 | {x: 6, y: 2} | tag1 | + | 3 | {x: -2, y: 5} | tag1 | + | 3 | {x: -2, y: 5} | tag2 | + | 3 | {x: -2, y: 5} | tag3 | + | 3 | {x: -2, y: 5} | tag4 | + | 3 | {x: -9, y: -7} | tag1 | + | 3 | {x: -9, y: -7} | tag2 | + | 3 | {x: -9, y: -7} | tag3 | + | 3 | {x: -9, y: -7} | tag4 | + | 4 | | tag1 | + | 4 | | tag2 | + | 4 | | tag3 | + +----------+----------------+------+ + "###); // Test aggregate results for points and tags. let df = table_with_nested_types(NUM_ROWS).await?; @@ -3994,15 +4051,15 @@ async fn unnest_aggregate_columns() -> Result<()> { assert_snapshot!( batches_to_sort_string(&results), @r###" - +--------------------+ - | tags | - +--------------------+ - | | - | [tag1, tag2, tag3] | - | [tag1, tag2, tag3] | - | [tag1, tag2] | - | [tag1] | - +--------------------+ + +--------------------------+ + | tags | + +--------------------------+ + | [tag1, tag2, tag3, tag4] | + | [tag1, tag2, tag3] | + | [tag1, tag2] | + | [tag1] | + | [tag1] | + +--------------------------+ "### ); @@ -4018,7 +4075,7 @@ async fn unnest_aggregate_columns() -> Result<()> { +-------------+ | count(tags) | +-------------+ - | 9 | + | 11 | +-------------+ "### ); @@ -4267,7 +4324,7 @@ async fn unnest_analyze_metrics() -> Result<()> { assert_contains!(&formatted, "elapsed_compute="); assert_contains!(&formatted, "input_batches=1"); assert_contains!(&formatted, "input_rows=5"); - assert_contains!(&formatted, "output_rows=10"); + assert_contains!(&formatted, "output_rows=11"); assert_contains!(&formatted, "output_batches=1"); Ok(()) @@ -4472,7 +4529,10 @@ async fn consecutive_projection_same_schema() -> Result<()> { // Add `t` column full of nulls let df = df - .with_column("t", cast(Expr::Literal(ScalarValue::Null), DataType::Int32)) + .with_column( + "t", + cast(Expr::Literal(ScalarValue::Null, None), DataType::Int32), + ) .unwrap(); df.clone().show().await.unwrap(); @@ -4614,7 +4674,7 @@ async fn table_with_nested_types(n: usize) -> Result { shape_id_builder.append_value(idx as u32 + 1); // Add a random number of points - let num_points: usize = rng.gen_range(0..4); + let num_points: usize = rng.random_range(0..4); if num_points > 0 { for _ in 0..num_points.max(2) { // Add x value @@ -4622,13 +4682,13 @@ async fn table_with_nested_types(n: usize) -> Result { .values() .field_builder::(0) .unwrap() - .append_value(rng.gen_range(-10..10)); + .append_value(rng.random_range(-10..10)); // Add y value points_builder .values() .field_builder::(1) .unwrap() - .append_value(rng.gen_range(-10..10)); + .append_value(rng.random_range(-10..10)); points_builder.values().append(true); } } @@ -4637,7 +4697,7 @@ async fn table_with_nested_types(n: usize) -> Result { points_builder.append(num_points > 0); // Append tags. - let num_tags: usize = rng.gen_range(0..5); + let num_tags: usize = rng.random_range(0..5); for id in 0..num_tags { tags_builder.values().append_value(format!("tag{}", id + 1)); } @@ -5079,7 +5139,7 @@ async fn write_partitioned_parquet_results() -> Result<()> { .await?; // Explicitly read the parquet file at c2=123 to verify the physical files are partitioned - let partitioned_file = format!("{out_dir}/c2=123", out_dir = out_dir); + let partitioned_file = format!("{out_dir}/c2=123"); let filter_df = ctx .read_parquet(&partitioned_file, ParquetReadOptions::default()) .await?; @@ -5616,6 +5676,7 @@ async fn test_alias() -> Result<()> { async fn test_alias_with_metadata() -> Result<()> { let mut metadata = HashMap::new(); metadata.insert(String::from("k"), String::from("v")); + let metadata = FieldMetadata::from(metadata); let df = create_test_table("test") .await? .select(vec![col("a").alias_with_metadata("b", Some(metadata))])? @@ -6017,3 +6078,62 @@ async fn test_insert_into_casting_support() -> Result<()> { ); Ok(()) } + +#[tokio::test] +async fn test_dataframe_from_columns() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let b: ArrayRef = Arc::new(BooleanArray::from(vec![true, true, false])); + let c: ArrayRef = Arc::new(StringArray::from(vec![Some("foo"), Some("bar"), None])); + let df = DataFrame::from_columns(vec![("a", a), ("b", b), ("c", c)])?; + + assert_eq!(df.schema().fields().len(), 3); + assert_eq!(df.clone().count().await?, 3); + + let rows = df.sort(vec![col("a").sort(true, true)])?; + assert_batches_eq!( + &[ + "+---+-------+-----+", + "| a | b | c |", + "+---+-------+-----+", + "| 1 | true | foo |", + "| 2 | true | bar |", + "| 3 | false | |", + "+---+-------+-----+", + ], + &rows.collect().await? + ); + + Ok(()) +} + +#[tokio::test] +async fn test_dataframe_macro() -> Result<()> { + let df = dataframe!( + "a" => [1, 2, 3], + "b" => [true, true, false], + "c" => [Some("foo"), Some("bar"), None] + )?; + + assert_eq!(df.schema().fields().len(), 3); + assert_eq!(df.clone().count().await?, 3); + + let rows = df.sort(vec![col("a").sort(true, true)])?; + assert_batches_eq!( + &[ + "+---+-------+-----+", + "| a | b | c |", + "+---+-------+-----+", + "| 1 | true | foo |", + "| 2 | true | bar |", + "| 3 | false | |", + "+---+-------+-----+", + ], + &rows.collect().await? + ); + + let df_empty = dataframe!()?; + assert_eq!(df_empty.schema().fields().len(), 0); + assert_eq!(df_empty.count().await?, 0); + + Ok(()) +} diff --git a/datafusion/core/tests/execution/coop.rs b/datafusion/core/tests/execution/coop.rs new file mode 100644 index 000000000000..d8aceadcec66 --- /dev/null +++ b/datafusion/core/tests/execution/coop.rs @@ -0,0 +1,755 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Int64Array, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow_schema::SortOptions; +use datafusion::common::NullEquality; +use datafusion::functions_aggregate::sum; +use datafusion::physical_expr::aggregate::AggregateExprBuilder; +use datafusion::physical_plan; +use datafusion::physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion::physical_plan::execution_plan::Boundedness; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionContext; +use datafusion_common::{DataFusionError, JoinType, ScalarValue}; +use datafusion_execution::TaskContext; +use datafusion_expr_common::operator::Operator; +use datafusion_expr_common::operator::Operator::{Divide, Eq, Gt, Modulo}; +use datafusion_functions_aggregate::min_max; +use datafusion_physical_expr::expressions::{ + binary, col, lit, BinaryExpr, Column, Literal, +}; +use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_optimizer::ensure_coop::EnsureCooperative; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; +use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec}; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::union::InterleaveExec; +use futures::StreamExt; +use parking_lot::RwLock; +use rstest::rstest; +use std::error::Error; +use std::fmt::Formatter; +use std::ops::Range; +use std::sync::Arc; +use std::task::Poll; +use std::time::Duration; +use tokio::runtime::{Handle, Runtime}; +use tokio::select; + +#[derive(Debug)] +struct RangeBatchGenerator { + schema: SchemaRef, + value_range: Range, + boundedness: Boundedness, + batch_size: usize, + poll_count: usize, +} + +impl std::fmt::Display for RangeBatchGenerator { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + // Display current counter + write!(f, "InfiniteGenerator(counter={})", self.poll_count) + } +} + +impl LazyBatchGenerator for RangeBatchGenerator { + fn boundedness(&self) -> Boundedness { + self.boundedness + } + + /// Generate the next RecordBatch. + fn generate_next_batch(&mut self) -> datafusion_common::Result> { + self.poll_count += 1; + + let mut builder = Int64Array::builder(self.batch_size); + for _ in 0..self.batch_size { + match self.value_range.next() { + None => break, + Some(v) => builder.append_value(v), + } + } + let array = builder.finish(); + + if array.is_empty() { + return Ok(None); + } + + let batch = + RecordBatch::try_new(Arc::clone(&self.schema), vec![Arc::new(array)])?; + Ok(Some(batch)) + } +} + +fn make_lazy_exec(column_name: &str, pretend_infinite: bool) -> LazyMemoryExec { + make_lazy_exec_with_range(column_name, i64::MIN..i64::MAX, pretend_infinite) +} + +fn make_lazy_exec_with_range( + column_name: &str, + range: Range, + pretend_infinite: bool, +) -> LazyMemoryExec { + let schema = Arc::new(Schema::new(vec![Field::new( + column_name, + DataType::Int64, + false, + )])); + + let boundedness = if pretend_infinite { + Boundedness::Unbounded { + requires_infinite_memory: false, + } + } else { + Boundedness::Bounded + }; + + // Instantiate the generator with the batch and limit + let gen = RangeBatchGenerator { + schema: Arc::clone(&schema), + boundedness, + value_range: range, + batch_size: 8192, + poll_count: 0, + }; + + // Wrap the generator in a trait object behind Arc> + let generator: Arc> = Arc::new(RwLock::new(gen)); + + // Create a LazyMemoryExec with one partition using our generator + let mut exec = LazyMemoryExec::try_new(schema, vec![generator]).unwrap(); + + exec.add_ordering(vec![PhysicalSortExpr::new( + Arc::new(Column::new(column_name, 0)), + SortOptions::new(false, true), + )]); + + exec +} + +#[rstest] +#[tokio::test] +async fn agg_no_grouping_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up an aggregation without grouping + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new(vec![], vec![], vec![]), + vec![Arc::new( + AggregateExprBuilder::new( + sum::sum_udaf(), + vec![col("value", &inf.schema())?], + ) + .schema(inf.schema()) + .alias("sum") + .build()?, + )], + vec![None], + inf.clone(), + inf.schema(), + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn agg_grouping_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up an aggregation with grouping + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + let value_col = col("value", &inf.schema())?; + let group = binary(value_col.clone(), Divide, lit(1000000i64), &inf.schema())?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new(vec![(group, "group".to_string())], vec![], vec![]), + vec![Arc::new( + AggregateExprBuilder::new(sum::sum_udaf(), vec![value_col.clone()]) + .schema(inf.schema()) + .alias("sum") + .build()?, + )], + vec![None], + inf.clone(), + inf.schema(), + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn agg_grouped_topk_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up a top-k aggregation + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + let value_col = col("value", &inf.schema())?; + let group = binary(value_col.clone(), Divide, lit(1000000i64), &inf.schema())?; + + let aggr = Arc::new( + AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new( + vec![(group, "group".to_string())], + vec![], + vec![vec![false]], + ), + vec![Arc::new( + AggregateExprBuilder::new(min_max::max_udaf(), vec![value_col.clone()]) + .schema(inf.schema()) + .alias("max") + .build()?, + )], + vec![None], + inf.clone(), + inf.schema(), + )? + .with_limit(Some(100)), + ); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn sort_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up the infinite source + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + // set up a SortExec that will not be able to finish in time because input is very large + let sort_expr = PhysicalSortExpr::new( + col("value", &inf.schema())?, + SortOptions { + descending: true, + nulls_first: true, + }, + ); + + let lex_ordering = LexOrdering::new(vec![sort_expr]).unwrap(); + let sort_exec = Arc::new(SortExec::new(lex_ordering, inf.clone())); + + query_yields(sort_exec, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn sort_merge_join_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up the join sources + let inf1 = Arc::new(make_lazy_exec_with_range( + "value1", + i64::MIN..0, + pretend_infinite, + )); + let inf2 = Arc::new(make_lazy_exec_with_range( + "value2", + 0..i64::MAX, + pretend_infinite, + )); + + // set up a SortMergeJoinExec that will take a long time skipping left side content to find + // the first right side match + let join = Arc::new(SortMergeJoinExec::try_new( + inf1.clone(), + inf2.clone(), + vec![( + col("value1", &inf1.schema())?, + col("value2", &inf2.schema())?, + )], + None, + JoinType::Inner, + vec![inf1.properties().eq_properties.output_ordering().unwrap()[0].options], + NullEquality::NullEqualsNull, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn filter_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up the infinite source + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + // set up a FilterExec that will filter out entire batches + let filter_expr = binary( + col("value", &inf.schema())?, + Operator::Lt, + lit(i64::MIN), + &inf.schema(), + )?; + let filter = Arc::new(FilterExec::try_new(filter_expr, inf.clone())?); + + query_yields(filter, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn filter_reject_all_batches_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // Create a Session, Schema, and an 8K-row RecordBatch + let session_ctx = SessionContext::new(); + + // Wrap this batch in an InfiniteExec + let infinite = make_lazy_exec_with_range("value", i64::MIN..0, pretend_infinite); + + // 2b) Construct a FilterExec that is always false: “value > 10000” (no rows pass) + let false_predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("value", 0)), + Gt, + Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), + )); + let filtered = Arc::new(FilterExec::try_new(false_predicate, Arc::new(infinite))?); + + // Use CoalesceBatchesExec to guarantee each Filter pull always yields an 8192-row batch + let coalesced = Arc::new(CoalesceBatchesExec::new(filtered, 8_192)); + + query_yields(coalesced, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn interleave_then_filter_all_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // Build a session and a schema with one i64 column. + let session_ctx = SessionContext::new(); + + // Create multiple infinite sources, each filtered by a different threshold. + // This ensures InterleaveExec has many children. + let mut infinite_children = vec![]; + + // Use 32 distinct thresholds (each >0 and <8 192) to force 32 infinite inputs + for thr in 1..32 { + // One infinite exec: + let mut inf = make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // Now repartition so that all children share identical Hash partitioning + // on “value” into 1 bucket. This is required for InterleaveExec::try_new. + let exprs = vec![Arc::new(Column::new("value", 0)) as _]; + let partitioning = Partitioning::Hash(exprs, 1); + inf.try_set_partitioning(partitioning)?; + + // Apply a FilterExec: “(value / 8192) % thr == 0”. + let filter_expr = binary( + binary( + binary( + col("value", &inf.schema())?, + Divide, + lit(8192i64), + &inf.schema(), + )?, + Modulo, + lit(thr as i64), + &inf.schema(), + )?, + Eq, + lit(0i64), + &inf.schema(), + )?; + let filtered = Arc::new(FilterExec::try_new(filter_expr, Arc::new(inf))?); + + infinite_children.push(filtered as _); + } + + // Build an InterleaveExec over all infinite children. + let interleave = Arc::new(InterleaveExec::try_new(infinite_children)?); + + // Wrap the InterleaveExec in a FilterExec that always returns false, + // ensuring that no rows are ever emitted. + let always_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))); + let filtered_interleave = Arc::new(FilterExec::try_new(always_false, interleave)?); + + query_yields(filtered_interleave, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn interleave_then_aggregate_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // Build session, schema, and a sample batch. + let session_ctx = SessionContext::new(); + + // Create N infinite sources, each filtered by a different predicate. + // That way, the InterleaveExec will have multiple children. + let mut infinite_children = vec![]; + + // Use 32 distinct thresholds (each >0 and <8 192) to force 32 infinite inputs + for thr in 1..32 { + // One infinite exec: + let mut inf = make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // Now repartition so that all children share identical Hash partitioning + // on “value” into 1 bucket. This is required for InterleaveExec::try_new. + let exprs = vec![Arc::new(Column::new("value", 0)) as _]; + let partitioning = Partitioning::Hash(exprs, 1); + inf.try_set_partitioning(partitioning)?; + + // Apply a FilterExec: “(value / 8192) % thr == 0”. + let filter_expr = binary( + binary( + binary( + col("value", &inf.schema())?, + Divide, + lit(8192i64), + &inf.schema(), + )?, + Modulo, + lit(thr as i64), + &inf.schema(), + )?, + Eq, + lit(0i64), + &inf.schema(), + )?; + let filtered = Arc::new(FilterExec::try_new(filter_expr, Arc::new(inf))?); + + infinite_children.push(filtered as _); + } + + // Build an InterleaveExec over all N children. + // Since each child now has Partitioning::Hash([col "value"], 1), InterleaveExec::try_new succeeds. + let interleave = Arc::new(InterleaveExec::try_new(infinite_children)?); + let interleave_schema = interleave.schema(); + + // Build a global AggregateExec that sums “value” over all rows. + // Because we use `AggregateMode::Single` with no GROUP BY columns, this plan will + // only produce one “final” row once all inputs finish. But our inputs never finish, + // so we should never get any output. + let aggregate_expr = AggregateExprBuilder::new( + sum::sum_udaf(), + vec![Arc::new(Column::new("value", 0))], + ) + .schema(interleave_schema.clone()) + .alias("total") + .build()?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new( + vec![], // no GROUP BY columns + vec![], // no GROUP BY expressions + vec![], // no GROUP BY physical expressions + ), + vec![Arc::new(aggregate_expr)], + vec![None], // no “distinct” flags + interleave, + interleave_schema, + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn join_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // Session, schema, and a single 8 K‐row batch for each side + let session_ctx = SessionContext::new(); + + // on the right side, we’ll shift each value by +1 so that not everything joins, + // but plenty of matching keys exist (e.g. 0 on left matches 1 on right, etc.) + let infinite_left = make_lazy_exec_with_range("value", -10..10, false); + let infinite_right = + make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // Create Join keys → join on “value” = “value” + let left_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; + let right_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; + + // Wrap each side in CoalesceBatches + Repartition so they are both hashed into 1 partition + let coalesced_left = + Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_left), 8_192)); + let coalesced_right = + Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_right), 8_192)); + + let part_left = Partitioning::Hash(left_keys, 1); + let part_right = Partitioning::Hash(right_keys, 1); + + let hashed_left = Arc::new(RepartitionExec::try_new(coalesced_left, part_left)?); + let hashed_right = Arc::new(RepartitionExec::try_new(coalesced_right, part_right)?); + + // Build an Inner HashJoinExec → left.value = right.value + let join = Arc::new(HashJoinExec::try_new( + hashed_left, + hashed_right, + vec![( + Arc::new(Column::new("value", 0)), + Arc::new(Column::new("value", 0)), + )], + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn join_agg_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // Session, schema, and a single 8 K‐row batch for each side + let session_ctx = SessionContext::new(); + + // on the right side, we’ll shift each value by +1 so that not everything joins, + // but plenty of matching keys exist (e.g. 0 on left matches 1 on right, etc.) + let infinite_left = make_lazy_exec_with_range("value", -10..10, false); + let infinite_right = + make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // 2b) Create Join keys → join on “value” = “value” + let left_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; + let right_keys: Vec> = vec![Arc::new(Column::new("value", 0))]; + + // Wrap each side in CoalesceBatches + Repartition so they are both hashed into 1 partition + let coalesced_left = + Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_left), 8_192)); + let coalesced_right = + Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_right), 8_192)); + + let part_left = Partitioning::Hash(left_keys, 1); + let part_right = Partitioning::Hash(right_keys, 1); + + let hashed_left = Arc::new(RepartitionExec::try_new(coalesced_left, part_left)?); + let hashed_right = Arc::new(RepartitionExec::try_new(coalesced_right, part_right)?); + + // Build an Inner HashJoinExec → left.value = right.value + let join = Arc::new(HashJoinExec::try_new( + hashed_left, + hashed_right, + vec![( + Arc::new(Column::new("value", 0)), + Arc::new(Column::new("value", 0)), + )], + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + )?); + + // Project only one column (“value” from the left side) because we just want to sum that + let input_schema = join.schema(); + + let proj_expr = vec![( + Arc::new(Column::new_with_schema("value", &input_schema)?) as _, + "value".to_string(), + )]; + + let projection = Arc::new(ProjectionExec::try_new(proj_expr, join)?); + let projection_schema = projection.schema(); + + let output_fields = vec![Field::new("total", DataType::Int64, true)]; + let output_schema = Arc::new(Schema::new(output_fields)); + + // 4) Global aggregate (Single) over “value” + let aggregate_expr = AggregateExprBuilder::new( + sum::sum_udaf(), + vec![Arc::new(Column::new_with_schema( + "value", + &projection.schema(), + )?)], + ) + .schema(output_schema) + .alias("total") + .build()?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new(vec![], vec![], vec![]), + vec![Arc::new(aggregate_expr)], + vec![None], + projection, + projection_schema, + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn hash_join_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // build session + let session_ctx = SessionContext::new(); + + // set up the join sources + let inf1 = Arc::new(make_lazy_exec("value1", pretend_infinite)); + let inf2 = Arc::new(make_lazy_exec("value2", pretend_infinite)); + + // set up a HashJoinExec that will take a long time in the build phase + let join = Arc::new(HashJoinExec::try_new( + inf1.clone(), + inf2.clone(), + vec![( + col("value1", &inf1.schema())?, + col("value2", &inf2.schema())?, + )], + None, + &JoinType::Left, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn hash_join_without_repartition_and_no_agg( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box> { + // Create Session, schema, and an 8K-row RecordBatch for each side + let session_ctx = SessionContext::new(); + + // on the right side, we’ll shift each value by +1 so that not everything joins, + // but plenty of matching keys exist (e.g. 0 on left matches 1 on right, etc.) + let infinite_left = make_lazy_exec_with_range("value", -10..10, false); + let infinite_right = + make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // Directly feed `infinite_left` and `infinite_right` into HashJoinExec. + // Do not use aggregation or repartition. + let join = Arc::new(HashJoinExec::try_new( + Arc::new(infinite_left), + Arc::new(infinite_right), + vec![( + Arc::new(Column::new("value", 0)), + Arc::new(Column::new("value", 0)), + )], + /* filter */ None, + &JoinType::Inner, + /* output64 */ None, + // Using CollectLeft is fine—just avoid RepartitionExec’s partitioned channels. + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +#[derive(Debug)] +enum Yielded { + ReadyOrPending, + Err(#[allow(dead_code)] DataFusionError), + Timeout, +} + +async fn query_yields( + plan: Arc, + task_ctx: Arc, +) -> Result<(), Box> { + // Run plan through EnsureCooperative + let optimized = + EnsureCooperative::new().optimize(plan, task_ctx.session_config().options())?; + + // Get the stream + let mut stream = physical_plan::execute_stream(optimized, task_ctx)?; + + // Create an independent executor pool + let child_runtime = Runtime::new()?; + + // Spawn a task that tries to poll the stream + // The task returns Ready when the stream yielded with either Ready or Pending + let join_handle = child_runtime.spawn(std::future::poll_fn(move |cx| { + match stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(_))) => Poll::Ready(Poll::Ready(Ok(()))), + Poll::Ready(Some(Err(e))) => Poll::Ready(Poll::Ready(Err(e))), + Poll::Ready(None) => Poll::Ready(Poll::Ready(Ok(()))), + Poll::Pending => Poll::Ready(Poll::Pending), + } + })); + + let abort_handle = join_handle.abort_handle(); + + // Now select on the join handle of the task running in the child executor with a timeout + let yielded = select! { + result = join_handle => { + match result { + Ok(Pending) => Yielded::ReadyOrPending, + Ok(Ready(Ok(_))) => Yielded::ReadyOrPending, + Ok(Ready(Err(e))) => Yielded::Err(e), + Err(_) => Yielded::Err(DataFusionError::Execution("join error".into())), + } + }, + _ = tokio::time::sleep(Duration::from_secs(10)) => { + Yielded::Timeout + } + }; + + // Try to abort the poll task and shutdown the child runtime + abort_handle.abort(); + Handle::current().spawn_blocking(move || { + child_runtime.shutdown_timeout(Duration::from_secs(5)); + }); + + // Finally, check if poll_next yielded + assert!( + matches!(yielded, Yielded::ReadyOrPending), + "Result is not Ready or Pending: {yielded:?}" + ); + Ok(()) +} diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs index b30636ddf6a8..f5a8a30e0130 100644 --- a/datafusion/core/tests/execution/logical_plan.rs +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -19,15 +19,19 @@ //! create them and depend on them. Test executable semantics of logical plans. use arrow::array::Int64Array; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::datasource::{provider_as_source, ViewTable}; use datafusion::execution::session_state::SessionStateBuilder; -use datafusion_common::{Column, DFSchema, Result, ScalarValue, Spans}; +use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, ScalarValue, Spans}; use datafusion_execution::TaskContext; use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams}; use datafusion_expr::logical_plan::{LogicalPlan, Values}; -use datafusion_expr::{Aggregate, AggregateUDF, Expr}; +use datafusion_expr::{ + Aggregate, AggregateUDF, EmptyRelation, Expr, LogicalPlanBuilder, UNNAMED_TABLE, +}; use datafusion_functions_aggregate::count::Count; use datafusion_physical_plan::collect; +use insta::assert_snapshot; use std::collections::HashMap; use std::fmt::Debug; use std::ops::Deref; @@ -43,9 +47,9 @@ async fn count_only_nulls() -> Result<()> { let input = Arc::new(LogicalPlan::Values(Values { schema: input_schema, values: vec![ - vec![Expr::Literal(ScalarValue::Null)], - vec![Expr::Literal(ScalarValue::Null)], - vec![Expr::Literal(ScalarValue::Null)], + vec![Expr::Literal(ScalarValue::Null, None)], + vec![Expr::Literal(ScalarValue::Null, None)], + vec![Expr::Literal(ScalarValue::Null, None)], ], })); let input_col_ref = Expr::Column(Column { @@ -92,7 +96,41 @@ where T: Debug, { let [element] = elements else { - panic!("Expected exactly one element, got {:?}", elements); + panic!("Expected exactly one element, got {elements:?}"); }; element } + +#[test] +fn inline_scan_projection_test() -> Result<()> { + let name = UNNAMED_TABLE; + let column = "a"; + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let projection = vec![schema.index_of(column)?]; + + let provider = ViewTable::new( + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: DFSchemaRef::new(DFSchema::try_from(schema)?), + }), + None, + ); + let source = provider_as_source(Arc::new(provider)); + + let plan = LogicalPlanBuilder::scan(name, source, Some(projection))?.build()?; + + assert_snapshot!( + format!("{plan}"), + @r" + SubqueryAlias: ?table? + Projection: a + EmptyRelation + " + ); + + Ok(()) +} diff --git a/datafusion/core/tests/execution/mod.rs b/datafusion/core/tests/execution/mod.rs index 8169db1a4611..f367a29017a3 100644 --- a/datafusion/core/tests/execution/mod.rs +++ b/datafusion/core/tests/execution/mod.rs @@ -15,4 +15,5 @@ // specific language governing permissions and limitations // under the License. +mod coop; mod logical_plan; diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index aef10379da07..a9cf7f04bb3a 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -358,8 +358,7 @@ async fn evaluate_agg_test(expr: Expr, expected_lines: Vec<&str>) { assert_eq!( expected_lines, actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines + "\n\nexpected:\n\n{expected_lines:#?}\nactual:\n\n{actual_lines:#?}\n\n" ); } @@ -379,8 +378,7 @@ fn evaluate_expr_test(expr: Expr, expected_lines: Vec<&str>) { assert_eq!( expected_lines, actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines + "\n\nexpected:\n\n{expected_lines:#?}\nactual:\n\n{actual_lines:#?}\n\n" ); } diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 7bb21725ef40..89651726a69a 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -17,6 +17,8 @@ //! This program demonstrates the DataFusion expression simplification API. +use insta::assert_snapshot; + use arrow::array::types::IntervalDayTime; use arrow::array::{ArrayRef, Int32Array}; use arrow::datatypes::{DataType, Field, Schema}; @@ -237,11 +239,15 @@ fn to_timestamp_expr_folded() -> Result<()> { .project(proj)? .build()?; - let expected = "Projection: TimestampNanosecond(1599566400000000000, None) AS to_timestamp(Utf8(\"2020-09-08T12:00:00+00:00\"))\ - \n TableScan: test" - .to_string(); - let actual = get_optimized_plan_formatted(plan, &Utc::now()); - assert_eq!(expected, actual); + let formatted = get_optimized_plan_formatted(plan, &Utc::now()); + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r###" + Projection: TimestampNanosecond(1599566400000000000, None) AS to_timestamp(Utf8("2020-09-08T12:00:00+00:00")) + TableScan: test + "### + ); Ok(()) } @@ -262,11 +268,16 @@ fn now_less_than_timestamp() -> Result<()> { // Note that constant folder runs and folds the entire // expression down to a single constant (true) - let expected = "Filter: Boolean(true)\ - \n TableScan: test"; - let actual = get_optimized_plan_formatted(plan, &time); - - assert_eq!(expected, actual); + let formatted = get_optimized_plan_formatted(plan, &time); + let actual = formatted.trim(); + + assert_snapshot!( + actual, + @r###" + Filter: Boolean(true) + TableScan: test + "### + ); Ok(()) } @@ -282,10 +293,13 @@ fn select_date_plus_interval() -> Result<()> { let date_plus_interval_expr = to_timestamp_expr(ts_string) .cast_to(&DataType::Date32, schema)? - + Expr::Literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days: 123, - milliseconds: 0, - }))); + + Expr::Literal( + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 123, + milliseconds: 0, + })), + None, + ); let plan = LogicalPlanBuilder::from(table_scan.clone()) .project(vec![date_plus_interval_expr])? @@ -293,11 +307,16 @@ fn select_date_plus_interval() -> Result<()> { // Note that constant folder runs and folds the entire // expression down to a single constant (true) - let expected = r#"Projection: Date32("2021-01-09") AS to_timestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 0 }") - TableScan: test"#; - let actual = get_optimized_plan_formatted(plan, &time); - - assert_eq!(expected, actual); + let formatted = get_optimized_plan_formatted(plan, &time); + let actual = formatted.trim(); + + assert_snapshot!( + actual, + @r###" + Projection: Date32("2021-01-09") AS to_timestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 0 }") + TableScan: test + "### + ); Ok(()) } @@ -311,10 +330,15 @@ fn simplify_project_scalar_fn() -> Result<()> { // before simplify: power(t.f, 1.0) // after simplify: t.f as "power(t.f, 1.0)" - let expected = "Projection: test.f AS power(test.f,Float64(1))\ - \n TableScan: test"; - let actual = get_optimized_plan_formatted(plan, &Utc::now()); - assert_eq!(expected, actual); + let formatter = get_optimized_plan_formatted(plan, &Utc::now()); + let actual = formatter.trim(); + assert_snapshot!( + actual, + @r###" + Projection: test.f AS power(test.f,Float64(1)) + TableScan: test + "### + ); Ok(()) } @@ -334,9 +358,9 @@ fn simplify_scan_predicate() -> Result<()> { // before simplify: t.g = power(t.f, 1.0) // after simplify: t.g = t.f" - let expected = "TableScan: test, full_filters=[g = f]"; - let actual = get_optimized_plan_formatted(plan, &Utc::now()); - assert_eq!(expected, actual); + let formatted = get_optimized_plan_formatted(plan, &Utc::now()); + let actual = formatted.trim(); + assert_snapshot!(actual, @"TableScan: test, full_filters=[g = f]"); Ok(()) } @@ -547,9 +571,9 @@ fn test_simplify_with_cycle_count( }; let simplifier = ExprSimplifier::new(info); let (simplified_expr, count) = simplifier - .simplify_with_cycle_count(input_expr.clone()) + .simplify_with_cycle_count_transformed(input_expr.clone()) .expect("successfully evaluated"); - + let simplified_expr = simplified_expr.data; assert_eq!( simplified_expr, expected_expr, "Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}" diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index ff3b66986ced..a1b23d263c0f 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -17,46 +17,58 @@ use std::sync::Arc; +use super::record_batch_generator::get_supported_types_columns; +use crate::fuzz_cases::aggregation_fuzzer::query_builder::QueryBuilder; use crate::fuzz_cases::aggregation_fuzzer::{ - AggregationFuzzerBuilder, DatasetGeneratorConfig, QueryBuilder, + AggregationFuzzerBuilder, DatasetGeneratorConfig, }; +use std::pin::Pin; +use std::sync::Arc; use arrow::array::{ types::Int64Type, Array, ArrayRef, AsArray, Int32Array, Int64Array, RecordBatch, - StringArray, + StringArray, UInt64Array, }; use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::DataType; use arrow::util::pretty::pretty_format_batches; use arrow_schema::{Field, Schema, SchemaRef}; -use datafusion::common::Result; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; use datafusion::datasource::MemTable; -use datafusion::physical_expr::aggregate::AggregateExprBuilder; -use datafusion::physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, -}; -use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; -use datafusion_common::HashMap; +use datafusion_execution::TaskContext; +use datafusion_common::{DataFusionError, HashMap, Result}; use datafusion_common_runtime::JoinSet; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::{col, lit, Column}; use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::InputOrderMode; +use datafusion_physical_plan::{ + collect, displayable, ExecutionPlan, DisplayAs, DisplayFormatType, InputOrderMode, PlanProperties, +}; +use futures::{Stream, StreamExt}; use test_utils::{add_empty_batches, StringBatchGenerator}; -use datafusion_execution::memory_pool::FairSpillPool; +use super::record_batch_generator::get_supported_types_columns; +use crate::fuzz_cases::stream_exec::StreamExec; +use datafusion_execution::memory_pool::units::{KB, MB}; +use datafusion_execution::memory_pool::{ + FairSpillPool, MemoryConsumer, MemoryReservation, +}; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::TaskContext; +use datafusion_functions_aggregate::array_agg::array_agg_udaf; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, AggregateExprBuilder, PhysicalGroupBy, +}; + +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::metrics::MetricValue; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use rand::rngs::StdRng; -use rand::{random, thread_rng, Rng, SeedableRng}; - -use super::record_batch_generator::get_supported_types_columns; +use rand::{random, thread_rng, rng, Rng, SeedableRng}; // ======================================================================== // The new aggregation fuzz tests based on [`AggregationFuzzer`] @@ -85,6 +97,7 @@ async fn test_min() { .with_aggregate_function("min") // min works on all column types .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -111,6 +124,7 @@ async fn test_first_val() { .with_table_name("fuzz_table") .with_aggregate_function("first_value") .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -137,6 +151,7 @@ async fn test_last_val() { .with_table_name("fuzz_table") .with_aggregate_function("last_value") .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -156,6 +171,7 @@ async fn test_max() { .with_aggregate_function("max") // max works on all column types .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -176,6 +192,7 @@ async fn test_sum() { .with_distinct_aggregate_function("sum") // sum only works on numeric columns .with_aggregate_arguments(data_gen_config.numeric_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -196,6 +213,7 @@ async fn test_count() { .with_distinct_aggregate_function("count") // count work for all arguments .with_aggregate_arguments(data_gen_config.all_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -216,6 +234,7 @@ async fn test_median() { .with_distinct_aggregate_function("median") // median only works on numeric columns .with_aggregate_arguments(data_gen_config.numeric_columns()) + .with_dataset_sort_keys(data_gen_config.sort_keys_set.clone()) .set_group_by_columns(data_gen_config.all_columns()); AggregationFuzzerBuilder::from(data_gen_config) @@ -233,8 +252,8 @@ async fn test_median() { /// 1. Floating point numbers /// 1. structured types fn baseline_config() -> DatasetGeneratorConfig { - let mut rng = thread_rng(); - let columns = get_supported_types_columns(rng.gen()); + let mut rng = rng(); + let columns = get_supported_types_columns(rng.random()); let min_num_rows = 512; let max_num_rows = 1024; @@ -246,6 +265,12 @@ fn baseline_config() -> DatasetGeneratorConfig { // low cardinality to try and get many repeated runs vec![String::from("u8_low")], vec![String::from("utf8_low"), String::from("u8_low")], + vec![String::from("dictionary_utf8_low")], + vec![ + String::from("dictionary_utf8_low"), + String::from("utf8_low"), + String::from("u8_low"), + ], ], } } @@ -295,13 +320,9 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::new_with_config(session_config); - let mut sort_keys = LexOrdering::default(); - for ordering_col in ["a", "b", "c"] { - sort_keys.push(PhysicalSortExpr { - expr: col(ordering_col, &schema).unwrap(), - options: SortOptions::default(), - }) - } + let sort_keys = ["a", "b", "c"].map(|ordering_col| { + PhysicalSortExpr::new_default(col(ordering_col, &schema).unwrap()) + }); let concat_input_record = concat_batches(&schema, &input1).unwrap(); @@ -315,7 +336,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str let running_source = DataSourceExec::from_data_source( MemorySourceConfig::try_new(&[input1.clone()], schema.clone(), None) .unwrap() - .try_with_sort_information(vec![sort_keys]) + .try_with_sort_information(vec![sort_keys.into()]) .unwrap(), ); @@ -423,13 +444,13 @@ pub(crate) fn make_staggered_batches( let mut input4: Vec = vec![0; len]; input123.iter_mut().for_each(|v| { *v = ( - rng.gen_range(0..n_distinct) as i64, - rng.gen_range(0..n_distinct) as i64, - rng.gen_range(0..n_distinct) as i64, + rng.random_range(0..n_distinct) as i64, + rng.random_range(0..n_distinct) as i64, + rng.random_range(0..n_distinct) as i64, ) }); input4.iter_mut().for_each(|v| { - *v = rng.gen_range(0..n_distinct) as i64; + *v = rng.random_range(0..n_distinct) as i64; }); input123.sort(); let input1 = Int64Array::from_iter_values(input123.clone().into_iter().map(|k| k.0)); @@ -449,7 +470,7 @@ pub(crate) fn make_staggered_batches( let mut batches = vec![]; if STREAM { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..50); + let batch_size = rng.random_range(0..50); if remainder.num_rows() < batch_size { break; } @@ -458,7 +479,7 @@ pub(crate) fn make_staggered_batches( } } else { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + let batch_size = rng.random_range(0..remainder.num_rows() + 1); batches.push(remainder.slice(0, batch_size)); remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); } @@ -504,7 +525,9 @@ async fn group_by_string_test( let expected = compute_counts(&input, column_name); let schema = input[0].schema(); - let session_config = SessionConfig::new().with_batch_size(50); + let session_config = SessionConfig::new() + .with_batch_size(50) + .with_repartition_file_scans(false); let ctx = SessionContext::new_with_config(session_config); let provider = MemTable::try_new(schema.clone(), vec![input]).unwrap(); @@ -623,7 +646,10 @@ fn extract_result_counts(results: Vec) -> HashMap, i output } -fn assert_spill_count_metric(expect_spill: bool, single_aggregate: Arc) { +fn assert_spill_count_metric( + expect_spill: bool, + single_aggregate: Arc, +) -> usize { if let Some(metrics_set) = single_aggregate.metrics() { let mut spill_count = 0; @@ -640,6 +666,8 @@ fn assert_spill_count_metric(expect_spill: bool, single_aggregate: Arc 0 { panic!("Expected no spill but found SpillCount metric with value greater than 0."); } + + spill_count } else { panic!("No metrics returned from the operator; cannot verify spilling."); } @@ -682,8 +710,8 @@ async fn test_single_mode_aggregate_with_spill() -> Result<()> { Arc::new(StringArray::from( (0..1024) .map(|_| -> String { - thread_rng() - .sample_iter::(rand::distributions::Standard) + rng() + .sample_iter::(rand::distr::StandardUniform) .take(5) .collect() }) @@ -753,3 +781,334 @@ async fn test_single_mode_aggregate_with_spill() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + let record_batch_size = pool_size / 16; + + // Basic test with a lot of groups that cannot all fit in memory and 1 record batch + // from each spill file is too much memory + let spill_count = run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), + memory_behavior: Default::default(), + }) + .await?; + + let total_spill_files_size = spill_count * record_batch_size; + assert!( + total_spill_files_size > pool_size, + "Total spill files size {} should be greater than pool size {}", + total_spill_files_size, + pool_size + ); + + Ok(()) +} + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + (16 * KB) as usize + } + }), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + (16 * KB) as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + (16 * KB) as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory_and_large_record_batch() -> Result<()> +{ + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory + run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 4), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +struct RunTestHighCardinalityArgs { + pool_size: usize, + task_ctx: TaskContext, + number_of_record_batches: usize, + get_size_of_record_batch_to_generate: + Pin usize + Send + 'static>>, + memory_behavior: MemoryBehavior, +} + +#[derive(Default)] +enum MemoryBehavior { + #[default] + AsIs, + TakeAllMemoryAtTheBeginning, + TakeAllMemoryAndReleaseEveryNthBatch(usize), +} + +async fn run_test_high_cardinality(args: RunTestHighCardinalityArgs) -> Result { + let RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches, + get_size_of_record_batch_to_generate, + memory_behavior, + } = args; + let scan_schema = Arc::new(Schema::new(vec![ + Field::new("col_0", DataType::UInt64, true), + Field::new("col_1", DataType::Utf8, true), + ])); + + let group_by = PhysicalGroupBy::new_single(vec![( + Arc::new(Column::new("col_0", 0)), + "col_0".to_string(), + )]); + + let aggregate_expressions = vec![Arc::new( + AggregateExprBuilder::new( + array_agg_udaf(), + vec![col("col_1", &scan_schema).unwrap()], + ) + .schema(Arc::clone(&scan_schema)) + .alias("array_agg(col_1)") + .build()?, + )]; + + let record_batch_size = task_ctx.session_config().batch_size() as u64; + + let schema = Arc::clone(&scan_schema); + let plan: Arc = + Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::() * record_batch_size as usize); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; + let string_array = Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(string_item_size)), + )); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + // Grouping key + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + // Grouping value + string_array, + ], + ) + .map_err(|err| err.into()) + }, + )), + )))); + + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggregate_expressions.clone(), + vec![None; aggregate_expressions.len()], + plan, + Arc::clone(&scan_schema), + )?); + let aggregate_final = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expressions.clone(), + vec![None; aggregate_expressions.len()], + aggregate_exec, + Arc::clone(&scan_schema), + )?); + + let task_ctx = Arc::new(task_ctx); + + let mut result = aggregate_final.execute(0, Arc::clone(&task_ctx))?; + + let mut number_of_groups = 0; + + let memory_pool = task_ctx.memory_pool(); + let memory_consumer = MemoryConsumer::new("mock_memory_consumer"); + let mut memory_reservation = memory_consumer.register(memory_pool); + + let mut index = 0; + let mut memory_took = false; + + while let Some(batch) = result.next().await { + match memory_behavior { + MemoryBehavior::AsIs => { + // Do nothing + } + MemoryBehavior::TakeAllMemoryAtTheBeginning => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(10, &mut memory_reservation)?; + } + } + MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(n) => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(pool_size, &mut memory_reservation)?; + } else if index % n == 0 { + // release memory + memory_reservation.free(); + } + } + } + + let batch = batch?; + number_of_groups += batch.num_rows(); + + index += 1; + } + + assert_eq!( + number_of_groups, + number_of_record_batches * record_batch_size as usize + ); + + let spill_count = assert_spill_count_metric(true, aggregate_final); + + Ok(spill_count) +} + +fn grow_memory_as_much_as_possible( + memory_step: usize, + memory_reservation: &mut MemoryReservation, +) -> Result { + let mut was_able_to_grow = false; + while memory_reservation.try_grow(memory_step).is_ok() { + was_able_to_grow = true; + } + + Ok(was_able_to_grow) +} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs index 3c9fe2917251..2abfcd8417cb 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs @@ -25,7 +25,7 @@ use datafusion_catalog::TableProvider; use datafusion_common::ScalarValue; use datafusion_common::{error::Result, utils::get_available_parallelism}; use datafusion_expr::col; -use rand::{thread_rng, Rng}; +use rand::{rng, Rng}; use crate::fuzz_cases::aggregation_fuzzer::data_generator::Dataset; @@ -112,7 +112,7 @@ impl SessionContextGenerator { /// Randomly generate session context pub fn generate(&self) -> Result { - let mut rng = thread_rng(); + let mut rng = rng(); let schema = self.dataset.batches[0].schema(); let batches = self.dataset.batches.clone(); let provider = MemTable::try_new(schema, vec![batches])?; @@ -123,17 +123,17 @@ impl SessionContextGenerator { // - `skip_partial`, trigger or not trigger currently for simplicity // - `sorted`, if found a sorted dataset, will or will not push down this information // - `spilling`(TODO) - let batch_size = rng.gen_range(1..=self.max_batch_size); + let batch_size = rng.random_range(1..=self.max_batch_size); - let target_partitions = rng.gen_range(1..=self.max_target_partitions); + let target_partitions = rng.random_range(1..=self.max_target_partitions); let skip_partial_params_idx = - rng.gen_range(0..self.candidate_skip_partial_params.len()); + rng.random_range(0..self.candidate_skip_partial_params.len()); let skip_partial_params = self.candidate_skip_partial_params[skip_partial_params_idx]; let (provider, sort_hint) = - if rng.gen_bool(0.5) && !self.dataset.sort_keys.is_empty() { + if rng.random_bool(0.5) && !self.dataset.sort_keys.is_empty() { // Sort keys exist and random to push down let sort_exprs = self .dataset diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs index 82bfe199234e..753a74995d8f 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs @@ -149,14 +149,14 @@ impl DatasetGenerator { for sort_keys in self.sort_keys_set.clone() { let sort_exprs = sort_keys .iter() - .map(|key| { - let col_expr = col(key, schema)?; - Ok(PhysicalSortExpr::new_default(col_expr)) - }) - .collect::>()?; - let sorted_batch = sort_batch(&base_batch, sort_exprs.as_ref(), None)?; - - let batches = stagger_batch(sorted_batch); + .map(|key| col(key, schema).map(PhysicalSortExpr::new_default)) + .collect::>>()?; + let batch = if let Some(ordering) = LexOrdering::new(sort_exprs) { + sort_batch(&base_batch, &ordering, None)? + } else { + base_batch.clone() + }; + let batches = stagger_batch(batch); let dataset = Dataset::new(batches, sort_keys); datasets.push(dataset); } diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs index 53e9288ab4af..cfb3c1c6a1b9 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs @@ -16,15 +16,14 @@ // under the License. use std::sync::Arc; -use std::{collections::HashSet, str::FromStr}; use arrow::array::RecordBatch; use arrow::util::pretty::pretty_format_batches; use datafusion_common::{DataFusionError, Result}; use datafusion_common_runtime::JoinSet; -use rand::seq::SliceRandom; -use rand::{thread_rng, Rng}; +use rand::{rng, Rng}; +use crate::fuzz_cases::aggregation_fuzzer::query_builder::QueryBuilder; use crate::fuzz_cases::aggregation_fuzzer::{ check_equality_of_batches, context_generator::{SessionContextGenerator, SessionContextWithParams}, @@ -69,30 +68,16 @@ impl AggregationFuzzerBuilder { /// - 3 random queries /// - 3 random queries for each group by selected from the sort keys /// - 1 random query with no grouping - pub fn add_query_builder(mut self, mut query_builder: QueryBuilder) -> Self { - const NUM_QUERIES: usize = 3; - for _ in 0..NUM_QUERIES { - let sql = query_builder.generate_query(); - self.candidate_sqls.push(Arc::from(sql)); - } - // also add several queries limited to grouping on the group by columns only, if any - // So if the data is sorted on `a,b` only group by `a,b` or`a` or `b` - if let Some(data_gen_config) = &self.data_gen_config { - for sort_keys in &data_gen_config.sort_keys_set { - let group_by_columns = sort_keys.iter().map(|s| s.as_str()); - query_builder = query_builder.set_group_by_columns(group_by_columns); - for _ in 0..NUM_QUERIES { - let sql = query_builder.generate_query(); - self.candidate_sqls.push(Arc::from(sql)); - } - } - } - // also add a query with no grouping - query_builder = query_builder.set_group_by_columns(vec![]); - let sql = query_builder.generate_query(); - self.candidate_sqls.push(Arc::from(sql)); + pub fn add_query_builder(mut self, query_builder: QueryBuilder) -> Self { + self = self.table_name(query_builder.table_name()); + + let sqls = query_builder + .generate_queries() + .into_iter() + .map(|sql| Arc::from(sql.as_str())); + self.candidate_sqls.extend(sqls); - self.table_name(query_builder.table_name()) + self } pub fn table_name(mut self, table_name: &str) -> Self { @@ -178,7 +163,7 @@ impl AggregationFuzzer { async fn run_inner(&mut self) -> Result<()> { let mut join_set = JoinSet::new(); - let mut rng = thread_rng(); + let mut rng = rng(); // Loop to generate datasets and its query for _ in 0..self.data_gen_rounds { @@ -192,7 +177,7 @@ impl AggregationFuzzer { let query_groups = datasets .into_iter() .map(|dataset| { - let sql_idx = rng.gen_range(0..self.candidate_sqls.len()); + let sql_idx = rng.random_range(0..self.candidate_sqls.len()); let sql = self.candidate_sqls[sql_idx].clone(); QueryGroup { dataset, sql } @@ -212,10 +197,7 @@ impl AggregationFuzzer { while let Some(join_handle) = join_set.join_next().await { // propagate errors join_handle.map_err(|e| { - DataFusionError::Internal(format!( - "AggregationFuzzer task error: {:?}", - e - )) + DataFusionError::Internal(format!("AggregationFuzzer task error: {e:?}")) })??; } Ok(()) @@ -371,217 +353,3 @@ fn format_batches_with_limit(batches: &[RecordBatch]) -> impl std::fmt::Display pretty_format_batches(&to_print).unwrap() } - -/// Random aggregate query builder -/// -/// Creates queries like -/// ```sql -/// SELECT AGG(..) FROM table_name GROUP BY -///``` -#[derive(Debug, Default, Clone)] -pub struct QueryBuilder { - /// The name of the table to query - table_name: String, - /// Aggregate functions to be used in the query - /// (function_name, is_distinct) - aggregate_functions: Vec<(String, bool)>, - /// Columns to be used in group by - group_by_columns: Vec, - /// Possible columns for arguments in the aggregate functions - /// - /// Assumes each - arguments: Vec, -} -impl QueryBuilder { - pub fn new() -> Self { - Default::default() - } - - /// return the table name if any - pub fn table_name(&self) -> &str { - &self.table_name - } - - /// Set the table name for the query builder - pub fn with_table_name(mut self, table_name: impl Into) -> Self { - self.table_name = table_name.into(); - self - } - - /// Add a new possible aggregate function to the query builder - pub fn with_aggregate_function( - mut self, - aggregate_function: impl Into, - ) -> Self { - self.aggregate_functions - .push((aggregate_function.into(), false)); - self - } - - /// Add a new possible `DISTINCT` aggregate function to the query - /// - /// This is different than `with_aggregate_function` because only certain - /// aggregates support `DISTINCT` - pub fn with_distinct_aggregate_function( - mut self, - aggregate_function: impl Into, - ) -> Self { - self.aggregate_functions - .push((aggregate_function.into(), true)); - self - } - - /// Set the columns to be used in the group bys clauses - pub fn set_group_by_columns<'a>( - mut self, - group_by: impl IntoIterator, - ) -> Self { - self.group_by_columns = group_by.into_iter().map(String::from).collect(); - self - } - - /// Add one or more columns to be used as an argument in the aggregate functions - pub fn with_aggregate_arguments<'a>( - mut self, - arguments: impl IntoIterator, - ) -> Self { - let arguments = arguments.into_iter().map(String::from); - self.arguments.extend(arguments); - self - } - - pub fn generate_query(&self) -> String { - let group_by = self.random_group_by(); - let mut query = String::from("SELECT "); - query.push_str(&group_by.join(", ")); - if !group_by.is_empty() { - query.push_str(", "); - } - query.push_str(&self.random_aggregate_functions(&group_by).join(", ")); - query.push_str(" FROM "); - query.push_str(&self.table_name); - if !group_by.is_empty() { - query.push_str(" GROUP BY "); - query.push_str(&group_by.join(", ")); - } - query - } - - /// Generate a some random aggregate function invocations (potentially repeating). - /// - /// Each aggregate function invocation is of the form - /// - /// ```sql - /// function_name( argument) as alias - /// ``` - /// - /// where - /// * `function_names` are randomly selected from [`Self::aggregate_functions`] - /// * ` argument` is randomly selected from [`Self::arguments`] - /// * `alias` is a unique alias `colN` for the column (to avoid duplicate column names) - fn random_aggregate_functions(&self, group_by_cols: &[String]) -> Vec { - const MAX_NUM_FUNCTIONS: usize = 5; - let mut rng = thread_rng(); - let num_aggregate_functions = rng.gen_range(1..MAX_NUM_FUNCTIONS); - - let mut alias_gen = 1; - - let mut aggregate_functions = vec![]; - - let mut order_by_black_list: HashSet = - group_by_cols.iter().cloned().collect(); - // remove one random col - if let Some(first) = order_by_black_list.iter().next().cloned() { - order_by_black_list.remove(&first); - } - - while aggregate_functions.len() < num_aggregate_functions { - let idx = rng.gen_range(0..self.aggregate_functions.len()); - let (function_name, is_distinct) = &self.aggregate_functions[idx]; - let argument = self.random_argument(); - let alias = format!("col{}", alias_gen); - let distinct = if *is_distinct { "DISTINCT " } else { "" }; - alias_gen += 1; - - let (order_by, null_opt) = if function_name.eq("first_value") - || function_name.eq("last_value") - { - ( - self.order_by(&order_by_black_list), /* Among the order by columns, at most one group by column can be included to avoid all order by column values being identical */ - self.null_opt(), - ) - } else { - ("".to_string(), "".to_string()) - }; - - let function = format!( - "{function_name}({distinct}{argument}{order_by}) {null_opt} as {alias}" - ); - aggregate_functions.push(function); - } - aggregate_functions - } - - /// Pick a random aggregate function argument - fn random_argument(&self) -> String { - let mut rng = thread_rng(); - let idx = rng.gen_range(0..self.arguments.len()); - self.arguments[idx].clone() - } - - fn order_by(&self, black_list: &HashSet) -> String { - let mut available_columns: Vec = self - .arguments - .iter() - .filter(|col| !black_list.contains(*col)) - .cloned() - .collect(); - - available_columns.shuffle(&mut thread_rng()); - - let num_of_order_by_col = 12; - let column_count = std::cmp::min(num_of_order_by_col, available_columns.len()); - - let selected_columns = &available_columns[0..column_count]; - - let mut rng = thread_rng(); - let mut result = String::from_str(" order by ").unwrap(); - for col in selected_columns { - let order = if rng.gen_bool(0.5) { "ASC" } else { "DESC" }; - result.push_str(&format!("{} {},", col, order)); - } - - result.strip_suffix(",").unwrap().to_string() - } - - fn null_opt(&self) -> String { - if thread_rng().gen_bool(0.5) { - "RESPECT NULLS".to_string() - } else { - "IGNORE NULLS".to_string() - } - } - - /// Pick a random number of fields to group by (non-repeating) - /// - /// Limited to 3 group by columns to ensure coverage for large groups. With - /// larger numbers of columns, each group has many fewer values. - fn random_group_by(&self) -> Vec { - let mut rng = thread_rng(); - const MAX_GROUPS: usize = 3; - let max_groups = self.group_by_columns.len().max(MAX_GROUPS); - let num_group_by = rng.gen_range(1..max_groups); - - let mut already_used = HashSet::new(); - let mut group_by = vec![]; - while group_by.len() < num_group_by - && already_used.len() != self.group_by_columns.len() - { - let idx = rng.gen_range(0..self.group_by_columns.len()); - if already_used.insert(idx) { - group_by.push(self.group_by_columns[idx].clone()); - } - } - group_by - } -} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs index bfb3bb096326..04b764e46a96 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs @@ -43,6 +43,7 @@ use datafusion_common::error::Result; mod context_generator; mod data_generator; mod fuzzer; +pub mod query_builder; pub use crate::fuzz_cases::record_batch_generator::ColumnDescr; pub use data_generator::DatasetGeneratorConfig; diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs new file mode 100644 index 000000000000..209278385b7b --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/query_builder.rs @@ -0,0 +1,384 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{collections::HashSet, str::FromStr}; + +use rand::{rng, seq::SliceRandom, Rng}; + +/// Random aggregate query builder +/// +/// Creates queries like +/// ```sql +/// SELECT AGG(..) FROM table_name GROUP BY +///``` +#[derive(Debug, Default, Clone)] +pub struct QueryBuilder { + // =================================== + // Table settings + // =================================== + /// The name of the table to query + table_name: String, + + // =================================== + // Grouping settings + // =================================== + /// Columns to be used in randomly generate `groupings` + /// + /// # Example + /// + /// Columns: + /// + /// ```text + /// [a,b,c,d] + /// ``` + /// + /// And randomly generated `groupings` (at least 1 column) + /// can be: + /// + /// ```text + /// [a] + /// [a,b] + /// [a,b,d] + /// ... + /// ``` + /// + /// So the finally generated sqls will be: + /// + /// ```text + /// SELECT aggr FROM t GROUP BY a; + /// SELECT aggr FROM t GROUP BY a,b; + /// SELECT aggr FROM t GROUP BY a,b,d; + /// ... + /// ``` + group_by_columns: Vec, + + /// Max columns num in randomly generated `groupings` + max_group_by_columns: usize, + + /// Min columns num in randomly generated `groupings` + min_group_by_columns: usize, + + /// The sort keys of dataset + /// + /// Due to optimizations will be triggered when all or some + /// grouping columns are the sort keys of dataset. + /// So it is necessary to randomly generate some `groupings` basing on + /// dataset sort keys for test coverage. + /// + /// # Example + /// + /// Dataset including columns [a,b,c], and sorted by [a,b] + /// + /// And we may generate sqls to try covering the sort-optimization cases like: + /// + /// ```text + /// SELECT aggr FROM t GROUP BY b; // no permutation case + /// SELECT aggr FROM t GROUP BY a,c; // partial permutation case + /// SELECT aggr FROM t GROUP BY a,b,c; // full permutation case + /// ... + /// ``` + /// + /// More details can see [`GroupOrdering`]. + /// + /// [`GroupOrdering`]: datafusion_physical_plan::aggregates::order::GroupOrdering + /// + dataset_sort_keys: Vec>, + + /// If we will also test the no grouping case like: + /// + /// ```text + /// SELECT aggr FROM t; + /// ``` + /// + no_grouping: bool, + + // ==================================== + // Aggregation function settings + // ==================================== + /// Aggregate functions to be used in the query + /// (function_name, is_distinct) + aggregate_functions: Vec<(String, bool)>, + + /// Possible columns for arguments in the aggregate functions + /// + /// Assumes each + arguments: Vec, +} + +impl QueryBuilder { + pub fn new() -> Self { + Self { + no_grouping: true, + max_group_by_columns: 5, + min_group_by_columns: 1, + ..Default::default() + } + } + + /// return the table name if any + pub fn table_name(&self) -> &str { + &self.table_name + } + + /// Set the table name for the query builder + pub fn with_table_name(mut self, table_name: impl Into) -> Self { + self.table_name = table_name.into(); + self + } + + /// Add a new possible aggregate function to the query builder + pub fn with_aggregate_function( + mut self, + aggregate_function: impl Into, + ) -> Self { + self.aggregate_functions + .push((aggregate_function.into(), false)); + self + } + + /// Add a new possible `DISTINCT` aggregate function to the query + /// + /// This is different than `with_aggregate_function` because only certain + /// aggregates support `DISTINCT` + pub fn with_distinct_aggregate_function( + mut self, + aggregate_function: impl Into, + ) -> Self { + self.aggregate_functions + .push((aggregate_function.into(), true)); + self + } + + /// Set the columns to be used in the group bys clauses + pub fn set_group_by_columns<'a>( + mut self, + group_by: impl IntoIterator, + ) -> Self { + self.group_by_columns = group_by.into_iter().map(String::from).collect(); + self + } + + /// Add one or more columns to be used as an argument in the aggregate functions + pub fn with_aggregate_arguments<'a>( + mut self, + arguments: impl IntoIterator, + ) -> Self { + let arguments = arguments.into_iter().map(String::from); + self.arguments.extend(arguments); + self + } + + /// Add max columns num in group by(default: 3), for example if it is set to 1, + /// the generated sql will group by at most 1 column + #[allow(dead_code)] + pub fn with_max_group_by_columns(mut self, max_group_by_columns: usize) -> Self { + self.max_group_by_columns = max_group_by_columns; + self + } + + #[allow(dead_code)] + pub fn with_min_group_by_columns(mut self, min_group_by_columns: usize) -> Self { + self.min_group_by_columns = min_group_by_columns; + self + } + + /// Add sort keys of dataset if any, then the builder will generate queries basing on it + /// to cover the sort-optimization cases + pub fn with_dataset_sort_keys(mut self, dataset_sort_keys: Vec>) -> Self { + self.dataset_sort_keys = dataset_sort_keys; + self + } + + /// Add if also test the no grouping aggregation case(default: true) + #[allow(dead_code)] + pub fn with_no_grouping(mut self, no_grouping: bool) -> Self { + self.no_grouping = no_grouping; + self + } + + pub fn generate_queries(mut self) -> Vec { + const NUM_QUERIES: usize = 3; + let mut sqls = Vec::new(); + + // Add several queries group on randomly picked columns + for _ in 0..NUM_QUERIES { + let sql = self.generate_query(); + sqls.push(sql); + } + + // Also add several queries limited to grouping on the group by + // dataset sorted columns only, if any. + // So if the data is sorted on `a,b` only group by `a,b` or`a` or `b`. + if !self.dataset_sort_keys.is_empty() { + let dataset_sort_keys = self.dataset_sort_keys.clone(); + for sort_keys in dataset_sort_keys { + let group_by_columns = sort_keys.iter().map(|s| s.as_str()); + self = self.set_group_by_columns(group_by_columns); + for _ in 0..NUM_QUERIES { + let sql = self.generate_query(); + sqls.push(sql); + } + } + } + + // Also add a query with no grouping + if self.no_grouping { + self = self.set_group_by_columns(vec![]); + let sql = self.generate_query(); + sqls.push(sql); + } + + sqls + } + + fn generate_query(&self) -> String { + let group_by = self.random_group_by(); + dbg!(&group_by); + let mut query = String::from("SELECT "); + query.push_str(&group_by.join(", ")); + if !group_by.is_empty() { + query.push_str(", "); + } + query.push_str(&self.random_aggregate_functions(&group_by).join(", ")); + query.push_str(" FROM "); + query.push_str(&self.table_name); + if !group_by.is_empty() { + query.push_str(" GROUP BY "); + query.push_str(&group_by.join(", ")); + } + query + } + + /// Generate a some random aggregate function invocations (potentially repeating). + /// + /// Each aggregate function invocation is of the form + /// + /// ```sql + /// function_name( argument) as alias + /// ``` + /// + /// where + /// * `function_names` are randomly selected from [`Self::aggregate_functions`] + /// * ` argument` is randomly selected from [`Self::arguments`] + /// * `alias` is a unique alias `colN` for the column (to avoid duplicate column names) + fn random_aggregate_functions(&self, group_by_cols: &[String]) -> Vec { + const MAX_NUM_FUNCTIONS: usize = 5; + let mut rng = rng(); + let num_aggregate_functions = rng.random_range(1..=MAX_NUM_FUNCTIONS); + + let mut alias_gen = 1; + + let mut aggregate_functions = vec![]; + + let mut order_by_black_list: HashSet = + group_by_cols.iter().cloned().collect(); + // remove one random col + if let Some(first) = order_by_black_list.iter().next().cloned() { + order_by_black_list.remove(&first); + } + + while aggregate_functions.len() < num_aggregate_functions { + let idx = rng.random_range(0..self.aggregate_functions.len()); + let (function_name, is_distinct) = &self.aggregate_functions[idx]; + let argument = self.random_argument(); + let alias = format!("col{alias_gen}"); + let distinct = if *is_distinct { "DISTINCT " } else { "" }; + alias_gen += 1; + + let (order_by, null_opt) = if function_name.eq("first_value") + || function_name.eq("last_value") + { + ( + self.order_by(&order_by_black_list), /* Among the order by columns, at most one group by column can be included to avoid all order by column values being identical */ + self.null_opt(), + ) + } else { + ("".to_string(), "".to_string()) + }; + + let function = format!( + "{function_name}({distinct}{argument}{order_by}) {null_opt} as {alias}" + ); + aggregate_functions.push(function); + } + aggregate_functions + } + + /// Pick a random aggregate function argument + fn random_argument(&self) -> String { + let mut rng = rng(); + let idx = rng.random_range(0..self.arguments.len()); + self.arguments[idx].clone() + } + + fn order_by(&self, black_list: &HashSet) -> String { + let mut available_columns: Vec = self + .arguments + .iter() + .filter(|col| !black_list.contains(*col)) + .cloned() + .collect(); + + available_columns.shuffle(&mut rng()); + + let num_of_order_by_col = 12; + let column_count = std::cmp::min(num_of_order_by_col, available_columns.len()); + + let selected_columns = &available_columns[0..column_count]; + + let mut rng = rng(); + let mut result = String::from_str(" order by ").unwrap(); + for col in selected_columns { + let order = if rng.random_bool(0.5) { "ASC" } else { "DESC" }; + result.push_str(&format!("{col} {order},")); + } + + result.strip_suffix(",").unwrap().to_string() + } + + fn null_opt(&self) -> String { + if rng().random_bool(0.5) { + "RESPECT NULLS".to_string() + } else { + "IGNORE NULLS".to_string() + } + } + + /// Pick a random number of fields to group by (non-repeating) + /// + /// Limited to `max_group_by_columns` group by columns to ensure coverage for large groups. + /// With larger numbers of columns, each group has many fewer values. + fn random_group_by(&self) -> Vec { + let mut rng = rng(); + let min_groups = self.min_group_by_columns; + let max_groups = self.max_group_by_columns; + assert!(min_groups <= max_groups); + let num_group_by = rng.random_range(min_groups..=max_groups); + + let mut already_used = HashSet::new(); + let mut group_by = vec![]; + while group_by.len() < num_group_by + && already_used.len() != self.group_by_columns.len() + { + let idx = rng.random_range(0..self.group_by_columns.len()); + if already_used.insert(idx) { + group_by.push(self.group_by_columns[idx].clone()); + } + } + group_by + } +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs index 769deef1187d..0d500fd7f441 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs @@ -16,13 +16,16 @@ // under the License. use crate::fuzz_cases::equivalence::utils::{ - convert_to_orderings, create_random_schema, create_test_params, create_test_schema_2, + create_random_schema, create_test_params, create_test_schema_2, generate_table_for_eq_properties, generate_table_for_orderings, is_table_same_after_sort, TestScalarUDF, }; use arrow::compute::SortOptions; use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; +use datafusion_physical_expr::equivalence::{ + convert_to_orderings, convert_to_sort_exprs, +}; use datafusion_physical_expr::expressions::{col, BinaryExpr}; use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -55,30 +58,27 @@ fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { col("f", &test_schema)?, ]; - for n_req in 0..=col_exprs.len() { + for n_req in 1..=col_exprs.len() { for exprs in col_exprs.iter().combinations(n_req) { - let requirement = exprs + let sort_exprs = exprs .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::(); + .map(|expr| PhysicalSortExpr::new(Arc::clone(expr), SORT_OPTIONS)); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + unreachable!("Test should always produce non-degenerate orderings"); + }; let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), + ordering.clone(), + &table_data_with_properties, )?; let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties {}", - requirement, expected, eq_properties + "Error in test case requirement:{ordering:?}, expected: {expected:?}, eq_properties {eq_properties}" ); // Check whether ordering_satisfy API result and // experimental result matches. assert_eq!( - eq_properties.ordering_satisfy(requirement.as_ref()), + eq_properties.ordering_satisfy(ordering)?, expected, - "{}", - err_msg + "{err_msg}" ); } } @@ -127,31 +127,28 @@ fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { a_plus_b, ]; - for n_req in 0..=exprs.len() { + for n_req in 1..=exprs.len() { for exprs in exprs.iter().combinations(n_req) { - let requirement = exprs + let sort_exprs = exprs .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::(); + .map(|expr| PhysicalSortExpr::new(Arc::clone(expr), SORT_OPTIONS)); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + unreachable!("Test should always produce non-degenerate orderings"); + }; let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), + ordering.clone(), + &table_data_with_properties, )?; let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties: {}", - requirement, expected, eq_properties, + "Error in test case requirement:{ordering:?}, expected: {expected:?}, eq_properties: {eq_properties}", ); // Check whether ordering_satisfy API result and // experimental result matches. assert_eq!( - eq_properties.ordering_satisfy(requirement.as_ref()), + eq_properties.ordering_satisfy(ordering)?, (expected | false), - "{}", - err_msg + "{err_msg}" ); } } @@ -304,25 +301,19 @@ fn test_ordering_satisfy_with_equivalence() -> Result<()> { ]; for (cols, expected) in requirements { - let err_msg = format!("Error in test case:{cols:?}"); - let required = cols - .into_iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(expr), - options, - }) - .collect::(); + let err_msg = format!("Error in test case: {cols:?}"); + let sort_exprs = convert_to_sort_exprs(&cols); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + unreachable!("Test should always produce non-degenerate orderings"); + }; // Check expected result with experimental result. assert_eq!( - is_table_same_after_sort( - required.clone(), - table_data_with_properties.clone() - )?, + is_table_same_after_sort(ordering.clone(), &table_data_with_properties)?, expected ); assert_eq!( - eq_properties.ordering_satisfy(required.as_ref()), + eq_properties.ordering_satisfy(ordering)?, expected, "{err_msg}" ); @@ -375,7 +366,7 @@ fn test_ordering_satisfy_on_data() -> Result<()> { (col_d, option_asc), ]; let ordering = convert_to_orderings(&[ordering])[0].clone(); - assert!(!is_table_same_after_sort(ordering, batch.clone())?); + assert!(!is_table_same_after_sort(ordering, &batch)?); // [a ASC, b ASC, d ASC] cannot be deduced let ordering = vec![ @@ -384,12 +375,12 @@ fn test_ordering_satisfy_on_data() -> Result<()> { (col_d, option_asc), ]; let ordering = convert_to_orderings(&[ordering])[0].clone(); - assert!(!is_table_same_after_sort(ordering, batch.clone())?); + assert!(!is_table_same_after_sort(ordering, &batch)?); // [a ASC, b ASC] can be deduced let ordering = vec![(col_a, option_asc), (col_b, option_asc)]; let ordering = convert_to_orderings(&[ordering])[0].clone(); - assert!(is_table_same_after_sort(ordering, batch.clone())?); + assert!(is_table_same_after_sort(ordering, &batch)?); Ok(()) } diff --git a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs index a3fa1157b38f..d776796a1b75 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs @@ -82,16 +82,12 @@ fn project_orderings_random() -> Result<()> { // Make sure each ordering after projection is valid. for ordering in projected_eq.oeq_class().iter() { let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties {}, proj_exprs: {:?}", - ordering, eq_properties, proj_exprs, + "Error in test case ordering:{ordering:?}, eq_properties {eq_properties}, proj_exprs: {proj_exprs:?}", ); // Since ordered section satisfies schema, we expect // that result will be same after sort (e.g sort was unnecessary). assert!( - is_table_same_after_sort( - ordering.clone(), - projected_batch.clone(), - )?, + is_table_same_after_sort(ordering.clone(), &projected_batch)?, "{}", err_msg ); @@ -148,8 +144,7 @@ fn ordering_satisfy_after_projection_random() -> Result<()> { for proj_exprs in proj_exprs.iter().combinations(n_req) { let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name.to_string())) - .collect::>(); + .map(|(expr, name)| (Arc::clone(expr), name.to_string())); let (projected_batch, projected_eq) = apply_projection( proj_exprs.clone(), &table_data_with_properties, @@ -157,37 +152,36 @@ fn ordering_satisfy_after_projection_random() -> Result<()> { )?; let projection_mapping = - ProjectionMapping::try_new(&proj_exprs, &test_schema)?; + ProjectionMapping::try_new(proj_exprs, &test_schema)?; let projected_exprs = projection_mapping .iter() - .map(|(_source, target)| Arc::clone(target)) + .flat_map(|(_, targets)| { + targets.iter().map(|(target, _)| Arc::clone(target)) + }) .collect::>(); - for n_req in 0..=projected_exprs.len() { + for n_req in 1..=projected_exprs.len() { for exprs in projected_exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::(); - let expected = is_table_same_after_sort( - requirement.clone(), - projected_batch.clone(), - )?; + let sort_exprs = exprs.into_iter().map(|expr| { + PhysicalSortExpr::new(Arc::clone(expr), SORT_OPTIONS) + }); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + unreachable!( + "Test should always produce non-degenerate orderings" + ); + }; + let expected = + is_table_same_after_sort(ordering.clone(), &projected_batch)?; let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties: {}, projected_eq: {}, projection_mapping: {:?}", - requirement, expected, eq_properties, projected_eq, projection_mapping + "Error in test case requirement:{ordering:?}, expected: {expected:?}, eq_properties: {eq_properties}, projected_eq: {projected_eq}, projection_mapping: {projection_mapping:?}" ); // Check whether ordering_satisfy API result and // experimental result matches. assert_eq!( - projected_eq.ordering_satisfy(requirement.as_ref()), + projected_eq.ordering_satisfy(ordering)?, expected, - "{}", - err_msg + "{err_msg}" ); } } diff --git a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs index 593e1c6c2dca..e35ce3a6f8c9 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs @@ -15,18 +15,20 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::fuzz_cases::equivalence::utils::{ create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, TestScalarUDF, }; + use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; use datafusion_physical_expr::expressions::{col, BinaryExpr}; -use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr}; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr::{LexOrdering, ScalarFunctionExpr}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + use itertools::Itertools; -use std::sync::Arc; #[test] fn test_find_longest_permutation_random() -> Result<()> { @@ -47,13 +49,13 @@ fn test_find_longest_permutation_random() -> Result<()> { Arc::clone(&test_fun), vec![col_a], &test_schema, - )?) as PhysicalExprRef; + )?) as _; let a_plus_b = Arc::new(BinaryExpr::new( col("a", &test_schema)?, Operator::Plus, col("b", &test_schema)?, - )) as Arc; + )) as _; let exprs = [ col("a", &test_schema)?, col("b", &test_schema)?, @@ -68,33 +70,32 @@ fn test_find_longest_permutation_random() -> Result<()> { for n_req in 0..=exprs.len() { for exprs in exprs.iter().combinations(n_req) { let exprs = exprs.into_iter().cloned().collect::>(); - let (ordering, indices) = eq_properties.find_longest_permutation(&exprs); + let (ordering, indices) = + eq_properties.find_longest_permutation(&exprs)?; // Make sure that find_longest_permutation return values are consistent let ordering2 = indices .iter() .zip(ordering.iter()) - .map(|(&idx, sort_expr)| PhysicalSortExpr { - expr: Arc::clone(&exprs[idx]), - options: sort_expr.options, + .map(|(&idx, sort_expr)| { + PhysicalSortExpr::new(Arc::clone(&exprs[idx]), sort_expr.options) }) - .collect::(); + .collect::>(); assert_eq!( ordering, ordering2, "indices and lexicographical ordering do not match" ); let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties: {}", - ordering, eq_properties + "Error in test case ordering:{ordering:?}, eq_properties: {eq_properties}" ); - assert_eq!(ordering.len(), indices.len(), "{}", err_msg); + assert_eq!(ordering.len(), indices.len(), "{err_msg}"); // Since ordered section satisfies schema, we expect // that result will be same after sort (e.g sort was unnecessary). + let Some(ordering) = LexOrdering::new(ordering) else { + continue; + }; assert!( - is_table_same_after_sort( - ordering.clone(), - table_data_with_properties.clone(), - )?, + is_table_same_after_sort(ordering, &table_data_with_properties)?, "{}", err_msg ); diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs index d4b41b686631..ef80925d9991 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -15,55 +15,50 @@ // specific language governing permissions and limitations // under the License. -use datafusion::physical_plan::expressions::col; -use datafusion::physical_plan::expressions::Column; -use datafusion_physical_expr::{ConstExpr, EquivalenceProperties, PhysicalSortExpr}; use std::any::Any; use std::cmp::Ordering; use std::sync::Arc; use arrow::array::{ArrayRef, Float32Array, Float64Array, RecordBatch, UInt32Array}; -use arrow::compute::SortOptions; -use arrow::compute::{lexsort_to_indices, take_record_batch, SortColumn}; +use arrow::compute::{lexsort_to_indices, take_record_batch, SortColumn, SortOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::utils::{compare_rows, get_row_at_idx}; -use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result}; +use datafusion_common::{exec_err, plan_err, DataFusionError, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; -use datafusion_physical_expr::equivalence::{EquivalenceClass, ProjectionMapping}; +use datafusion_physical_expr::equivalence::{ + convert_to_orderings, EquivalenceClass, ProjectionMapping, +}; +use datafusion_physical_expr::{ConstExpr, EquivalenceProperties}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_plan::expressions::{col, Column}; use itertools::izip; use rand::prelude::*; +/// Projects the input schema based on the given projection mapping. pub fn output_schema( mapping: &ProjectionMapping, input_schema: &Arc, ) -> Result { - // Calculate output schema - let fields: Result> = mapping - .iter() - .map(|(source, target)| { - let name = target - .as_any() - .downcast_ref::() - .ok_or_else(|| plan_datafusion_err!("Expects to have column"))? - .name(); - let field = Field::new( - name, - source.data_type(input_schema)?, - source.nullable(input_schema)?, - ); - - Ok(field) - }) - .collect(); + // Calculate output schema: + let mut fields = vec![]; + for (source, targets) in mapping.iter() { + let data_type = source.data_type(input_schema)?; + let nullable = source.nullable(input_schema)?; + for (target, _) in targets.iter() { + let Some(column) = target.as_any().downcast_ref::() else { + return plan_err!("Expects to have column"); + }; + fields.push(Field::new(column.name(), data_type.clone(), nullable)); + } + } let output_schema = Arc::new(Schema::new_with_metadata( - fields?, + fields, input_schema.metadata().clone(), )); @@ -100,9 +95,9 @@ pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperti let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); // Define a and f are aliases - eq_properties.add_equal_conditions(col_a, col_f)?; + eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_f))?; // Column e has constant value. - eq_properties = eq_properties.with_constants([ConstExpr::from(col_e)]); + eq_properties.add_constants([ConstExpr::from(Arc::clone(col_e))])?; // Randomly order columns for sorting let mut rng = StdRng::seed_from_u64(seed); @@ -114,18 +109,18 @@ pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperti }; while !remaining_exprs.is_empty() { - let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + let n_sort_expr = rng.random_range(1..remaining_exprs.len() + 1); remaining_exprs.shuffle(&mut rng); - let ordering = remaining_exprs - .drain(0..n_sort_expr) - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: options_asc, - }) - .collect(); + let ordering = + remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: options_asc, + }); - eq_properties.add_new_orderings([ordering]); + eq_properties.add_ordering(ordering); } Ok((test_schema, eq_properties)) @@ -133,12 +128,12 @@ pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperti // Apply projection to the input_data, return projected equivalence properties and record batch pub fn apply_projection( - proj_exprs: Vec<(Arc, String)>, + proj_exprs: impl IntoIterator, String)>, input_data: &RecordBatch, input_eq_properties: &EquivalenceProperties, ) -> Result<(RecordBatch, EquivalenceProperties)> { let input_schema = input_data.schema(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &input_schema)?; let output_schema = output_schema(&projection_mapping, &input_schema)?; let num_rows = input_data.num_rows(); @@ -168,49 +163,49 @@ fn add_equal_conditions_test() -> Result<()> { ])); let mut eq_properties = EquivalenceProperties::new(schema); - let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; - let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; - let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; - let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; + let col_a = Arc::new(Column::new("a", 0)) as _; + let col_b = Arc::new(Column::new("b", 1)) as _; + let col_c = Arc::new(Column::new("c", 2)) as _; + let col_x = Arc::new(Column::new("x", 3)) as _; + let col_y = Arc::new(Column::new("y", 4)) as _; // a and b are aliases - eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_a), Arc::clone(&col_b))?; assert_eq!(eq_properties.eq_group().len(), 1); // This new entry is redundant, size shouldn't increase - eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_a))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 2); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); // b and c are aliases. Existing equivalence class should expand, // however there shouldn't be any new equivalence class - eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_c))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 3); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - assert!(eq_groups.contains(&col_c_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); + assert!(eq_groups.contains(&col_c)); // This is a new set of equality. Hence equivalent class count should be 2. - eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_x), Arc::clone(&col_y))?; assert_eq!(eq_properties.eq_group().len(), 2); // This equality bridges distinct equality sets. // Hence equivalent class count should decrease from 2 to 1. - eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_x), Arc::clone(&col_a))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 5); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - assert!(eq_groups.contains(&col_c_expr)); - assert!(eq_groups.contains(&col_x_expr)); - assert!(eq_groups.contains(&col_y_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); + assert!(eq_groups.contains(&col_c)); + assert!(eq_groups.contains(&col_x)); + assert!(eq_groups.contains(&col_y)); Ok(()) } @@ -226,7 +221,7 @@ fn add_equal_conditions_test() -> Result<()> { /// already sorted according to `required_ordering` to begin with. pub fn is_table_same_after_sort( mut required_ordering: LexOrdering, - batch: RecordBatch, + batch: &RecordBatch, ) -> Result { // Clone the original schema and columns let original_schema = batch.schema(); @@ -327,7 +322,7 @@ pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { let col_f = &col("f", &test_schema)?; let col_g = &col("g", &test_schema)?; let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); - eq_properties.add_equal_conditions(col_a, col_c)?; + eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_c))?; let option_asc = SortOptions { descending: false, @@ -350,7 +345,7 @@ pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { ], ]; let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); Ok((test_schema, eq_properties)) } @@ -369,14 +364,14 @@ pub fn generate_table_for_eq_properties( // Utility closure to generate random array let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { let values: Vec = (0..num_elems) - .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) + .map(|_| rng.random_range(0..max_val) as f64 / 2.0) .collect(); Arc::new(Float64Array::from_iter_values(values)) }; // Fill constant columns for constant in eq_properties.constants() { - let col = constant.expr().as_any().downcast_ref::().unwrap(); + let col = constant.expr.as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); let arr = Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) as ArrayRef; @@ -461,7 +456,7 @@ pub fn generate_table_for_orderings( let batch = RecordBatch::try_from_iter(arrays)?; // Sort batch according to first ordering expression - let sort_columns = get_sort_columns(&batch, orderings[0].as_ref())?; + let sort_columns = get_sort_columns(&batch, &orderings[0])?; let sort_indices = lexsort_to_indices(&sort_columns, None)?; let mut batch = take_record_batch(&batch, &sort_indices)?; @@ -494,29 +489,6 @@ pub fn generate_table_for_orderings( Ok(batch) } -// Convert each tuple to PhysicalSortExpr -pub fn convert_to_sort_exprs( - in_data: &[(&Arc, SortOptions)], -) -> LexOrdering { - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(*expr), - options: *options, - }) - .collect() -} - -// Convert each inner tuple to PhysicalSortExpr -pub fn convert_to_orderings( - orderings: &[Vec<(&Arc, SortOptions)>], -) -> Vec { - orderings - .iter() - .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) - .collect() -} - // Utility function to generate random f64 array fn generate_random_f64_array( n_elems: usize, @@ -524,7 +496,7 @@ fn generate_random_f64_array( rng: &mut StdRng, ) -> ArrayRef { let values: Vec = (0..n_elems) - .map(|_| rng.gen_range(0..n_distinct) as f64 / 2.0) + .map(|_| rng.random_range(0..n_distinct) as f64 / 2.0) .collect(); Arc::new(Float64Array::from_iter_values(values)) } diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index da93dd5edf29..7250a263d89c 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -37,7 +37,7 @@ use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, }; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::ScalarValue; +use datafusion_common::{NullEquality, ScalarValue}; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::PhysicalExprRef; @@ -186,7 +186,7 @@ async fn test_full_join_1k_filtered() { } #[tokio::test] -async fn test_semi_join_1k() { +async fn test_left_semi_join_1k() { JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), @@ -198,7 +198,7 @@ async fn test_semi_join_1k() { } #[tokio::test] -async fn test_semi_join_1k_filtered() { +async fn test_left_semi_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), @@ -209,6 +209,30 @@ async fn test_semi_join_1k_filtered() { .await } +#[tokio::test] +async fn test_right_semi_join_1k() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::RightSemi, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + +#[tokio::test] +async fn test_right_semi_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::RightSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await +} + #[tokio::test] async fn test_left_anti_join_1k() { JoinFuzzTestCase::new( @@ -281,6 +305,31 @@ async fn test_left_mark_join_1k_filtered() { .await } +// todo: add JoinTestType::HjSmj after Right mark SortMergeJoin support +#[tokio::test] +async fn test_right_mark_join_1k() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::RightMark, + None, + ) + .run_test(&[NljHj], false) + .await +} + +#[tokio::test] +async fn test_right_mark_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::RightMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[NljHj], false) + .await +} + type JoinFilterBuilder = Box, Arc) -> JoinFilter>; struct JoinFuzzTestCase { @@ -455,7 +504,7 @@ impl JoinFuzzTestCase { self.join_filter(), self.join_type, vec![SortOptions::default(); self.on_columns().len()], - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ) @@ -472,7 +521,7 @@ impl JoinFuzzTestCase { &self.join_type, None, PartitionMode::Partitioned, - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ) @@ -545,7 +594,7 @@ impl JoinFuzzTestCase { std::fs::remove_dir_all(fuzz_debug).unwrap_or(()); std::fs::create_dir_all(fuzz_debug).unwrap(); let out_dir_name = &format!("{fuzz_debug}/batch_size_{batch_size}"); - println!("Test result data mismatch found. HJ rows {}, SMJ rows {}, NLJ rows {}", hj_rows, smj_rows, nlj_rows); + println!("Test result data mismatch found. HJ rows {hj_rows}, SMJ rows {smj_rows}, NLJ rows {nlj_rows}"); println!("The debug is ON. Input data will be saved to {out_dir_name}"); Self::save_partitioned_batches_as_parquet( @@ -561,9 +610,9 @@ impl JoinFuzzTestCase { if join_tests.contains(&NljHj) && nlj_rows != hj_rows { println!("=============== HashJoinExec =================="); - hj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + hj_formatted_sorted.iter().for_each(|s| println!("{s}")); println!("=============== NestedLoopJoinExec =================="); - nlj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + nlj_formatted_sorted.iter().for_each(|s| println!("{s}")); Self::save_partitioned_batches_as_parquet( &nlj_collected, @@ -579,9 +628,9 @@ impl JoinFuzzTestCase { if join_tests.contains(&HjSmj) && smj_rows != hj_rows { println!("=============== HashJoinExec =================="); - hj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + hj_formatted_sorted.iter().for_each(|s| println!("{s}")); println!("=============== SortMergeJoinExec =================="); - smj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + smj_formatted_sorted.iter().for_each(|s| println!("{s}")); Self::save_partitioned_batches_as_parquet( &hj_collected, @@ -597,10 +646,10 @@ impl JoinFuzzTestCase { } if join_tests.contains(&NljHj) { - let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {}", batch_size); + let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {batch_size}"); assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str()); - let err_msg_contents = format!("NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {}", batch_size); + let err_msg_contents = format!("NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {batch_size}"); // row level compare if any of joins returns the result // the reason is different formatting when there is no rows for (i, (nlj_line, hj_line)) in nlj_formatted_sorted @@ -671,7 +720,7 @@ impl JoinFuzzTestCase { std::fs::create_dir_all(out_path).unwrap(); input.iter().enumerate().for_each(|(idx, batch)| { - let file_path = format!("{out_path}/file_{}.parquet", idx); + let file_path = format!("{out_path}/file_{idx}.parquet"); let mut file = std::fs::File::create(&file_path).unwrap(); println!( "{}: Saving batch idx {} rows {} to parquet {}", @@ -722,11 +771,9 @@ impl JoinFuzzTestCase { path.to_str().unwrap(), datafusion::prelude::ParquetReadOptions::default(), ) - .await - .unwrap() + .await? .collect() - .await - .unwrap(); + .await?; batches.append(&mut batch); } @@ -739,13 +786,13 @@ impl JoinFuzzTestCase { /// two sorted int32 columns 'a', 'b' ranged from 0..99 as join columns /// two random int32 columns 'x', 'y' as other columns fn make_staggered_batches(len: usize) -> Vec { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let mut input12: Vec<(i32, i32)> = vec![(0, 0); len]; let mut input3: Vec = vec![0; len]; let mut input4: Vec = vec![0; len]; input12 .iter_mut() - .for_each(|v| *v = (rng.gen_range(0..100), rng.gen_range(0..100))); + .for_each(|v| *v = (rng.random_range(0..100), rng.random_range(0..100))); rng.fill(&mut input3[..]); rng.fill(&mut input4[..]); input12.sort_unstable(); diff --git a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs index 987a732eb294..4c5ebf040241 100644 --- a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs @@ -24,7 +24,7 @@ use arrow::util::pretty::pretty_format_batches; use datafusion::datasource::MemTable; use datafusion::prelude::SessionContext; use datafusion_common::assert_contains; -use rand::{thread_rng, Rng}; +use rand::{rng, Rng}; use std::sync::Arc; use test_utils::stagger_batch; @@ -54,11 +54,11 @@ async fn run_limit_fuzz_test(make_data: F) where F: Fn(usize) -> SortedData, { - let mut rng = thread_rng(); + let mut rng = rng(); for size in [10, 1_0000, 10_000, 100_000] { let data = make_data(size); // test various limits including some random ones - for limit in [1, 3, 7, 17, 10000, rng.gen_range(1..size * 2)] { + for limit in [1, 3, 7, 17, 10000, rng.random_range(1..size * 2)] { // limit can be larger than the number of rows in the input run_limit_test(limit, &data).await; } @@ -97,13 +97,13 @@ impl SortedData { /// Create an i32 column of random values, with the specified number of /// rows, sorted the default fn new_i32(size: usize) -> Self { - let mut rng = thread_rng(); + let mut rng = rng(); // have some repeats (approximately 1/3 of the values are the same) let max = size as i32 / 3; let data: Vec> = (0..size) .map(|_| { // no nulls for now - Some(rng.gen_range(0..max)) + Some(rng.random_range(0..max)) }) .collect(); @@ -118,17 +118,17 @@ impl SortedData { /// Create an f64 column of random values, with the specified number of /// rows, sorted the default fn new_f64(size: usize) -> Self { - let mut rng = thread_rng(); + let mut rng = rng(); let mut data: Vec> = (0..size / 3) .map(|_| { // no nulls for now - Some(rng.gen_range(0.0..1.0f64)) + Some(rng.random_range(0.0..1.0f64)) }) .collect(); // have some repeats (approximately 1/3 of the values are the same) while data.len() < size { - data.push(data[rng.gen_range(0..data.len())]); + data.push(data[rng.random_range(0..data.len())]); } let batches = stagger_batch(f64_batch(data.iter().cloned())); @@ -142,7 +142,7 @@ impl SortedData { /// Create an string column of random values, with the specified number of /// rows, sorted the default fn new_str(size: usize) -> Self { - let mut rng = thread_rng(); + let mut rng = rng(); let mut data: Vec> = (0..size / 3) .map(|_| { // no nulls for now @@ -152,7 +152,7 @@ impl SortedData { // have some repeats (approximately 1/3 of the values are the same) while data.len() < size { - data.push(data[rng.gen_range(0..data.len())].clone()); + data.push(data[rng.random_range(0..data.len())].clone()); } let batches = stagger_batch(string_batch(data.iter())); @@ -166,7 +166,7 @@ impl SortedData { /// Create two columns of random values (int64, string), with the specified number of /// rows, sorted the default fn new_i64str(size: usize) -> Self { - let mut rng = thread_rng(); + let mut rng = rng(); // 100 distinct values let strings: Vec> = (0..100) @@ -180,8 +180,8 @@ impl SortedData { let data = (0..size) .map(|_| { ( - Some(rng.gen_range(0..10)), - strings[rng.gen_range(0..strings.len())].clone(), + Some(rng.random_range(0..10)), + strings[rng.random_range(0..strings.len())].clone(), ) }) .collect::>(); @@ -340,8 +340,8 @@ async fn run_limit_test(fetch: usize, data: &SortedData) { /// Return random ASCII String with len fn get_random_string(len: usize) -> String { - thread_rng() - .sample_iter(rand::distributions::Alphanumeric) + rng() + .sample_iter(rand::distr::Alphanumeric) .take(len) .map(char::from) .collect() diff --git a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs index 92f375525066..b92dec64e3f1 100644 --- a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs @@ -31,7 +31,6 @@ use datafusion::physical_plan::{ sorts::sort_preserving_merge::SortPreservingMergeExec, }; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; use test_utils::{batches_to_vec, partitions_to_sorted_vec, stagger_batch_with_seed}; @@ -109,13 +108,14 @@ async fn run_merge_test(input: Vec>) { .expect("at least one batch"); let schema = first_batch.schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort = [PhysicalSortExpr { expr: col("x", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let exec = MemorySourceConfig::try_new_exec(&input, schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index 8ccc2a5bc131..c48695914906 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -21,6 +21,7 @@ mod join_fuzz; mod merge_fuzz; mod sort_fuzz; mod sort_query_fuzz; +mod topk_filter_pushdown; mod aggregation_fuzzer; mod equivalence; @@ -33,3 +34,4 @@ mod window_fuzz; // Utility modules mod record_batch_generator; +mod stream_exec; diff --git a/datafusion/core/tests/fuzz_cases/pruning.rs b/datafusion/core/tests/fuzz_cases/pruning.rs index 11dd961a54ee..6e624d458bd9 100644 --- a/datafusion/core/tests/fuzz_cases/pruning.rs +++ b/datafusion/core/tests/fuzz_cases/pruning.rs @@ -90,42 +90,42 @@ async fn test_utf8_not_like() { #[tokio::test] async fn test_utf8_like_prefix() { - Utf8Test::new(|value| col("a").like(lit(format!("%{}", value)))) + Utf8Test::new(|value| col("a").like(lit(format!("%{value}")))) .run() .await; } #[tokio::test] async fn test_utf8_like_suffix() { - Utf8Test::new(|value| col("a").like(lit(format!("{}%", value)))) + Utf8Test::new(|value| col("a").like(lit(format!("{value}%")))) .run() .await; } #[tokio::test] async fn test_utf8_not_like_prefix() { - Utf8Test::new(|value| col("a").not_like(lit(format!("%{}", value)))) + Utf8Test::new(|value| col("a").not_like(lit(format!("%{value}")))) .run() .await; } #[tokio::test] async fn test_utf8_not_like_ecsape() { - Utf8Test::new(|value| col("a").not_like(lit(format!("\\%{}%", value)))) + Utf8Test::new(|value| col("a").not_like(lit(format!("\\%{value}%")))) .run() .await; } #[tokio::test] async fn test_utf8_not_like_suffix() { - Utf8Test::new(|value| col("a").not_like(lit(format!("{}%", value)))) + Utf8Test::new(|value| col("a").not_like(lit(format!("{value}%")))) .run() .await; } #[tokio::test] async fn test_utf8_not_like_suffix_one() { - Utf8Test::new(|value| col("a").not_like(lit(format!("{}_", value)))) + Utf8Test::new(|value| col("a").not_like(lit(format!("{value}_")))) .run() .await; } @@ -226,7 +226,7 @@ impl Utf8Test { return (*files).clone(); } - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let values = Self::values(); let mut row_groups = vec![]; @@ -276,7 +276,7 @@ async fn execute_with_predicate( ctx: &SessionContext, ) -> Vec { let parquet_source = if prune_stats { - ParquetSource::default().with_predicate(Arc::clone(&schema), predicate.clone()) + ParquetSource::default().with_predicate(predicate.clone()) } else { ParquetSource::default() }; @@ -345,7 +345,7 @@ async fn write_parquet_file( /// The string values for [Utf8Test::values] static VALUES: LazyLock> = LazyLock::new(|| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let characters = [ "z", diff --git a/datafusion/core/tests/fuzz_cases/record_batch_generator.rs b/datafusion/core/tests/fuzz_cases/record_batch_generator.rs index 9a62a6397d82..4eac1482ad3f 100644 --- a/datafusion/core/tests/fuzz_cases/record_batch_generator.rs +++ b/datafusion/core/tests/fuzz_cases/record_batch_generator.rs @@ -17,13 +17,15 @@ use std::sync::Arc; -use arrow::array::{ArrayRef, RecordBatch}; +use arrow::array::{ArrayRef, DictionaryArray, PrimitiveArray, RecordBatch}; use arrow::datatypes::{ - BooleanType, DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, - Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, - IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, - Schema, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, - Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, + ArrowPrimitiveType, BooleanType, DataType, Date32Type, Date64Type, Decimal128Type, + Decimal256Type, DurationMicrosecondType, DurationMillisecondType, + DurationNanosecondType, DurationSecondType, Field, Float32Type, Float64Type, + Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, + IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, Schema, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; @@ -32,7 +34,7 @@ use arrow_schema::{ DECIMAL256_MAX_SCALE, }; use datafusion_common::{arrow_datafusion_err, DataFusionError, Result}; -use rand::{rngs::StdRng, thread_rng, Rng, SeedableRng}; +use rand::{rng, rngs::StdRng, Rng, SeedableRng}; use test_utils::array_gen::{ BinaryArrayGenerator, BooleanArrayGenerator, DecimalArrayGenerator, PrimitiveArrayGenerator, StringArrayGenerator, @@ -85,16 +87,33 @@ pub fn get_supported_types_columns(rng_seed: u64) -> Vec { "interval_month_day_nano", DataType::Interval(IntervalUnit::MonthDayNano), ), + // Internal error: AggregationFuzzer task error: JoinError::Panic(Id(29108), "called `Option::unwrap()` on a `None` value", ...). + // ColumnDescr::new( + // "duration_seconds", + // DataType::Duration(TimeUnit::Second), + // ), + ColumnDescr::new( + "duration_milliseconds", + DataType::Duration(TimeUnit::Millisecond), + ), + ColumnDescr::new( + "duration_microsecond", + DataType::Duration(TimeUnit::Microsecond), + ), + ColumnDescr::new( + "duration_nanosecond", + DataType::Duration(TimeUnit::Nanosecond), + ), ColumnDescr::new("decimal128", { - let precision: u8 = rng.gen_range(1..=DECIMAL128_MAX_PRECISION); - let scale: i8 = rng.gen_range( + let precision: u8 = rng.random_range(1..=DECIMAL128_MAX_PRECISION); + let scale: i8 = rng.random_range( i8::MIN..=std::cmp::min(precision as i8, DECIMAL128_MAX_SCALE), ); DataType::Decimal128(precision, scale) }), ColumnDescr::new("decimal256", { - let precision: u8 = rng.gen_range(1..=DECIMAL256_MAX_PRECISION); - let scale: i8 = rng.gen_range( + let precision: u8 = rng.random_range(1..=DECIMAL256_MAX_PRECISION); + let scale: i8 = rng.random_range( i8::MIN..=std::cmp::min(precision as i8, DECIMAL256_MAX_SCALE), ); DataType::Decimal256(precision, scale) @@ -108,6 +127,11 @@ pub fn get_supported_types_columns(rng_seed: u64) -> Vec { ColumnDescr::new("binary", DataType::Binary), ColumnDescr::new("large_binary", DataType::LargeBinary), ColumnDescr::new("binaryview", DataType::BinaryView), + ColumnDescr::new( + "dictionary_utf8_low", + DataType::Dictionary(Box::new(DataType::UInt64), Box::new(DataType::Utf8)), + ) + .with_max_num_distinct(10), ] } @@ -161,22 +185,19 @@ pub struct RecordBatchGenerator { /// If a seed is provided when constructing the generator, it will be used to /// create `rng` and the pseudo-randomly generated batches will be deterministic. - /// Otherwise, `rng` will be initialized using `thread_rng()` and the batches + /// Otherwise, `rng` will be initialized using `rng()` and the batches /// generated will be different each time. rng: StdRng, } macro_rules! generate_decimal_array { - ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT: expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $PRECISION: ident, $SCALE: ident, $ARROW_TYPE: ident) => {{ - let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); - let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; - + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT: expr, $NULL_PCT:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $PRECISION: ident, $SCALE: ident, $ARROW_TYPE: ident) => {{ let mut generator = DecimalArrayGenerator { precision: $PRECISION, scale: $SCALE, num_decimals: $NUM_ROWS, num_distinct_decimals: $MAX_NUM_DISTINCT, - null_pct, + null_pct: $NULL_PCT, rng: $ARRAY_GEN_RNG, }; @@ -186,17 +207,13 @@ macro_rules! generate_decimal_array { // Generating `BooleanArray` due to it being a special type in Arrow (bit-packed) macro_rules! generate_boolean_array { - ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE: ident) => {{ - // Select a null percentage from the candidate percentages - let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); - let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; - + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $NULL_PCT:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE: ident) => {{ let num_distinct_booleans = if $MAX_NUM_DISTINCT >= 2 { 2 } else { 1 }; let mut generator = BooleanArrayGenerator { num_booleans: $NUM_ROWS, num_distinct_booleans, - null_pct, + null_pct: $NULL_PCT, rng: $ARRAY_GEN_RNG, }; @@ -205,14 +222,11 @@ macro_rules! generate_boolean_array { } macro_rules! generate_primitive_array { - ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident) => {{ - let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); - let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; - + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $NULL_PCT:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident) => {{ let mut generator = PrimitiveArrayGenerator { num_primitives: $NUM_ROWS, num_distinct_primitives: $MAX_NUM_DISTINCT, - null_pct, + null_pct: $NULL_PCT, rng: $ARRAY_GEN_RNG, }; @@ -220,6 +234,28 @@ macro_rules! generate_primitive_array { }}; } +macro_rules! generate_dict { + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $NULL_PCT:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident, $VALUES: ident) => {{ + debug_assert_eq!($VALUES.len(), $MAX_NUM_DISTINCT); + let keys: PrimitiveArray<$ARROW_TYPE> = (0..$NUM_ROWS) + .map(|_| { + if $BATCH_GEN_RNG.random::() < $NULL_PCT { + None + } else if $MAX_NUM_DISTINCT > 1 { + let range = 0..($MAX_NUM_DISTINCT + as <$ARROW_TYPE as ArrowPrimitiveType>::Native); + Some($ARRAY_GEN_RNG.random_range(range)) + } else { + Some(0) + } + }) + .collect(); + + let dict = DictionaryArray::new(keys, $VALUES); + Arc::new(dict) as ArrayRef + }}; +} + impl RecordBatchGenerator { /// Create a new `RecordBatchGenerator` with a random seed. The generated /// batches will be different each time. @@ -235,7 +271,7 @@ impl RecordBatchGenerator { max_rows_num, columns, candidate_null_pcts, - rng: StdRng::from_rng(thread_rng()).unwrap(), + rng: StdRng::from_rng(&mut rng()), } } @@ -247,9 +283,9 @@ impl RecordBatchGenerator { } pub fn generate(&mut self) -> Result { - let num_rows = self.rng.gen_range(self.min_rows_num..=self.max_rows_num); - let array_gen_rng = StdRng::from_seed(self.rng.gen()); - let mut batch_gen_rng = StdRng::from_seed(self.rng.gen()); + let num_rows = self.rng.random_range(self.min_rows_num..=self.max_rows_num); + let array_gen_rng = StdRng::from_seed(self.rng.random()); + let mut batch_gen_rng = StdRng::from_seed(self.rng.random()); let columns = self.columns.clone(); // Build arrays @@ -281,9 +317,28 @@ impl RecordBatchGenerator { num_rows: usize, batch_gen_rng: &mut StdRng, array_gen_rng: StdRng, + ) -> ArrayRef { + let null_pct_idx = batch_gen_rng.random_range(0..self.candidate_null_pcts.len()); + let null_pct = self.candidate_null_pcts[null_pct_idx]; + + Self::generate_array_of_type_inner( + col, + num_rows, + batch_gen_rng, + array_gen_rng, + null_pct, + ) + } + + fn generate_array_of_type_inner( + col: &ColumnDescr, + num_rows: usize, + batch_gen_rng: &mut StdRng, + array_gen_rng: StdRng, + null_pct: f64, ) -> ArrayRef { let num_distinct = if num_rows > 1 { - batch_gen_rng.gen_range(1..num_rows) + batch_gen_rng.random_range(1..num_rows) } else { num_rows }; @@ -299,6 +354,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Int8Type @@ -309,6 +365,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Int16Type @@ -319,6 +376,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Int32Type @@ -329,6 +387,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Int64Type @@ -339,6 +398,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, UInt8Type @@ -349,6 +409,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, UInt16Type @@ -359,6 +420,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, UInt32Type @@ -369,6 +431,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, UInt64Type @@ -379,6 +442,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Float32Type @@ -389,6 +453,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Float64Type @@ -399,6 +464,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Date32Type @@ -409,6 +475,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Date64Type @@ -419,6 +486,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Time32SecondType @@ -429,6 +497,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Time32MillisecondType @@ -439,6 +508,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Time64MicrosecondType @@ -449,6 +519,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, Time64NanosecondType @@ -459,6 +530,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, IntervalYearMonthType @@ -469,6 +541,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, IntervalDayTimeType @@ -479,16 +552,62 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, IntervalMonthDayNanoType ) } + DataType::Duration(TimeUnit::Second) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + DurationSecondType + ) + } + DataType::Duration(TimeUnit::Millisecond) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + DurationMillisecondType + ) + } + DataType::Duration(TimeUnit::Microsecond) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + DurationMicrosecondType + ) + } + DataType::Duration(TimeUnit::Nanosecond) => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + DurationNanosecondType + ) + } DataType::Timestamp(TimeUnit::Second, None) => { generate_primitive_array!( self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, TimestampSecondType @@ -499,6 +618,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, TimestampMillisecondType @@ -509,6 +629,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, TimestampMicrosecondType @@ -519,16 +640,14 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, TimestampNanosecondType ) } DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { - let null_pct_idx = - batch_gen_rng.gen_range(0..self.candidate_null_pcts.len()); - let null_pct = self.candidate_null_pcts[null_pct_idx]; - let max_len = batch_gen_rng.gen_range(1..50); + let max_len = batch_gen_rng.random_range(1..50); let mut generator = StringArrayGenerator { max_len, @@ -546,10 +665,7 @@ impl RecordBatchGenerator { } } DataType::Binary | DataType::LargeBinary | DataType::BinaryView => { - let null_pct_idx = - batch_gen_rng.gen_range(0..self.candidate_null_pcts.len()); - let null_pct = self.candidate_null_pcts[null_pct_idx]; - let max_len = batch_gen_rng.gen_range(1..100); + let max_len = batch_gen_rng.random_range(1..100); let mut generator = BinaryArrayGenerator { max_len, @@ -571,6 +687,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, precision, @@ -583,6 +700,7 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, precision, @@ -595,11 +713,43 @@ impl RecordBatchGenerator { self, num_rows, max_num_distinct, + null_pct, batch_gen_rng, array_gen_rng, BooleanType } } + DataType::Dictionary(ref key_type, ref value_type) + if key_type.is_dictionary_key_type() => + { + // We generate just num_distinct values because they will be reused by different keys + let mut array_gen_rng = array_gen_rng; + + let values = Self::generate_array_of_type_inner( + &ColumnDescr::new("values", *value_type.clone()), + num_distinct, + batch_gen_rng, + array_gen_rng.clone(), + // Once https://github.com/apache/datafusion/issues/16228 is fixed + // we can also generate nulls in values + 0.0, // null values are generated on the key level + ); + + match key_type.as_ref() { + // new key types can be added here + DataType::UInt64 => generate_dict!( + self, + num_rows, + num_distinct, + null_pct, + batch_gen_rng, + array_gen_rng, + UInt64Type, + values + ), + _ => panic!("Invalid dictionary keys type: {key_type}"), + } + } _ => { panic!("Unsupported data generator type: {}", col.column_type) } @@ -636,8 +786,8 @@ mod tests { let batch1 = gen1.generate().unwrap(); let batch2 = gen2.generate().unwrap(); - let batch1_formatted = format!("{:?}", batch1); - let batch2_formatted = format!("{:?}", batch2); + let batch1_formatted = format!("{batch1:?}"); + let batch2_formatted = format!("{batch2:?}"); assert_eq!(batch1_formatted, batch2_formatted); } diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index 0b0f0aa2f105..08b568c37ef1 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -17,13 +17,17 @@ //! Fuzz Test for various corner cases sorting RecordBatches exceeds available memory and should spill +use std::pin::Pin; use std::sync::Arc; +use arrow::array::UInt64Array; use arrow::{ array::{as_string_array, ArrayRef, Int32Array, StringArray}, compute::SortOptions, record_batch::RecordBatch, }; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::common::Result; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::physical_plan::expressions::PhysicalSortExpr; @@ -31,10 +35,17 @@ use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{collect, ExecutionPlan}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::cast::as_int32_array; -use datafusion_execution::memory_pool::GreedyMemoryPool; +use datafusion_execution::memory_pool::{ + FairSpillPool, GreedyMemoryPool, MemoryConsumer, MemoryReservation, +}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use futures::StreamExt; +use crate::fuzz_cases::stream_exec::StreamExec; +use datafusion_execution::memory_pool::units::MB; +use datafusion_execution::TaskContext; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use rand::Rng; use test_utils::{batches_to_vec, partitions_to_sorted_vec}; @@ -232,23 +243,20 @@ impl SortTest { .expect("at least one batch"); let schema = first_batch.schema(); - let sort_ordering = LexOrdering::new( - self.sort_columns - .iter() - .map(|c| PhysicalSortExpr { - expr: col(c, &schema).unwrap(), - options: SortOptions { - descending: false, - nulls_first: true, - }, - }) - .collect(), - ); + let sort_ordering = + LexOrdering::new(self.sort_columns.iter().map(|c| PhysicalSortExpr { + expr: col(c, &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + })) + .unwrap(); let exec = MemorySourceConfig::try_new_exec(&input, schema, None).unwrap(); let sort = Arc::new(SortExec::new(sort_ordering, exec)); - let session_config = SessionConfig::new(); + let session_config = SessionConfig::new().with_repartition_file_scans(false); let session_ctx = if let Some(pool_size) = self.pool_size { // Make sure there is enough space for the initial spill // reservation @@ -298,20 +306,20 @@ impl SortTest { /// Return randomly sized record batches in a field named 'x' of type `Int32` /// with randomized i32 content fn make_staggered_i32_batches(len: usize) -> Vec { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let max_batch = 1024; let mut batches = vec![]; let mut remaining = len; while remaining != 0 { - let to_read = rng.gen_range(0..=remaining.min(max_batch)); + let to_read = rng.random_range(0..=remaining.min(max_batch)); remaining -= to_read; batches.push( RecordBatch::try_from_iter(vec![( "x", Arc::new(Int32Array::from_iter_values( - (0..to_read).map(|_| rng.gen()), + (0..to_read).map(|_| rng.random()), )) as ArrayRef, )]) .unwrap(), @@ -323,20 +331,20 @@ fn make_staggered_i32_batches(len: usize) -> Vec { /// Return randomly sized record batches in a field named 'x' of type `Utf8` /// with randomized content fn make_staggered_utf8_batches(len: usize) -> Vec { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let max_batch = 1024; let mut batches = vec![]; let mut remaining = len; while remaining != 0 { - let to_read = rng.gen_range(0..=remaining.min(max_batch)); + let to_read = rng.random_range(0..=remaining.min(max_batch)); remaining -= to_read; batches.push( RecordBatch::try_from_iter(vec![( "x", Arc::new(StringArray::from_iter_values( - (0..to_read).map(|_| format!("test_string_{}", rng.gen::())), + (0..to_read).map(|_| format!("test_string_{}", rng.random::())), )) as ArrayRef, )]) .unwrap(), @@ -349,13 +357,13 @@ fn make_staggered_utf8_batches(len: usize) -> Vec { /// with randomized i32 content and a field named 'y' of type `Utf8` /// with randomized content fn make_staggered_i32_utf8_batches(len: usize) -> Vec { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let max_batch = 1024; let mut batches = vec![]; let mut remaining = len; while remaining != 0 { - let to_read = rng.gen_range(0..=remaining.min(max_batch)); + let to_read = rng.random_range(0..=remaining.min(max_batch)); remaining -= to_read; batches.push( @@ -363,13 +371,14 @@ fn make_staggered_i32_utf8_batches(len: usize) -> Vec { ( "x", Arc::new(Int32Array::from_iter_values( - (0..to_read).map(|_| rng.gen()), + (0..to_read).map(|_| rng.random()), )) as ArrayRef, ), ( "y", Arc::new(StringArray::from_iter_values( - (0..to_read).map(|_| format!("test_string_{}", rng.gen::())), + (0..to_read) + .map(|_| format!("test_string_{}", rng.random::())), )) as ArrayRef, ), ]) @@ -379,3 +388,336 @@ fn make_staggered_i32_utf8_batches(len: usize) -> Vec { batches } + +#[tokio::test] +async fn test_sort_with_limited_memory() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + let record_batch_size = pool_size / 16; + + // Basic test with a lot of groups that cannot all fit in memory and 1 record batch + // from each spill file is too much memory + let spill_count = + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), + memory_behavior: Default::default(), + }) + .await?; + + let total_spill_files_size = spill_count * record_batch_size; + assert!( + total_spill_files_size > pool_size, + "Total spill files size {} should be greater than pool size {}", + total_spill_files_size, + pool_size + ); + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch() -> Result<()> +{ + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + 16 * KB + } + }), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + 16 * KB + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + 16 * KB + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_large_record_batch() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 4), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +struct RunSortTestWithLimitedMemoryArgs { + pool_size: usize, + task_ctx: TaskContext, + number_of_record_batches: usize, + get_size_of_record_batch_to_generate: + Pin usize + Send + 'static>>, + memory_behavior: MemoryBehavior, +} + +#[derive(Default)] +enum MemoryBehavior { + #[default] + AsIs, + TakeAllMemoryAtTheBeginning, + TakeAllMemoryAndReleaseEveryNthBatch(usize), +} + +async fn run_sort_test_with_limited_memory( + args: RunSortTestWithLimitedMemoryArgs, +) -> Result { + let RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches, + get_size_of_record_batch_to_generate, + memory_behavior, + } = args; + let scan_schema = Arc::new(Schema::new(vec![ + Field::new("col_0", DataType::UInt64, true), + Field::new("col_1", DataType::Utf8, true), + ])); + + let record_batch_size = task_ctx.session_config().batch_size() as u64; + + let schema = Arc::clone(&scan_schema); + let plan: Arc = + Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::() * record_batch_size as usize); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; + let string_array = Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(string_item_size)), + )); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + string_array, + ], + ) + .map_err(|err| err.into()) + }, + )), + )))); + let sort_exec = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: col("col_0", &scan_schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]), + plan, + )); + + let task_ctx = Arc::new(task_ctx); + + let mut result = sort_exec.execute(0, Arc::clone(&task_ctx))?; + + let mut number_of_rows = 0; + + let memory_pool = task_ctx.memory_pool(); + let memory_consumer = MemoryConsumer::new("mock_memory_consumer"); + let mut memory_reservation = memory_consumer.register(memory_pool); + + let mut index = 0; + let mut memory_took = false; + + while let Some(batch) = result.next().await { + match memory_behavior { + MemoryBehavior::AsIs => { + // Do nothing + } + MemoryBehavior::TakeAllMemoryAtTheBeginning => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(10, &mut memory_reservation)?; + } + } + MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(n) => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(pool_size, &mut memory_reservation)?; + } else if index % n == 0 { + // release memory + memory_reservation.free(); + } + } + } + + let batch = batch?; + number_of_rows += batch.num_rows(); + + index += 1; + } + + assert_eq!( + number_of_rows, + number_of_record_batches * record_batch_size as usize + ); + + let spill_count = sort_exec.metrics().unwrap().spill_count().unwrap(); + assert!( + spill_count > 0, + "Expected spill, but did not: {number_of_record_batches:?}" + ); + + Ok(spill_count) +} + +fn grow_memory_as_much_as_possible( + memory_step: usize, + memory_reservation: &mut MemoryReservation, +) -> Result { + let mut was_able_to_grow = false; + while memory_reservation.try_grow(memory_step).is_ok() { + was_able_to_grow = true; + } + + Ok(was_able_to_grow) +} diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index 06b93d41af36..99b20790fc46 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -23,6 +23,8 @@ mod sp_repartition_fuzz_tests { use arrow::compute::{concat_batches, lexsort, SortColumn, SortOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion::datasource::memory::MemorySourceConfig; + use datafusion::datasource::source::DataSourceExec; use datafusion::physical_plan::{ collect, metrics::{BaselineMetrics, ExecutionPlanMetricsSet}, @@ -34,19 +36,16 @@ mod sp_repartition_fuzz_tests { }; use datafusion::prelude::SessionContext; use datafusion_common::Result; - use datafusion_execution::{ - config::SessionConfig, memory_pool::MemoryConsumer, SendableRecordBatchStream, - }; - use datafusion_physical_expr::{ - equivalence::{EquivalenceClass, EquivalenceProperties}, - expressions::{col, Column}, - ConstExpr, PhysicalExpr, PhysicalSortExpr, + use datafusion_execution::{config::SessionConfig, memory_pool::MemoryConsumer}; + use datafusion_physical_expr::equivalence::{ + EquivalenceClass, EquivalenceProperties, }; + use datafusion_physical_expr::expressions::{col, Column}; + use datafusion_physical_expr::ConstExpr; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use test_utils::add_empty_batches; - use datafusion::datasource::memory::MemorySourceConfig; - use datafusion::datasource::source::DataSourceExec; - use datafusion_physical_expr_common::sort_expr::LexOrdering; use itertools::izip; use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; @@ -80,9 +79,9 @@ mod sp_repartition_fuzz_tests { let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); // Define a and f are aliases - eq_properties.add_equal_conditions(col_a, col_f)?; + eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_f))?; // Column e has constant value. - eq_properties = eq_properties.with_constants([ConstExpr::from(col_e)]); + eq_properties.add_constants([ConstExpr::from(Arc::clone(col_e))])?; // Randomly order columns for sorting let mut rng = StdRng::seed_from_u64(seed); @@ -94,18 +93,18 @@ mod sp_repartition_fuzz_tests { }; while !remaining_exprs.is_empty() { - let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + let n_sort_expr = rng.random_range(1..remaining_exprs.len() + 1); remaining_exprs.shuffle(&mut rng); - let ordering = remaining_exprs - .drain(0..n_sort_expr) - .map(|expr| PhysicalSortExpr { - expr: expr.clone(), - options: options_asc, - }) - .collect(); + let ordering = + remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: expr.clone(), + options: options_asc, + }); - eq_properties.add_new_orderings([ordering]); + eq_properties.add_ordering(ordering); } Ok((test_schema, eq_properties)) @@ -144,14 +143,14 @@ mod sp_repartition_fuzz_tests { // Utility closure to generate random array let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { let values: Vec = (0..num_elems) - .map(|_| rng.gen_range(0..max_val) as u64) + .map(|_| rng.random_range(0..max_val) as u64) .collect(); Arc::new(UInt64Array::from_iter_values(values)) }; // Fill constant columns for constant in eq_properties.constants() { - let col = constant.expr().as_any().downcast_ref::().unwrap(); + let col = constant.expr.as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); let arr = Arc::new(UInt64Array::from_iter_values(vec![0; n_elem])) as ArrayRef; @@ -227,21 +226,21 @@ mod sp_repartition_fuzz_tests { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEM, N_DISTINCT)?; let schema = table_data_with_properties.schema(); - let streams: Vec = (0..N_PARTITION) + let streams = (0..N_PARTITION) .map(|_idx| { let batch = table_data_with_properties.clone(); Box::pin(RecordBatchStreamAdapter::new( schema.clone(), futures::stream::once(async { Ok(batch) }), - )) as SendableRecordBatchStream + )) as _ }) .collect::>(); - // Returns concatenated version of the all available orderings - let exprs = eq_properties - .oeq_class() - .output_ordering() - .unwrap_or_default(); + // Returns concatenated version of the all available orderings: + let Some(exprs) = eq_properties.oeq_class().output_ordering() else { + // We always should have an ordering due to the way we generate the schema: + unreachable!("No ordering found in eq_properties: {:?}", eq_properties); + }; let context = SessionContext::new().task_ctx(); let mem_reservation = @@ -261,7 +260,7 @@ mod sp_repartition_fuzz_tests { let res = concat_batches(&res[0].schema(), &res)?; for ordering in eq_properties.oeq_class().iter() { - let err_msg = format!("error in eq properties: {:?}", eq_properties); + let err_msg = format!("error in eq properties: {eq_properties:?}"); let sort_columns = ordering .iter() .map(|sort_expr| sort_expr.evaluate_to_sort_column(&res)) @@ -273,7 +272,7 @@ mod sp_repartition_fuzz_tests { let sorted_columns = lexsort(&sort_columns, None)?; // Make sure after merging ordering is still valid. - assert_eq!(orig_columns.len(), sorted_columns.len(), "{}", err_msg); + assert_eq!(orig_columns.len(), sorted_columns.len(), "{err_msg}"); assert!( izip!(orig_columns.into_iter(), sorted_columns.into_iter()) .all(|(lhs, rhs)| { lhs == rhs }), @@ -347,20 +346,16 @@ mod sp_repartition_fuzz_tests { let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::new_with_config(session_config); - let mut sort_keys = LexOrdering::default(); - for ordering_col in ["a", "b", "c"] { - sort_keys.push(PhysicalSortExpr { - expr: col(ordering_col, &schema).unwrap(), - options: SortOptions::default(), - }) - } + let sort_keys = ["a", "b", "c"].map(|ordering_col| { + PhysicalSortExpr::new_default(col(ordering_col, &schema).unwrap()) + }); let concat_input_record = concat_batches(&schema, &input1).unwrap(); let running_source = Arc::new( - MemorySourceConfig::try_new(&[input1.clone()], schema.clone(), None) + MemorySourceConfig::try_new(&[input1], schema.clone(), None) .unwrap() - .try_with_sort_information(vec![sort_keys.clone()]) + .try_with_sort_information(vec![sort_keys.clone().into()]) .unwrap(), ); let running_source = Arc::new(DataSourceExec::new(running_source)); @@ -381,7 +376,7 @@ mod sp_repartition_fuzz_tests { sort_preserving_repartition_exec_hash(intermediate, hash_exprs.clone()) }; - let final_plan = sort_preserving_merge_exec(sort_keys.clone(), intermediate); + let final_plan = sort_preserving_merge_exec(sort_keys.into(), intermediate); let task_ctx = ctx.task_ctx(); let collected_running = collect(final_plan, task_ctx.clone()).await.unwrap(); @@ -428,10 +423,9 @@ mod sp_repartition_fuzz_tests { } fn sort_preserving_merge_exec( - sort_exprs: impl IntoIterator, + sort_exprs: LexOrdering, input: Arc, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) } @@ -447,9 +441,9 @@ mod sp_repartition_fuzz_tests { let mut input123: Vec<(i64, i64, i64)> = vec![(0, 0, 0); len]; input123.iter_mut().for_each(|v| { *v = ( - rng.gen_range(0..n_distinct) as i64, - rng.gen_range(0..n_distinct) as i64, - rng.gen_range(0..n_distinct) as i64, + rng.random_range(0..n_distinct) as i64, + rng.random_range(0..n_distinct) as i64, + rng.random_range(0..n_distinct) as i64, ) }); input123.sort(); @@ -471,7 +465,7 @@ mod sp_repartition_fuzz_tests { let mut batches = vec![]; if STREAM { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..50); + let batch_size = rng.random_range(0..50); if remainder.num_rows() < batch_size { break; } @@ -481,7 +475,7 @@ mod sp_repartition_fuzz_tests { } } else { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + let batch_size = rng.random_range(0..remainder.num_rows() + 1); batches.push(remainder.slice(0, batch_size)); remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); diff --git a/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs index 1319d4817326..1f47412caf2a 100644 --- a/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs @@ -25,18 +25,17 @@ use arrow_schema::SchemaRef; use datafusion::datasource::MemTable; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{instant::Instant, Result}; +use datafusion_execution::disk_manager::DiskManagerBuilder; use datafusion_execution::memory_pool::{ human_readable_size, MemoryPool, UnboundedMemoryPool, }; use datafusion_expr::display_schema; use datafusion_physical_plan::spill::get_record_batch_memory_size; -use rand::seq::SliceRandom; +use itertools::Itertools; use std::time::Duration; -use datafusion_execution::{ - disk_manager::DiskManagerConfig, memory_pool::FairSpillPool, - runtime_env::RuntimeEnvBuilder, -}; +use datafusion_execution::{memory_pool::FairSpillPool, runtime_env::RuntimeEnvBuilder}; +use rand::prelude::IndexedRandom; use rand::Rng; use rand::{rngs::StdRng, SeedableRng}; @@ -74,6 +73,43 @@ async fn sort_query_fuzzer_runner() { fuzzer.run().await.unwrap(); } +/// Reproduce the bug with specific seeds from the +/// [failing test case](https://github.com/apache/datafusion/issues/16452). +#[tokio::test(flavor = "multi_thread")] +async fn test_reproduce_sort_query_issue_16452() { + // Seeds from the failing test case + let init_seed = 10313160656544581998u64; + let query_seed = 15004039071976572201u64; + let config_seed_1 = 11807432710583113300u64; + let config_seed_2 = 759937414670321802u64; + + let random_seed = 1u64; // Use a fixed seed to ensure consistent behavior + + let mut test_generator = SortFuzzerTestGenerator::new( + 2000, + 3, + "sort_fuzz_table".to_string(), + get_supported_types_columns(random_seed), + false, + random_seed, + ); + + let mut results = vec![]; + + for config_seed in [config_seed_1, config_seed_2] { + let r = test_generator + .fuzzer_run(init_seed, query_seed, config_seed) + .await + .unwrap(); + + results.push(r); + } + + for (lhs, rhs) in results.iter().tuple_windows() { + check_equality_of_batches(lhs, rhs).unwrap(); + } +} + /// SortQueryFuzzer holds the runner configuration for executing sort query fuzz tests. The fuzzing details are managed inside `SortFuzzerTestGenerator`. /// /// It defines: @@ -199,25 +235,24 @@ impl SortQueryFuzzer { // Execute until either`max_rounds` or `time_limit` is reached let max_rounds = self.max_rounds.unwrap_or(usize::MAX); for round in 0..max_rounds { - let init_seed = self.runner_rng.gen(); + let init_seed = self.runner_rng.random(); for query_i in 0..self.queries_per_round { - let query_seed = self.runner_rng.gen(); + let query_seed = self.runner_rng.random(); let mut expected_results: Option> = None; // use first config's result as the expected result for config_i in 0..self.config_variations_per_query { if self.should_stop_due_to_time_limit(start_time, round, query_i) { return Ok(()); } - let config_seed = self.runner_rng.gen(); + let config_seed = self.runner_rng.random(); println!( - "[SortQueryFuzzer] Round {}, Query {} (Config {})", - round, query_i, config_i + "[SortQueryFuzzer] Round {round}, Query {query_i} (Config {config_i})" ); println!(" Seeds:"); - println!(" init_seed = {}", init_seed); - println!(" query_seed = {}", query_seed); - println!(" config_seed = {}", config_seed); + println!(" init_seed = {init_seed}"); + println!(" query_seed = {query_seed}"); + println!(" config_seed = {config_seed}"); let results = self .test_gen @@ -300,7 +335,7 @@ impl SortFuzzerTestGenerator { let mut rng = StdRng::seed_from_u64(rng_seed); let min_ncol = min(candidate_columns.len(), 5); let max_ncol = min(candidate_columns.len(), 10); - let amount = rng.gen_range(min_ncol..=max_ncol); + let amount = rng.random_range(min_ncol..=max_ncol); let selected_columns = candidate_columns .choose_multiple(&mut rng, amount) .cloned() @@ -327,7 +362,7 @@ impl SortFuzzerTestGenerator { /// memory table should be generated with more partitions, due to https://github.com/apache/datafusion/issues/15088 fn init_partitioned_staggered_batches(&mut self, rng_seed: u64) { let mut rng = StdRng::seed_from_u64(rng_seed); - let num_partitions = rng.gen_range(1..=self.max_partitions); + let num_partitions = rng.random_range(1..=self.max_partitions); let max_batch_size = self.num_rows / num_partitions / 50; let target_partition_size = self.num_rows / num_partitions; @@ -344,7 +379,7 @@ impl SortFuzzerTestGenerator { // Generate a random batch of size between 1 and max_batch_size // Let edge case (1-row batch) more common - let (min_nrow, max_nrow) = if rng.gen_bool(0.1) { + let (min_nrow, max_nrow) = if rng.random_bool(0.1) { (1, 3) } else { (1, max_batch_size) @@ -355,7 +390,7 @@ impl SortFuzzerTestGenerator { max_nrow, self.selected_columns.clone(), ) - .with_seed(rng.gen()); + .with_seed(rng.random()); let record_batch = record_batch_generator.generate().unwrap(); num_rows += record_batch.num_rows(); @@ -373,9 +408,9 @@ impl SortFuzzerTestGenerator { } // After all partitions are created, optionally make one partition have 0/1 batch - if num_partitions > 2 && rng.gen_bool(0.1) { - let partition_index = rng.gen_range(0..num_partitions); - if rng.gen_bool(0.5) { + if num_partitions > 2 && rng.random_bool(0.1) { + let partition_index = rng.random_range(0..num_partitions); + if rng.random_bool(0.5) { // 0 batch partitions[partition_index] = Vec::new(); } else { @@ -424,7 +459,7 @@ impl SortFuzzerTestGenerator { pub fn generate_random_query(&self, rng_seed: u64) -> (String, Option) { let mut rng = StdRng::seed_from_u64(rng_seed); - let num_columns = rng.gen_range(1..=3).min(self.selected_columns.len()); + let num_columns = rng.random_range(1..=3).min(self.selected_columns.len()); let selected_columns: Vec<_> = self .selected_columns .choose_multiple(&mut rng, num_columns) @@ -433,37 +468,37 @@ impl SortFuzzerTestGenerator { let mut order_by_clauses = Vec::new(); for col in selected_columns { let mut clause = col.name.clone(); - if rng.gen_bool(0.5) { - let order = if rng.gen_bool(0.5) { "ASC" } else { "DESC" }; - clause.push_str(&format!(" {}", order)); + if rng.random_bool(0.5) { + let order = if rng.random_bool(0.5) { "ASC" } else { "DESC" }; + clause.push_str(&format!(" {order}")); } - if rng.gen_bool(0.5) { - let nulls = if rng.gen_bool(0.5) { + if rng.random_bool(0.5) { + let nulls = if rng.random_bool(0.5) { "NULLS FIRST" } else { "NULLS LAST" }; - clause.push_str(&format!(" {}", nulls)); + clause.push_str(&format!(" {nulls}")); } order_by_clauses.push(clause); } let dataset_size = self.dataset_state.as_ref().unwrap().dataset_size; - let limit = if rng.gen_bool(0.2) { + let limit = if rng.random_bool(0.2) { // Prefer edge cases for k like 1, dataset_size, etc. - Some(if rng.gen_bool(0.5) { + Some(if rng.random_bool(0.5) { let edge_cases = [1, 2, 3, dataset_size - 1, dataset_size, dataset_size + 1]; *edge_cases.choose(&mut rng).unwrap() } else { - rng.gen_range(1..=dataset_size) + rng.random_range(1..=dataset_size) }) } else { None }; - let limit_clause = limit.map_or(String::new(), |l| format!(" LIMIT {}", l)); + let limit_clause = limit.map_or(String::new(), |l| format!(" LIMIT {l}")); let query = format!( "SELECT * FROM {} ORDER BY {}{}", @@ -487,12 +522,12 @@ impl SortFuzzerTestGenerator { // 30% to 200% of the dataset size (if `with_memory_limit` is false, config // will use the default unbounded pool to override it later) - let memory_limit = rng.gen_range( + let memory_limit = rng.random_range( (dataset_size as f64 * 0.5) as usize..=(dataset_size as f64 * 2.0) as usize, ); // 10% to 20% of the per-partition memory limit size let per_partition_mem_limit = memory_limit / num_partitions; - let sort_spill_reservation_bytes = rng.gen_range( + let sort_spill_reservation_bytes = rng.random_range( (per_partition_mem_limit as f64 * 0.2) as usize ..=(per_partition_mem_limit as f64 * 0.3) as usize, ); @@ -505,7 +540,7 @@ impl SortFuzzerTestGenerator { 0 } else { let dataset_size = self.dataset_state.as_ref().unwrap().dataset_size; - rng.gen_range(0..=dataset_size * 2_usize) + rng.random_range(0..=dataset_size * 2_usize) }; // Set up strings for printing @@ -522,13 +557,10 @@ impl SortFuzzerTestGenerator { println!(" Config: "); println!(" Dataset size: {}", human_readable_size(dataset_size)); - println!(" Number of partitions: {}", num_partitions); + println!(" Number of partitions: {num_partitions}"); println!(" Batch size: {}", init_state.approx_batch_num_rows / 2); - println!(" Memory limit: {}", memory_limit_str); - println!( - " Per partition memory limit: {}", - per_partition_limit_str - ); + println!(" Memory limit: {memory_limit_str}"); + println!(" Per partition memory limit: {per_partition_limit_str}"); println!( " Sort spill reservation bytes: {}", human_readable_size(sort_spill_reservation_bytes) @@ -552,7 +584,7 @@ impl SortFuzzerTestGenerator { let runtime = RuntimeEnvBuilder::new() .with_memory_pool(memory_pool) - .with_disk_manager(DiskManagerConfig::NewOs) + .with_disk_manager_builder(DiskManagerBuilder::default()) .build_arc()?; let ctx = SessionContext::new_with_config_rt(config, runtime); @@ -575,7 +607,7 @@ impl SortFuzzerTestGenerator { self.init_partitioned_staggered_batches(dataset_seed); let (query_str, limit) = self.generate_random_query(query_seed); println!(" Query:"); - println!(" {}", query_str); + println!(" {query_str}"); // ==== Execute the query ==== diff --git a/datafusion/core/tests/fuzz_cases/stream_exec.rs b/datafusion/core/tests/fuzz_cases/stream_exec.rs new file mode 100644 index 000000000000..6e71b9988d79 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/stream_exec.rs @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::SchemaRef; +use datafusion_common::DataFusionError; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, +}; +use std::any::Any; +use std::fmt::{Debug, Formatter}; +use std::sync::{Arc, Mutex}; + +/// Execution plan that return the stream on the call to `execute`. further calls to `execute` will +/// return an error +pub struct StreamExec { + /// the results to send back + stream: Mutex>, + cache: PlanProperties, +} + +impl Debug for StreamExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "StreamExec") + } +} + +impl StreamExec { + pub fn new(stream: SendableRecordBatchStream) -> Self { + let cache = Self::compute_properties(stream.schema()); + Self { + stream: Mutex::new(Some(stream)), + cache, + } + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties(schema: SchemaRef) -> PlanProperties { + PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ) + } +} + +impl DisplayAs for StreamExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "StreamExec:") + } + DisplayFormatType::TreeRender => { + write!(f, "") + } + } + } +} + +impl ExecutionPlan for StreamExec { + fn name(&self) -> &'static str { + Self::static_name() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion_common::Result> { + unimplemented!() + } + + /// Returns a stream which yields data + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> datafusion_common::Result { + assert_eq!(partition, 0); + + let stream = self.stream.lock().unwrap().take(); + + stream.ok_or(DataFusionError::Internal( + "Stream already consumed".to_string(), + )) + } +} diff --git a/datafusion/core/tests/fuzz_cases/topk_filter_pushdown.rs b/datafusion/core/tests/fuzz_cases/topk_filter_pushdown.rs new file mode 100644 index 000000000000..a5934882cbcc --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/topk_filter_pushdown.rs @@ -0,0 +1,387 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::sync::{Arc, LazyLock}; + +use arrow::array::{Int32Array, StringArray, StringDictionaryBuilder}; +use arrow::datatypes::Int32Type; +use arrow::record_batch::RecordBatch; +use arrow::util::pretty::pretty_format_batches; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::datasource::listing::{ListingOptions, ListingTable, ListingTableConfig}; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_datasource::ListingTableUrl; +use datafusion_datasource_parquet::ParquetFormat; +use datafusion_execution::object_store::ObjectStoreUrl; +use itertools::Itertools; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStore, PutPayload}; +use parquet::arrow::ArrowWriter; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use tokio::sync::Mutex; +use tokio::task::JoinSet; + +#[derive(Clone)] +struct TestDataSet { + store: Arc, + schema: Arc, +} + +/// List of in memory parquet files with UTF8 data +// Use a mutex rather than LazyLock to allow for async initialization +static TESTFILES: LazyLock>> = + LazyLock::new(|| Mutex::new(vec![])); + +async fn test_files() -> Vec { + let files_mutex = &TESTFILES; + let mut files = files_mutex.lock().await; + if !files.is_empty() { + return (*files).clone(); + } + + let mut rng = StdRng::seed_from_u64(0); + + for nulls_in_ids in [false, true] { + for nulls_in_names in [false, true] { + for nulls_in_departments in [false, true] { + let store = Arc::new(InMemory::new()); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, nulls_in_ids), + Field::new("name", DataType::Utf8, nulls_in_names), + Field::new( + "department", + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), + nulls_in_departments, + ), + ])); + + let name_choices = if nulls_in_names { + [Some("Alice"), Some("Bob"), None, Some("David"), None] + } else { + [ + Some("Alice"), + Some("Bob"), + Some("Charlie"), + Some("David"), + Some("Eve"), + ] + }; + + let department_choices = if nulls_in_departments { + [ + Some("Theater"), + Some("Engineering"), + None, + Some("Arts"), + None, + ] + } else { + [ + Some("Theater"), + Some("Engineering"), + Some("Healthcare"), + Some("Arts"), + Some("Music"), + ] + }; + + // Generate 5 files, some with overlapping or repeated ids some without + for i in 0..5 { + let num_batches = rng.random_range(1..3); + let mut batches = Vec::with_capacity(num_batches); + for _ in 0..num_batches { + let num_rows = 25; + let ids = Int32Array::from_iter((0..num_rows).map(|file| { + if nulls_in_ids { + if rng.random_bool(1.0 / 10.0) { + None + } else { + Some(rng.random_range(file..file + 5)) + } + } else { + Some(rng.random_range(file..file + 5)) + } + })); + let names = StringArray::from_iter((0..num_rows).map(|_| { + // randomly select a name + let idx = rng.random_range(0..name_choices.len()); + name_choices[idx].map(|s| s.to_string()) + })); + let mut departments = StringDictionaryBuilder::::new(); + for _ in 0..num_rows { + // randomly select a department + let idx = rng.random_range(0..department_choices.len()); + departments.append_option(department_choices[idx].as_ref()); + } + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(ids), + Arc::new(names), + Arc::new(departments.finish()), + ], + ) + .unwrap(); + batches.push(batch); + } + let mut buf = vec![]; + { + let mut writer = + ArrowWriter::try_new(&mut buf, schema.clone(), None).unwrap(); + for batch in batches { + writer.write(&batch).unwrap(); + writer.flush().unwrap(); + } + writer.flush().unwrap(); + writer.finish().unwrap(); + } + let payload = PutPayload::from(buf); + let path = Path::from(format!("file_{i}.parquet")); + store.put(&path, payload).await.unwrap(); + } + files.push(TestDataSet { store, schema }); + } + } + } + (*files).clone() +} + +struct RunResult { + results: Vec, + explain_plan: String, +} + +async fn run_query_with_config( + query: &str, + config: SessionConfig, + dataset: TestDataSet, +) -> RunResult { + let store = dataset.store; + let schema = dataset.schema; + let ctx = SessionContext::new_with_config(config); + let url = ObjectStoreUrl::parse("memory://").unwrap(); + ctx.register_object_store(url.as_ref(), store.clone()); + + let format = Arc::new( + ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()), + ); + let options = ListingOptions::new(format); + let table_path = ListingTableUrl::parse("memory:///").unwrap(); + let config = ListingTableConfig::new(table_path) + .with_listing_options(options) + .with_schema(schema); + let table = Arc::new(ListingTable::try_new(config).unwrap()); + + ctx.register_table("test_table", table).unwrap(); + + let results = ctx.sql(query).await.unwrap().collect().await.unwrap(); + let explain_batches = ctx + .sql(&format!("EXPLAIN ANALYZE {query}")) + .await + .unwrap() + .collect() + .await + .unwrap(); + let explain_plan = pretty_format_batches(&explain_batches).unwrap().to_string(); + RunResult { + results, + explain_plan, + } +} + +#[derive(Debug)] +struct RunQueryResult { + query: String, + result: Vec, + expected: Vec, +} + +impl RunQueryResult { + fn expected_formated(&self) -> String { + format!("{}", pretty_format_batches(&self.expected).unwrap()) + } + + fn result_formated(&self) -> String { + format!("{}", pretty_format_batches(&self.result).unwrap()) + } + + fn is_ok(&self) -> bool { + self.expected_formated() == self.result_formated() + } +} + +/// Iterate over each line in the plan and check that one of them has `DataSourceExec` and `DynamicFilterPhysicalExpr` in the same line. +fn has_dynamic_filter_expr_pushdown(plan: &str) -> bool { + for line in plan.lines() { + if line.contains("DataSourceExec") && line.contains("DynamicFilterPhysicalExpr") { + return true; + } + } + false +} + +async fn run_query( + query: String, + cfg: SessionConfig, + dataset: TestDataSet, +) -> RunQueryResult { + let cfg_with_dynamic_filters = cfg + .clone() + .set_bool("datafusion.optimizer.enable_dynamic_filter_pushdown", true); + let cfg_without_dynamic_filters = cfg + .clone() + .set_bool("datafusion.optimizer.enable_dynamic_filter_pushdown", false); + + let expected_result = + run_query_with_config(&query, cfg_without_dynamic_filters, dataset.clone()).await; + let result = + run_query_with_config(&query, cfg_with_dynamic_filters, dataset.clone()).await; + // Check that dynamic filters were actually pushed down + if !has_dynamic_filter_expr_pushdown(&result.explain_plan) { + panic!( + "Dynamic filter was not pushed down in query: {query}\n\n{}", + result.explain_plan + ); + } + + RunQueryResult { + query: query.to_string(), + result: result.results, + expected: expected_result.results, + } +} + +struct TestCase { + query: String, + cfg: SessionConfig, + dataset: TestDataSet, +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_fuzz_topk_filter_pushdown() { + let order_columns = ["id", "name", "department"]; + let order_directions = ["ASC", "DESC"]; + let null_orders = ["NULLS FIRST", "NULLS LAST"]; + + let start = datafusion_common::instant::Instant::now(); + let mut orders: HashMap> = HashMap::new(); + for order_column in &order_columns { + for order_direction in &order_directions { + for null_order in &null_orders { + // if there is a vec for this column insert the order, otherwise create a new vec + let ordering = format!("{order_column} {order_direction} {null_order}"); + match orders.get_mut(*order_column) { + Some(order_vec) => { + order_vec.push(ordering); + } + None => { + orders.insert(order_column.to_string(), vec![ordering]); + } + } + } + } + } + + let mut queries = vec![]; + + for limit in [1, 10] { + for num_order_by_columns in [1, 2, 3] { + for order_columns in ["id", "name", "department"] + .iter() + .combinations(num_order_by_columns) + { + for orderings in order_columns + .iter() + .map(|col| orders.get(**col).unwrap()) + .multi_cartesian_product() + { + let query = format!( + "SELECT * FROM test_table ORDER BY {} LIMIT {}", + orderings.into_iter().join(", "), + limit + ); + queries.push(query); + } + } + } + } + + queries.sort_unstable(); + println!( + "Generated {} queries in {:?}", + queries.len(), + start.elapsed() + ); + + let start = datafusion_common::instant::Instant::now(); + let datasets = test_files().await; + println!("Generated test files in {:?}", start.elapsed()); + + let mut test_cases = vec![]; + for enable_filter_pushdown in [true, false] { + for query in &queries { + for dataset in &datasets { + let mut cfg = SessionConfig::new(); + cfg = cfg.set_bool( + "datafusion.optimizer.enable_dynamic_filter_pushdown", + enable_filter_pushdown, + ); + test_cases.push(TestCase { + query: query.to_string(), + cfg, + dataset: dataset.clone(), + }); + } + } + } + + let start = datafusion_common::instant::Instant::now(); + let mut join_set = JoinSet::new(); + for tc in test_cases { + join_set.spawn(run_query(tc.query, tc.cfg, tc.dataset)); + } + let mut results = join_set.join_all().await; + results.sort_unstable_by(|a, b| a.query.cmp(&b.query)); + println!("Ran {} test cases in {:?}", results.len(), start.elapsed()); + + let failures = results + .iter() + .filter(|result| !result.is_ok()) + .collect::>(); + + for failure in &failures { + println!("Failure:"); + println!("Query:\n{}", failure.query); + println!("\nExpected:\n{}", failure.expected_formated()); + println!("\nResult:\n{}", failure.result_formated()); + println!("\n\n"); + } + + if !failures.is_empty() { + panic!("Some test cases failed"); + } else { + println!("All test cases passed"); + } +} diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 6b166dd32782..316d3ba5a926 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -35,7 +35,7 @@ use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::HashMap; use datafusion_common::{Result, ScalarValue}; use datafusion_common_runtime::SpawnedTask; -use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf; +use datafusion_expr::type_coercion::functions::fields_with_aggregate_udf; use datafusion_expr::{ WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; @@ -51,7 +51,7 @@ use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use test_utils::add_empty_batches; @@ -252,7 +252,6 @@ async fn bounded_window_causal_non_causal() -> Result<()> { ]; let partitionby_exprs = vec![]; - let orderby_exprs = LexOrdering::default(); // Window frame starts with "UNBOUNDED PRECEDING": let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None)); @@ -285,7 +284,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { fn_name.to_string(), &args, &partitionby_exprs, - orderby_exprs.as_ref(), + &[], Arc::new(window_frame), &extended_schema, false, @@ -398,8 +397,8 @@ fn get_random_function( WindowFunctionDefinition::WindowUDF(lead_udwf()), vec![ arg.clone(), - lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), - lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), + lit(ScalarValue::Int64(Some(rng.random_range(1..10)))), + lit(ScalarValue::Int64(Some(rng.random_range(1..1000)))), ], ), ); @@ -409,8 +408,8 @@ fn get_random_function( WindowFunctionDefinition::WindowUDF(lag_udwf()), vec![ arg.clone(), - lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), - lit(ScalarValue::Int64(Some(rng.gen_range(1..1000)))), + lit(ScalarValue::Int64(Some(rng.random_range(1..10)))), + lit(ScalarValue::Int64(Some(rng.random_range(1..1000)))), ], ), ); @@ -435,12 +434,12 @@ fn get_random_function( WindowFunctionDefinition::WindowUDF(nth_value_udwf()), vec![ arg.clone(), - lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), + lit(ScalarValue::Int64(Some(rng.random_range(1..10)))), ], ), ); - let rand_fn_idx = rng.gen_range(0..window_fn_map.len()); + let rand_fn_idx = rng.random_range(0..window_fn_map.len()); let fn_name = window_fn_map.keys().collect::>()[rand_fn_idx]; let (window_fn, args) = window_fn_map.values().collect::>()[rand_fn_idx]; let mut args = args.clone(); @@ -448,9 +447,9 @@ fn get_random_function( if !args.is_empty() { // Do type coercion first argument let a = args[0].clone(); - let dt = a.data_type(schema.as_ref()).unwrap(); - let coerced = data_types_with_aggregate_udf(&[dt], udf).unwrap(); - args[0] = cast(a, schema, coerced[0].clone()).unwrap(); + let dt = a.return_field(schema.as_ref()).unwrap(); + let coerced = fields_with_aggregate_udf(&[dt], udf).unwrap(); + args[0] = cast(a, schema, coerced[0].data_type().clone()).unwrap(); } } @@ -463,12 +462,12 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame { is_preceding: bool, } let first_bound = Utils { - val: rng.gen_range(0..10), - is_preceding: rng.gen_range(0..2) == 0, + val: rng.random_range(0..10), + is_preceding: rng.random_range(0..2) == 0, }; let second_bound = Utils { - val: rng.gen_range(0..10), - is_preceding: rng.gen_range(0..2) == 0, + val: rng.random_range(0..10), + is_preceding: rng.random_range(0..2) == 0, }; let (start_bound, end_bound) = if first_bound.is_preceding == second_bound.is_preceding { @@ -485,7 +484,7 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame { (second_bound, first_bound) }; // 0 means Range, 1 means Rows, 2 means GROUPS - let rand_num = rng.gen_range(0..3); + let rand_num = rng.random_range(0..3); let units = if rand_num < 1 { WindowFrameUnits::Range } else if rand_num < 2 { @@ -517,7 +516,7 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame { }; let mut window_frame = WindowFrame::new_bounds(units, start_bound, end_bound); // with 10% use unbounded preceding in tests - if rng.gen_range(0..10) == 0 { + if rng.random_range(0..10) == 0 { window_frame.start_bound = WindowFrameBound::Preceding(ScalarValue::Int32(None)); } @@ -545,7 +544,7 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame { }; let mut window_frame = WindowFrame::new_bounds(units, start_bound, end_bound); // with 10% use unbounded preceding in tests - if rng.gen_range(0..10) == 0 { + if rng.random_range(0..10) == 0 { window_frame.start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None)); } @@ -569,7 +568,7 @@ fn convert_bound_to_current_row_if_applicable( match bound { WindowFrameBound::Preceding(value) | WindowFrameBound::Following(value) => { if let Ok(zero) = ScalarValue::new_zero(&value.data_type()) { - if value == &zero && rng.gen_range(0..2) == 0 { + if value == &zero && rng.random_range(0..2) == 0 { *bound = WindowFrameBound::CurrentRow; } } @@ -594,7 +593,7 @@ async fn run_window_test( let ctx = SessionContext::new_with_config(session_config); let (window_fn, args, fn_name) = get_random_function(&schema, &mut rng, is_linear); let window_frame = get_random_window_frame(&mut rng, is_linear); - let mut orderby_exprs = LexOrdering::default(); + let mut orderby_exprs = vec![]; for column in &orderby_columns { orderby_exprs.push(PhysicalSortExpr { expr: col(column, &schema)?, @@ -602,13 +601,13 @@ async fn run_window_test( }) } if orderby_exprs.len() > 1 && !window_frame.can_accept_multi_orderby() { - orderby_exprs = LexOrdering::new(orderby_exprs[0..1].to_vec()); + orderby_exprs.truncate(1); } let mut partitionby_exprs = vec![]; for column in &partition_by_columns { partitionby_exprs.push(col(column, &schema)?); } - let mut sort_keys = LexOrdering::default(); + let mut sort_keys = vec![]; for partition_by_expr in &partitionby_exprs { sort_keys.push(PhysicalSortExpr { expr: partition_by_expr.clone(), @@ -622,7 +621,7 @@ async fn run_window_test( } let concat_input_record = concat_batches(&schema, &input1)?; - let source_sort_keys = LexOrdering::new(vec![ + let source_sort_keys: LexOrdering = [ PhysicalSortExpr { expr: col("a", &schema)?, options: Default::default(), @@ -635,7 +634,8 @@ async fn run_window_test( expr: col("c", &schema)?, options: Default::default(), }, - ]); + ] + .into(); let mut exec1 = DataSourceExec::from_data_source( MemorySourceConfig::try_new(&[vec![concat_input_record]], schema.clone(), None)? .try_with_sort_information(vec![source_sort_keys.clone()])?, @@ -643,7 +643,9 @@ async fn run_window_test( // Table is ordered according to ORDER BY a, b, c In linear test we use PARTITION BY b, ORDER BY a // For WindowAggExec to produce correct result it need table to be ordered by b,a. Hence add a sort. if is_linear { - exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _; + if let Some(ordering) = LexOrdering::new(sort_keys) { + exec1 = Arc::new(SortExec::new(ordering, exec1)) as _; + } } let extended_schema = schema_add_window_field(&args, &schema, &window_fn, &fn_name)?; @@ -654,7 +656,7 @@ async fn run_window_test( fn_name.clone(), &args, &partitionby_exprs, - orderby_exprs.as_ref(), + &orderby_exprs.clone(), Arc::new(window_frame.clone()), &extended_schema, false, @@ -663,8 +665,8 @@ async fn run_window_test( false, )?) as _; let exec2 = DataSourceExec::from_data_source( - MemorySourceConfig::try_new(&[input1.clone()], schema.clone(), None)? - .try_with_sort_information(vec![source_sort_keys.clone()])?, + MemorySourceConfig::try_new(&[input1], schema, None)? + .try_with_sort_information(vec![source_sort_keys])?, ); let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( vec![create_window_expr( @@ -672,7 +674,7 @@ async fn run_window_test( fn_name, &args, &partitionby_exprs, - orderby_exprs.as_ref(), + &orderby_exprs, Arc::new(window_frame.clone()), &extended_schema, false, @@ -728,7 +730,7 @@ async fn run_window_test( for (line1, line2) in usual_formatted_sorted.iter().zip(running_formatted_sorted) { - println!("{:?} --- {:?}", line1, line2); + println!("{line1:?} --- {line2:?}"); } unreachable!(); } @@ -758,9 +760,9 @@ pub(crate) fn make_staggered_batches( let mut input5: Vec = vec!["".to_string(); len]; input123.iter_mut().for_each(|v| { *v = ( - rng.gen_range(0..n_distinct) as i32, - rng.gen_range(0..n_distinct) as i32, - rng.gen_range(0..n_distinct) as i32, + rng.random_range(0..n_distinct) as i32, + rng.random_range(0..n_distinct) as i32, + rng.random_range(0..n_distinct) as i32, ) }); input123.sort(); @@ -788,7 +790,7 @@ pub(crate) fn make_staggered_batches( let mut batches = vec![]; if STREAM { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..50); + let batch_size = rng.random_range(0..50); if remainder.num_rows() < batch_size { batches.push(remainder); break; @@ -798,7 +800,7 @@ pub(crate) fn make_staggered_batches( } } else { while remainder.num_rows() > 0 { - let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + let batch_size = rng.random_range(0..remainder.num_rows() + 1); batches.push(remainder.slice(0, batch_size)); remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); } diff --git a/datafusion/core/tests/integration_tests/schema_adapter_integration_tests.rs b/datafusion/core/tests/integration_tests/schema_adapter_integration_tests.rs new file mode 100644 index 000000000000..833af04680db --- /dev/null +++ b/datafusion/core/tests/integration_tests/schema_adapter_integration_tests.rs @@ -0,0 +1,445 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration test for schema adapter factory functionality + +use std::any::Any; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion::datasource::object_store::ObjectStoreUrl; +use datafusion::datasource::physical_plan::arrow_file::ArrowSource; +use datafusion::prelude::*; +use datafusion_common::Result; +use datafusion_datasource::file::FileSource; +use datafusion_datasource::file_scan_config::FileScanConfigBuilder; +use datafusion_datasource::schema_adapter::{SchemaAdapter, SchemaAdapterFactory}; +use datafusion_datasource::source::DataSourceExec; +use datafusion_datasource::PartitionedFile; +use std::sync::Arc; +use tempfile::TempDir; + +#[cfg(feature = "parquet")] +use datafusion_datasource_parquet::ParquetSource; +#[cfg(feature = "parquet")] +use parquet::arrow::ArrowWriter; +#[cfg(feature = "parquet")] +use parquet::file::properties::WriterProperties; + +#[cfg(feature = "csv")] +use datafusion_datasource_csv::CsvSource; + +/// A schema adapter factory that transforms column names to uppercase +#[derive(Debug)] +struct UppercaseAdapterFactory {} + +impl SchemaAdapterFactory for UppercaseAdapterFactory { + fn create(&self, schema: &Schema) -> Result> { + Ok(Box::new(UppercaseAdapter { + input_schema: Arc::new(schema.clone()), + })) + } +} + +/// Schema adapter that transforms column names to uppercase +#[derive(Debug)] +struct UppercaseAdapter { + input_schema: SchemaRef, +} + +impl SchemaAdapter for UppercaseAdapter { + fn adapt(&self, record_batch: RecordBatch) -> Result { + // In a real adapter, we might transform the data too + // For this test, we're just passing through the batch + Ok(record_batch) + } + + fn output_schema(&self) -> SchemaRef { + let fields = self + .input_schema + .fields() + .iter() + .map(|f| { + Field::new( + f.name().to_uppercase().as_str(), + f.data_type().clone(), + f.is_nullable(), + ) + }) + .collect(); + + Arc::new(Schema::new(fields)) + } +} + +#[cfg(feature = "parquet")] +#[tokio::test] +async fn test_parquet_integration_with_schema_adapter() -> Result<()> { + // Create a temporary directory for our test file + let tmp_dir = TempDir::new()?; + let file_path = tmp_dir.path().join("test.parquet"); + let file_path_str = file_path.to_str().unwrap(); + + // Create test data + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3])), + Arc::new(arrow::array::StringArray::from(vec!["a", "b", "c"])), + ], + )?; + + // Write test parquet file + let file = std::fs::File::create(file_path_str)?; + let props = WriterProperties::builder().build(); + let mut writer = ArrowWriter::try_new(file, schema.clone(), Some(props))?; + writer.write(&batch)?; + writer.close()?; + + // Create a session context + let ctx = SessionContext::new(); + + // Create a ParquetSource with the adapter factory + let source = ParquetSource::default() + .with_schema_adapter_factory(Arc::new(UppercaseAdapterFactory {})); + + // Create a scan config + let config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse(&format!("file://{}", file_path_str))?, + schema.clone(), + ) + .with_source(source) + .build(); + + // Create a data source executor + let exec = DataSourceExec::from_data_source(config); + + // Collect results + let task_ctx = ctx.task_ctx(); + let stream = exec.execute(0, task_ctx)?; + let batches = datafusion::physical_plan::common::collect(stream).await?; + + // There should be one batch + assert_eq!(batches.len(), 1); + + // Verify the schema has uppercase column names + let result_schema = batches[0].schema(); + assert_eq!(result_schema.field(0).name(), "ID"); + assert_eq!(result_schema.field(1).name(), "NAME"); + + Ok(()) +} + +#[tokio::test] +async fn test_multi_source_schema_adapter_reuse() -> Result<()> { + // This test verifies that the same schema adapter factory can be reused + // across different file source types. This is important for ensuring that: + // 1. The schema adapter factory interface works uniformly across all source types + // 2. The factory can be shared and cloned efficiently using Arc + // 3. Various data source implementations correctly implement the schema adapter factory pattern + + // Create a test factory + let factory = Arc::new(UppercaseAdapterFactory {}); + + // Apply the same adapter to different source types + let arrow_source = + ArrowSource::default().with_schema_adapter_factory(factory.clone()); + + #[cfg(feature = "parquet")] + let parquet_source = + ParquetSource::default().with_schema_adapter_factory(factory.clone()); + + #[cfg(feature = "csv")] + let csv_source = CsvSource::default().with_schema_adapter_factory(factory.clone()); + + // Verify adapters were properly set + assert!(arrow_source.schema_adapter_factory().is_some()); + + #[cfg(feature = "parquet")] + assert!(parquet_source.schema_adapter_factory().is_some()); + + #[cfg(feature = "csv")] + assert!(csv_source.schema_adapter_factory().is_some()); + + Ok(()) +} + +// Helper function to test From for Arc implementations +fn test_from_impl> + Default>(expected_file_type: &str) { + let source = T::default(); + let file_source: Arc = source.into(); + assert_eq!(file_source.file_type(), expected_file_type); +} + +#[test] +fn test_from_implementations() { + // Test From implementation for various sources + test_from_impl::("arrow"); + + #[cfg(feature = "parquet")] + test_from_impl::("parquet"); + + #[cfg(feature = "csv")] + test_from_impl::("csv"); + + #[cfg(feature = "json")] + test_from_impl::("json"); +} + +/// A simple test schema adapter factory that doesn't modify the schema +#[derive(Debug)] +struct TestSchemaAdapterFactory {} + +impl SchemaAdapterFactory for TestSchemaAdapterFactory { + fn create(&self, schema: &Schema) -> Result> { + Ok(Box::new(TestSchemaAdapter { + input_schema: Arc::new(schema.clone()), + })) + } +} + +/// A test schema adapter that passes through data unmodified +#[derive(Debug)] +struct TestSchemaAdapter { + input_schema: SchemaRef, +} + +impl SchemaAdapter for TestSchemaAdapter { + fn adapt(&self, record_batch: RecordBatch) -> Result { + // Just pass through the batch unmodified + Ok(record_batch) + } + + fn output_schema(&self) -> SchemaRef { + self.input_schema.clone() + } +} + +#[cfg(feature = "parquet")] +#[test] +fn test_schema_adapter_preservation() { + // Create a test schema + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + // Create source with schema adapter factory + let source = ParquetSource::default(); + let factory = Arc::new(TestSchemaAdapterFactory {}); + let file_source = source.with_schema_adapter_factory(factory); + + // Create a FileScanConfig with the source + let config_builder = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), schema.clone()) + .with_source(file_source.clone()) + // Add a file to make it valid + .with_file(PartitionedFile::new("test.parquet", 100)); + + let config = config_builder.build(); + + // Verify the schema adapter factory is present in the file source + assert!(config.source().schema_adapter_factory().is_some()); +} + + +/// A test source for testing schema adapters +#[derive(Debug, Clone)] +struct TestSource { + schema_adapter_factory: Option>, +} + +impl TestSource { + fn new() -> Self { + Self { + schema_adapter_factory: None, + } + } +} + +impl FileSource for TestSource { + fn file_type(&self) -> &str { + "test" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn create_file_opener( + &self, + _store: Arc, + _conf: &FileScanConfig, + _index: usize, + ) -> Arc { + unimplemented!("Not needed for this test") + } + + fn with_batch_size(&self, _batch_size: usize) -> Arc { + Arc::new(self.clone()) + } + + fn with_schema(&self, _schema: SchemaRef) -> Arc { + Arc::new(self.clone()) + } + + fn with_projection(&self, _projection: &FileScanConfig) -> Arc { + Arc::new(self.clone()) + } + + fn with_statistics(&self, _statistics: Statistics) -> Arc { + Arc::new(self.clone()) + } + + fn metrics(&self) -> &ExecutionPlanMetricsSet { + unimplemented!("Not needed for this test") + } + + fn statistics(&self) -> Result { + Ok(Statistics::default()) + } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } +} + +/// A test schema adapter factory +#[derive(Debug)] +struct TestSchemaAdapterFactory {} + +impl SchemaAdapterFactory for TestSchemaAdapterFactory { + fn create( + &self, + projected_table_schema: SchemaRef, + _table_schema: SchemaRef, + ) -> Box { + Box::new(TestSchemaAdapter { + table_schema: projected_table_schema, + }) + } +} + +/// A test schema adapter implementation +#[derive(Debug)] +struct TestSchemaAdapter { + table_schema: SchemaRef, +} + +impl SchemaAdapter for TestSchemaAdapter { + fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { + let field = self.table_schema.field(index); + file_schema.fields.find(field.name()).map(|(i, _)| i) + } + + fn map_schema( + &self, + file_schema: &Schema, + ) -> Result<(Arc, Vec)> { + let mut projection = Vec::with_capacity(file_schema.fields().len()); + for (file_idx, file_field) in file_schema.fields().iter().enumerate() { + if self.table_schema.fields().find(file_field.name()).is_some() { + projection.push(file_idx); + } + } + + Ok((Arc::new(TestSchemaMapping {}), projection)) + } +} + +/// A test schema mapper implementation +#[derive(Debug)] +struct TestSchemaMapping {} + +impl SchemaMapper for TestSchemaMapping { + fn map_batch(&self, batch: RecordBatch) -> Result { + // For testing, just return the original batch + Ok(batch) + } + + fn map_column_statistics( + &self, + stats: &[ColumnStatistics], + ) -> Result> { + // For testing, just return the input statistics + Ok(stats.to_vec()) + } +} + +#[test] +fn test_schema_adapter() { + // This test verifies the functionality of the SchemaAdapter and SchemaAdapterFactory + // components used in DataFusion's file sources. + // + // The test specifically checks: + // 1. Creating and attaching a schema adapter factory to a file source + // 2. Creating a schema adapter using the factory + // 3. The schema adapter's ability to map column indices between a table schema and a file schema + // 4. The schema adapter's ability to create a projection that selects only the columns + // from the file schema that are present in the table schema + // + // Schema adapters are used when the schema of data in files doesn't exactly match + // the schema expected by the query engine, allowing for field mapping and data transformation. + + // Create a test schema + let table_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + // Create a file schema + let file_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new("extra", DataType::Int64, true), + ]); + + // Create a TestSource + let source = TestSource::new(); + assert!(source.schema_adapter_factory().is_none()); + + // Add a schema adapter factory + let factory = Arc::new(TestSchemaAdapterFactory {}); + let source_with_adapter = source.with_schema_adapter_factory(factory).unwrap(); + assert!(source_with_adapter.schema_adapter_factory().is_some()); + + // Create a schema adapter + let adapter_factory = source_with_adapter.schema_adapter_factory().unwrap(); + let adapter = + adapter_factory.create(Arc::clone(&table_schema), Arc::clone(&table_schema)); + + // Test mapping column index + assert_eq!(adapter.map_column_index(0, &file_schema), Some(0)); + assert_eq!(adapter.map_column_index(1, &file_schema), Some(1)); + + // Test creating schema mapper + let (_mapper, projection) = adapter.map_schema(&file_schema).unwrap(); + assert_eq!(projection, vec![0, 1]); +} diff --git a/datafusion/core/tests/macro_hygiene/mod.rs b/datafusion/core/tests/macro_hygiene/mod.rs index 9196efec972c..e5396ce2194e 100644 --- a/datafusion/core/tests/macro_hygiene/mod.rs +++ b/datafusion/core/tests/macro_hygiene/mod.rs @@ -65,3 +65,40 @@ mod config_namespace { } } } + +mod config_field { + // NO other imports! + use datafusion_common::config_field; + + #[test] + fn test_macro() { + #[derive(Debug)] + struct E; + + impl std::fmt::Display for E { + fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + unimplemented!() + } + } + + impl std::error::Error for E {} + + struct S; + + impl std::str::FromStr for S { + type Err = E; + + fn from_str(_s: &str) -> Result { + unimplemented!() + } + } + + impl std::fmt::Display for S { + fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + unimplemented!() + } + } + + config_field!(S); + } +} diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 01342d1604fc..2b262d4326cc 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -28,10 +28,10 @@ use arrow::compute::SortOptions; use arrow::datatypes::{Int32Type, SchemaRef}; use arrow_schema::{DataType, Field, Schema}; use datafusion::assert_batches_eq; +use datafusion::config::SpillCompression; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; use datafusion::datasource::{MemTable, TableProvider}; -use datafusion::execution::disk_manager::DiskManagerConfig; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; @@ -41,11 +41,12 @@ use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_catalog::streaming::StreamingTable; use datafusion_catalog::Session; use datafusion_common::{assert_contains, Result}; +use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; use datafusion_execution::memory_pool::{ FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool, }; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_execution::{DiskManager, TaskContext}; +use datafusion_execution::TaskContext; use datafusion_expr::{Expr, TableType}; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_optimizer::join_selection::JoinSelection; @@ -84,7 +85,7 @@ async fn group_by_none() { TestCase::new() .with_query("select median(request_bytes) from t") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: AggregateStream" + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n AggregateStream" ]) .with_memory_limit(2_000) .run() @@ -96,7 +97,7 @@ async fn group_by_row_hash() { TestCase::new() .with_query("select count(*) from t GROUP BY response_bytes") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: GroupedHashAggregateStream" + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n GroupedHashAggregateStream" ]) .with_memory_limit(2_000) .run() @@ -109,7 +110,7 @@ async fn group_by_hash() { // group by dict column .with_query("select count(*) from t GROUP BY service, host, pod, container") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: GroupedHashAggregateStream" + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n GroupedHashAggregateStream" ]) .with_memory_limit(1_000) .run() @@ -122,7 +123,7 @@ async fn join_by_key_multiple_partitions() { TestCase::new() .with_query("select t1.* from t t1 JOIN t t2 ON t1.service = t2.service") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n HashJoinInput", ]) .with_memory_limit(1_000) .with_config(config) @@ -136,7 +137,7 @@ async fn join_by_key_single_partition() { TestCase::new() .with_query("select t1.* from t t1 JOIN t t2 ON t1.service = t2.service") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n HashJoinInput", ]) .with_memory_limit(1_000) .with_config(config) @@ -149,7 +150,7 @@ async fn join_by_expression() { TestCase::new() .with_query("select t1.* from t t1 JOIN t t2 ON t1.service != t2.service") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: NestedLoopJoinLoad[0]", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n NestedLoopJoinLoad[0]", ]) .with_memory_limit(1_000) .run() @@ -161,7 +162,7 @@ async fn cross_join() { TestCase::new() .with_query("select t1.*, t2.* from t t1 CROSS JOIN t t2") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: CrossJoinExec", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n CrossJoinExec", ]) .with_memory_limit(1_000) .run() @@ -204,7 +205,7 @@ async fn sort_merge_join_spill() { ) .with_memory_limit(1_000) .with_config(config) - .with_disk_manager_config(DiskManagerConfig::NewOs) + .with_disk_manager_builder(DiskManagerBuilder::default()) .with_scenario(Scenario::AccessLogStreaming) .run() .await @@ -217,7 +218,7 @@ async fn symmetric_hash_join() { "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", ) .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: SymmetricHashJoinStream", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n SymmetricHashJoinStream", ]) .with_memory_limit(1_000) .with_scenario(Scenario::AccessLogStreaming) @@ -235,7 +236,7 @@ async fn sort_preserving_merge() { // so only a merge is needed .with_query("select * from t ORDER BY a ASC NULLS LAST, b ASC NULLS LAST LIMIT 10") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: SortPreservingMergeExec", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n SortPreservingMergeExec", ]) // provide insufficient memory to merge .with_memory_limit(partition_size / 2) @@ -288,7 +289,7 @@ async fn sort_spill_reservation() { .with_memory_limit(mem_limit) // use a single partition so only a sort is needed .with_scenario(scenario) - .with_disk_manager_config(DiskManagerConfig::NewOs) + .with_disk_manager_builder(DiskManagerBuilder::default()) .with_expected_plan( // It is important that this plan only has a SortExec, not // also merge, so we can ensure the sort could finish @@ -315,7 +316,7 @@ async fn sort_spill_reservation() { test.clone() .with_expected_errors(vec![ "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:", - "bytes for ExternalSorterMerge", + "B for ExternalSorterMerge", ]) .with_config(config) .run() @@ -344,7 +345,7 @@ async fn oom_recursive_cte() { SELECT * FROM nodes;", ) .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: RecursiveQuery", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n RecursiveQuery", ]) .with_memory_limit(2_000) .run() @@ -354,7 +355,7 @@ async fn oom_recursive_cte() { #[tokio::test] async fn oom_parquet_sink() { let dir = tempfile::tempdir().unwrap(); - let path = dir.into_path().join("test.parquet"); + let path = dir.path().join("test.parquet"); let _ = File::create(path.clone()).await.unwrap(); TestCase::new() @@ -378,7 +379,7 @@ async fn oom_parquet_sink() { #[tokio::test] async fn oom_with_tracked_consumer_pool() { let dir = tempfile::tempdir().unwrap(); - let path = dir.into_path().join("test.parquet"); + let path = dir.path().join("test.parquet"); let _ = File::create(path.clone()).await.unwrap(); TestCase::new() @@ -396,7 +397,7 @@ async fn oom_with_tracked_consumer_pool() { .with_expected_errors(vec![ "Failed to allocate additional", "for ParquetSink(ArrowColumnWriter)", - "Additional allocation failed with top memory consumers (across reservations) as: ParquetSink(ArrowColumnWriter)" + "Additional allocation failed with top memory consumers (across reservations) as:\n ParquetSink(ArrowColumnWriter)" ]) .with_memory_pool(Arc::new( TrackConsumersPool::new( @@ -408,6 +409,19 @@ async fn oom_with_tracked_consumer_pool() { .await } +#[tokio::test] +async fn oom_grouped_hash_aggregate() { + TestCase::new() + .with_query("SELECT COUNT(*), SUM(request_bytes) FROM t GROUP BY host") + .with_expected_errors(vec![ + "Failed to allocate additional", + "GroupedHashAggregateStream[0] (count(1), sum(t.request_bytes))", + ]) + .with_memory_limit(1_000) + .run() + .await +} + /// For regression case: if spilled `StringViewArray`'s buffer will be referenced by /// other batches which are also need to be spilled, then the spill writer will /// repeatedly write out the same buffer, and after reading back, each batch's size @@ -417,7 +431,7 @@ async fn oom_with_tracked_consumer_pool() { /// If there is memory explosion for spilled record batch, this test will fail. #[tokio::test] async fn test_stringview_external_sort() { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let array_length = 1000; let num_batches = 200; // Batches contain two columns: random 100-byte string, and random i32 @@ -427,7 +441,7 @@ async fn test_stringview_external_sort() { let strings: Vec = (0..array_length) .map(|_| { (0..100) - .map(|_| rng.gen_range(0..=u8::MAX) as char) + .map(|_| rng.random_range(0..=u8::MAX) as char) .collect() }) .collect(); @@ -435,8 +449,9 @@ async fn test_stringview_external_sort() { let string_array = StringViewArray::from(strings); let array_ref: ArrayRef = Arc::new(string_array); - let random_numbers: Vec = - (0..array_length).map(|_| rng.gen_range(0..=1000)).collect(); + let random_numbers: Vec = (0..array_length) + .map(|_| rng.random_range(0..=1000)) + .collect(); let int_array = Int32Array::from(random_numbers); let int_array_ref: ArrayRef = Arc::new(int_array); @@ -458,7 +473,9 @@ async fn test_stringview_external_sort() { .with_memory_pool(Arc::new(FairSpillPool::new(60 * 1024 * 1024))); let runtime = builder.build_arc().unwrap(); - let config = SessionConfig::new().with_sort_spill_reservation_bytes(40 * 1024 * 1024); + let config = SessionConfig::new() + .with_sort_spill_reservation_bytes(40 * 1024 * 1024) + .with_repartition_file_scans(false); let ctx = SessionContext::new_with_config_rt(config, runtime); ctx.register_table("t", Arc::new(table)).unwrap(); @@ -529,16 +546,16 @@ async fn test_external_sort_zero_merge_reservation() { // Tests for disk limit (`max_temp_directory_size` in `DiskManager`) // ------------------------------------------------------------------ -// Create a new `SessionContext` with speicified disk limit and memory pool limit +// Create a new `SessionContext` with speicified disk limit, memory pool limit, and spill compression codec async fn setup_context( disk_limit: u64, memory_pool_limit: usize, + spill_compression: SpillCompression, ) -> Result { - let disk_manager = DiskManager::try_new(DiskManagerConfig::NewOs)?; - - let disk_manager = Arc::try_unwrap(disk_manager) - .expect("DiskManager should be a single instance") - .with_max_temp_directory_size(disk_limit)?; + let disk_manager = DiskManagerBuilder::default() + .with_mode(DiskManagerMode::OsTmpDirectory) + .with_max_temp_directory_size(disk_limit) + .build()?; let runtime = RuntimeEnvBuilder::new() .with_memory_pool(Arc::new(FairSpillPool::new(memory_pool_limit))) @@ -555,6 +572,7 @@ async fn setup_context( let config = SessionConfig::new() .with_sort_spill_reservation_bytes(64 * 1024) // 256KB .with_sort_in_place_threshold_bytes(0) + .with_spill_compression(spill_compression) .with_batch_size(64) // To reduce test memory usage .with_target_partitions(1); @@ -565,7 +583,8 @@ async fn setup_context( /// (specified by `max_temp_directory_size` in `DiskManager`) #[tokio::test] async fn test_disk_spill_limit_reached() -> Result<()> { - let ctx = setup_context(1024 * 1024, 1024 * 1024).await?; // 1MB disk limit, 1MB memory limit + let spill_compression = SpillCompression::Uncompressed; + let ctx = setup_context(1024 * 1024, 1024 * 1024, spill_compression).await?; // 1MB disk limit, 1MB memory limit let df = ctx .sql("select * from generate_series(1, 1000000000000) as t1(v1) order by v1") @@ -587,7 +606,8 @@ async fn test_disk_spill_limit_reached() -> Result<()> { #[tokio::test] async fn test_disk_spill_limit_not_reached() -> Result<()> { let disk_spill_limit = 1024 * 1024; // 1MB - let ctx = setup_context(disk_spill_limit, 128 * 1024).await?; // 1MB disk limit, 128KB memory limit + let spill_compression = SpillCompression::Uncompressed; + let ctx = setup_context(disk_spill_limit, 128 * 1024, spill_compression).await?; // 1MB disk limit, 128KB memory limit let df = ctx .sql("select * from generate_series(1, 10000) as t1(v1) order by v1") @@ -603,7 +623,43 @@ async fn test_disk_spill_limit_not_reached() -> Result<()> { let spill_count = plan.metrics().unwrap().spill_count().unwrap(); let spilled_bytes = plan.metrics().unwrap().spilled_bytes().unwrap(); - println!("spill count {}, spill bytes {}", spill_count, spilled_bytes); + println!("spill count {spill_count}, spill bytes {spilled_bytes}"); + assert!(spill_count > 0); + assert!((spilled_bytes as u64) < disk_spill_limit); + + // Verify that all temporary files have been properly cleaned up by checking + // that the total disk usage tracked by the disk manager is zero + let current_disk_usage = ctx.runtime_env().disk_manager.used_disk_space(); + assert_eq!(current_disk_usage, 0); + + Ok(()) +} + +/// External query should succeed using zstd as spill compression codec and +/// and all temporary spill files are properly cleaned up after execution. +/// Note: This test does not inspect file contents (e.g. magic number), +/// as spill files are automatically deleted on drop. +#[tokio::test] +async fn test_spill_file_compressed_with_zstd() -> Result<()> { + let disk_spill_limit = 1024 * 1024; // 1MB + let spill_compression = SpillCompression::Zstd; + let ctx = setup_context(disk_spill_limit, 128 * 1024, spill_compression).await?; // 1MB disk limit, 128KB memory limit, zstd + + let df = ctx + .sql("select * from generate_series(1, 100000) as t1(v1) order by v1") + .await + .unwrap(); + let plan = df.create_physical_plan().await.unwrap(); + + let task_ctx = ctx.task_ctx(); + let _ = collect_batches(Arc::clone(&plan), task_ctx) + .await + .expect("Query execution failed"); + + let spill_count = plan.metrics().unwrap().spill_count().unwrap(); + let spilled_bytes = plan.metrics().unwrap().spilled_bytes().unwrap(); + + println!("spill count {spill_count}"); assert!(spill_count > 0); assert!((spilled_bytes as u64) < disk_spill_limit); @@ -615,6 +671,41 @@ async fn test_disk_spill_limit_not_reached() -> Result<()> { Ok(()) } +/// External query should succeed using lz4_frame as spill compression codec and +/// and all temporary spill files are properly cleaned up after execution. +/// Note: This test does not inspect file contents (e.g. magic number), +/// as spill files are automatically deleted on drop. +#[tokio::test] +async fn test_spill_file_compressed_with_lz4_frame() -> Result<()> { + let disk_spill_limit = 1024 * 1024; // 1MB + let spill_compression = SpillCompression::Lz4Frame; + let ctx = setup_context(disk_spill_limit, 128 * 1024, spill_compression).await?; // 1MB disk limit, 128KB memory limit, lz4_frame + + let df = ctx + .sql("select * from generate_series(1, 100000) as t1(v1) order by v1") + .await + .unwrap(); + let plan = df.create_physical_plan().await.unwrap(); + + let task_ctx = ctx.task_ctx(); + let _ = collect_batches(Arc::clone(&plan), task_ctx) + .await + .expect("Query execution failed"); + + let spill_count = plan.metrics().unwrap().spill_count().unwrap(); + let spilled_bytes = plan.metrics().unwrap().spilled_bytes().unwrap(); + + println!("spill count {spill_count}"); + assert!(spill_count > 0); + assert!((spilled_bytes as u64) < disk_spill_limit); + + // Verify that all temporary files have been properly cleaned up by checking + // that the total disk usage tracked by the disk manager is zero + let current_disk_usage = ctx.runtime_env().disk_manager.used_disk_space(); + assert_eq!(current_disk_usage, 0); + + Ok(()) +} /// Run the query with the specified memory limit, /// and verifies the expected errors are returned #[derive(Clone, Debug)] @@ -627,7 +718,7 @@ struct TestCase { scenario: Scenario, /// How should the disk manager (that allows spilling) be /// configured? Defaults to `Disabled` - disk_manager_config: DiskManagerConfig, + disk_manager_builder: DiskManagerBuilder, /// Expected explain plan, if non-empty expected_plan: Vec, /// Is the plan expected to pass? Defaults to false @@ -643,7 +734,8 @@ impl TestCase { config: SessionConfig::new(), memory_pool: None, scenario: Scenario::AccessLog, - disk_manager_config: DiskManagerConfig::Disabled, + disk_manager_builder: DiskManagerBuilder::default() + .with_mode(DiskManagerMode::Disabled), expected_plan: vec![], expected_success: false, } @@ -700,11 +792,11 @@ impl TestCase { /// Specify if the disk manager should be enabled. If true, /// operators that support it can spill - pub fn with_disk_manager_config( + pub fn with_disk_manager_builder( mut self, - disk_manager_config: DiskManagerConfig, + disk_manager_builder: DiskManagerBuilder, ) -> Self { - self.disk_manager_config = disk_manager_config; + self.disk_manager_builder = disk_manager_builder; self } @@ -723,7 +815,7 @@ impl TestCase { memory_pool, config, scenario, - disk_manager_config, + disk_manager_builder, expected_plan, expected_success, } = self; @@ -732,7 +824,7 @@ impl TestCase { let mut builder = RuntimeEnvBuilder::new() // disk manager setting controls the spilling - .with_disk_manager(disk_manager_config) + .with_disk_manager_builder(disk_manager_builder) .with_memory_limit(memory_limit, MEMORY_FRACTION); if let Some(pool) = memory_pool { @@ -874,16 +966,11 @@ impl Scenario { descending: false, nulls_first: false, }; - let sort_information = vec![LexOrdering::new(vec![ - PhysicalSortExpr { - expr: col("a", &schema).unwrap(), - options, - }, - PhysicalSortExpr { - expr: col("b", &schema).unwrap(), - options, - }, - ])]; + let sort_information = vec![[ + PhysicalSortExpr::new(col("a", &schema).unwrap(), options), + PhysicalSortExpr::new(col("b", &schema).unwrap(), options), + ] + .into()]; let table = SortedTableProvider::new(batches, sort_information); Arc::new(table) diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index 585540bd5875..3b39c9adfa32 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -18,6 +18,7 @@ //! Tests for the DataFusion SQL query planner that require functions from the //! datafusion-functions crate. +use insta::assert_snapshot; use std::any::Any; use std::collections::HashMap; use std::sync::Arc; @@ -56,9 +57,14 @@ fn init() { #[test] fn select_arrow_cast() { let sql = "SELECT arrow_cast(1234, 'Float64') as f64, arrow_cast('foo', 'LargeUtf8') as large"; - let expected = "Projection: Float64(1234) AS f64, LargeUtf8(\"foo\") AS large\ - \n EmptyRelation"; - quick_test(sql, expected); + let plan = test_sql(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: Float64(1234) AS f64, LargeUtf8("foo") AS large + EmptyRelation + "# + ); } #[test] fn timestamp_nano_ts_none_predicates() -> Result<()> { @@ -68,11 +74,15 @@ fn timestamp_nano_ts_none_predicates() -> Result<()> { // a scan should have the now()... predicate folded to a single // constant and compared to the column without a cast so it can be // pushed down / pruned - let expected = - "Projection: test.col_int32\ - \n Filter: test.col_ts_nano_none < TimestampNanosecond(1666612093000000000, None)\ - \n TableScan: test projection=[col_int32, col_ts_nano_none]"; - quick_test(sql, expected); + let plan = test_sql(sql).unwrap(); + assert_snapshot!( + plan, + @r" + Projection: test.col_int32 + Filter: test.col_ts_nano_none < TimestampNanosecond(1666612093000000000, None) + TableScan: test projection=[col_int32, col_ts_nano_none] + " + ); Ok(()) } @@ -84,10 +94,15 @@ fn timestamp_nano_ts_utc_predicates() { // a scan should have the now()... predicate folded to a single // constant and compared to the column without a cast so it can be // pushed down / pruned - let expected = - "Projection: test.col_int32\n Filter: test.col_ts_nano_utc < TimestampNanosecond(1666612093000000000, Some(\"+00:00\"))\ - \n TableScan: test projection=[col_int32, col_ts_nano_utc]"; - quick_test(sql, expected); + let plan = test_sql(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: test.col_int32 + Filter: test.col_ts_nano_utc < TimestampNanosecond(1666612093000000000, Some("+00:00")) + TableScan: test projection=[col_int32, col_ts_nano_utc] + "# + ); } #[test] @@ -95,10 +110,14 @@ fn concat_literals() -> Result<()> { let sql = "SELECT concat(true, col_int32, false, null, 'hello', col_utf8, 12, 3.4) \ AS col FROM test"; - let expected = - "Projection: concat(Utf8(\"true\"), CAST(test.col_int32 AS Utf8), Utf8(\"falsehello\"), test.col_utf8, Utf8(\"123.4\")) AS col\ - \n TableScan: test projection=[col_int32, col_utf8]"; - quick_test(sql, expected); + let plan = test_sql(sql).unwrap(); + assert_snapshot!( + plan, + @r#" + Projection: concat(Utf8("true"), CAST(test.col_int32 AS Utf8), Utf8("falsehello"), test.col_utf8, Utf8("123.4")) AS col + TableScan: test projection=[col_int32, col_utf8] + "# + ); Ok(()) } @@ -107,16 +126,15 @@ fn concat_ws_literals() -> Result<()> { let sql = "SELECT concat_ws('-', true, col_int32, false, null, 'hello', col_utf8, 12, '', 3.4) \ AS col FROM test"; - let expected = - "Projection: concat_ws(Utf8(\"-\"), Utf8(\"true\"), CAST(test.col_int32 AS Utf8), Utf8(\"false-hello\"), test.col_utf8, Utf8(\"12--3.4\")) AS col\ - \n TableScan: test projection=[col_int32, col_utf8]"; - quick_test(sql, expected); - Ok(()) -} - -fn quick_test(sql: &str, expected_plan: &str) { let plan = test_sql(sql).unwrap(); - assert_eq!(expected_plan, format!("{}", plan)); + assert_snapshot!( + plan, + @r#" + Projection: concat_ws(Utf8("-"), Utf8("true"), CAST(test.col_int32 AS Utf8), Utf8("false-hello"), test.col_utf8, Utf8("12--3.4")) AS col + TableScan: test projection=[col_int32, col_utf8] + "# + ); + Ok(()) } fn test_sql(sql: &str) -> Result { @@ -342,8 +360,7 @@ where let expected = lit(ScalarValue::from(expected_value.clone())); assert_eq!( output, expected, - "{} simplified to {}, but expected {}", - expr, output, expected + "{expr} simplified to {output}, but expected {expected}" ); } } @@ -352,8 +369,7 @@ fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { let output = expr.clone().rewrite(rewriter).data().unwrap(); assert_eq!( &output, expr, - "{} was simplified to {}, but expected it to be unchanged", - expr, output + "{expr} was simplified to {output}, but expected it to be unchanged" ); } } diff --git a/datafusion/core/tests/parquet/custom_reader.rs b/datafusion/core/tests/parquet/custom_reader.rs index 761a78a29fd3..5fc3513ff745 100644 --- a/datafusion/core/tests/parquet/custom_reader.rs +++ b/datafusion/core/tests/parquet/custom_reader.rs @@ -241,6 +241,7 @@ impl AsyncFileReader for ParquetFileReader { self.store.as_ref(), &self.meta, self.metadata_size_hint, + None, ) .await .map_err(|e| { diff --git a/datafusion/core/tests/parquet/encryption.rs b/datafusion/core/tests/parquet/encryption.rs new file mode 100644 index 000000000000..203c985428bc --- /dev/null +++ b/datafusion/core/tests/parquet/encryption.rs @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! non trivial integration testing for parquet predicate pushdown +//! +//! Testing hints: If you run this test with --nocapture it will tell you where +//! the generated parquet file went. You can then test it and try out various queries +//! datafusion-cli like: +//! +//! ```sql +//! create external table data stored as parquet location 'data.parquet'; +//! select * from data limit 10; +//! ``` + +use arrow::record_batch::RecordBatch; +use datafusion::prelude::{ParquetReadOptions, SessionContext}; +use std::fs::File; +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +use parquet::arrow::ArrowWriter; +use parquet::encryption::decrypt::FileDecryptionProperties; +use parquet::encryption::encrypt::FileEncryptionProperties; +use parquet::file::properties::WriterProperties; +use tempfile::TempDir; + +async fn read_parquet_test_data<'a, T: Into>( + path: T, + ctx: &SessionContext, + options: ParquetReadOptions<'a>, +) -> Vec { + ctx.read_parquet(path.into(), options) + .await + .unwrap() + .collect() + .await + .unwrap() +} + +pub fn write_batches( + path: PathBuf, + props: WriterProperties, + batches: impl IntoIterator, +) -> datafusion_common::Result { + let mut batches = batches.into_iter(); + let first_batch = batches.next().expect("need at least one record batch"); + let schema = first_batch.schema(); + + let file = File::create(&path)?; + let mut writer = ArrowWriter::try_new(file, Arc::clone(&schema), Some(props))?; + + writer.write(&first_batch)?; + let mut num_rows = first_batch.num_rows(); + + for batch in batches { + writer.write(&batch)?; + num_rows += batch.num_rows(); + } + writer.close()?; + Ok(num_rows) +} + +#[tokio::test] +async fn round_trip_encryption() { + let ctx: SessionContext = SessionContext::new(); + + let options = ParquetReadOptions::default(); + let batches = read_parquet_test_data( + "tests/data/filter_pushdown/single_file.gz.parquet", + &ctx, + options, + ) + .await; + + let schema = batches[0].schema(); + let footer_key = b"0123456789012345".to_vec(); // 128bit/16 + let column_key = b"1234567890123450".to_vec(); // 128bit/16 + + let mut encrypt = FileEncryptionProperties::builder(footer_key.clone()); + let mut decrypt = FileDecryptionProperties::builder(footer_key.clone()); + + for field in schema.fields.iter() { + encrypt = encrypt.with_column_key(field.name().as_str(), column_key.clone()); + decrypt = decrypt.with_column_key(field.name().as_str(), column_key.clone()); + } + let encrypt = encrypt.build().unwrap(); + let decrypt = decrypt.build().unwrap(); + + // Write encrypted parquet + let props = WriterProperties::builder() + .with_file_encryption_properties(encrypt) + .build(); + + let tempdir = TempDir::new_in(Path::new(".")).unwrap(); + let tempfile = tempdir.path().join("data.parquet"); + let num_rows_written = write_batches(tempfile.clone(), props, batches).unwrap(); + + // Read encrypted parquet + let ctx: SessionContext = SessionContext::new(); + let options = + ParquetReadOptions::default().file_decryption_properties((&decrypt).into()); + + let encrypted_batches = read_parquet_test_data( + tempfile.into_os_string().into_string().unwrap(), + &ctx, + options, + ) + .await; + + let num_rows_read = encrypted_batches + .iter() + .fold(0, |acc, x| acc + x.num_rows()); + + assert_eq!(num_rows_written, num_rows_read); +} diff --git a/datafusion/core/tests/parquet/external_access_plan.rs b/datafusion/core/tests/parquet/external_access_plan.rs index bbef073345b7..a5397c5a397c 100644 --- a/datafusion/core/tests/parquet/external_access_plan.rs +++ b/datafusion/core/tests/parquet/external_access_plan.rs @@ -346,7 +346,7 @@ impl TestFull { let source = if let Some(predicate) = predicate { let df_schema = DFSchema::try_from(schema.clone())?; let predicate = ctx.create_physical_expr(predicate, &df_schema)?; - Arc::new(ParquetSource::default().with_predicate(schema.clone(), predicate)) + Arc::new(ParquetSource::default().with_predicate(predicate)) } else { Arc::new(ParquetSource::default()) }; diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index 7e98ebed6c9a..a60beaf665e5 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -28,6 +28,7 @@ use datafusion::execution::context::SessionState; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::SessionContext; use datafusion_common::stats::Precision; +use datafusion_common::DFSchema; use datafusion_execution::cache::cache_manager::CacheManagerConfig; use datafusion_execution::cache::cache_unit::{ DefaultFileStatisticsCache, DefaultListFilesCache, @@ -37,6 +38,10 @@ use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::{col, lit, Expr}; use datafusion::datasource::physical_plan::FileScanConfig; +use datafusion_physical_optimizer::filter_pushdown::FilterPushdown; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::ExecutionPlan; use tempfile::tempdir; #[tokio::test] @@ -45,18 +50,53 @@ async fn check_stats_precision_with_filter_pushdown() { let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); let table_path = ListingTableUrl::parse(filename).unwrap(); - let opt = ListingOptions::new(Arc::new(ParquetFormat::default())); + let opt = + ListingOptions::new(Arc::new(ParquetFormat::default())).with_collect_stat(true); let table = get_listing_table(&table_path, None, &opt).await; + let (_, _, state) = get_cache_runtime_state(); + let mut options = state.config().options().clone(); + options.execution.parquet.pushdown_filters = true; + // Scan without filter, stats are exact let exec = table.scan(&state, None, &[], None).await.unwrap(); - assert_eq!(exec.statistics().unwrap().num_rows, Precision::Exact(8)); + assert_eq!( + exec.partition_statistics(None).unwrap().num_rows, + Precision::Exact(8), + "Stats without filter should be exact" + ); - // Scan with filter pushdown, stats are inexact - let filter = Expr::gt(col("id"), lit(1)); + // This is a filter that cannot be evaluated by the table provider scanning + // (it is not a partition filter). Therefore; it will be pushed down to the + // source operator after the appropriate optimizer pass. + let filter_expr = Expr::gt(col("id"), lit(1)); + let exec_with_filter = table + .scan(&state, None, &[filter_expr.clone()], None) + .await + .unwrap(); + + let ctx = SessionContext::new(); + let df_schema = DFSchema::try_from(table.schema()).unwrap(); + let physical_filter = ctx.create_physical_expr(filter_expr, &df_schema).unwrap(); - let exec = table.scan(&state, None, &[filter], None).await.unwrap(); - assert_eq!(exec.statistics().unwrap().num_rows, Precision::Inexact(8)); + let filtered_exec = + Arc::new(FilterExec::try_new(physical_filter, exec_with_filter).unwrap()) + as Arc; + + let optimized_exec = FilterPushdown::new() + .optimize(filtered_exec, &options) + .unwrap(); + + assert!( + optimized_exec.as_any().is::(), + "Sanity check that the pushdown did what we expected" + ); + // Scan with filter pushdown, stats are inexact + assert_eq!( + optimized_exec.partition_statistics(None).unwrap().num_rows, + Precision::Inexact(8), + "Stats after filter pushdown should be inexact" + ); } #[tokio::test] @@ -70,7 +110,8 @@ async fn load_table_stats_with_session_level_cache() { // Create a separate DefaultFileStatisticsCache let (cache2, _, state2) = get_cache_runtime_state(); - let opt = ListingOptions::new(Arc::new(ParquetFormat::default())); + let opt = + ListingOptions::new(Arc::new(ParquetFormat::default())).with_collect_stat(true); let table1 = get_listing_table(&table_path, Some(cache1), &opt).await; let table2 = get_listing_table(&table_path, Some(cache2), &opt).await; @@ -79,9 +120,12 @@ async fn load_table_stats_with_session_level_cache() { assert_eq!(get_static_cache_size(&state1), 0); let exec1 = table1.scan(&state1, None, &[], None).await.unwrap(); - assert_eq!(exec1.statistics().unwrap().num_rows, Precision::Exact(8)); assert_eq!( - exec1.statistics().unwrap().total_byte_size, + exec1.partition_statistics(None).unwrap().num_rows, + Precision::Exact(8) + ); + assert_eq!( + exec1.partition_statistics(None).unwrap().total_byte_size, // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 Precision::Exact(671), ); @@ -91,9 +135,12 @@ async fn load_table_stats_with_session_level_cache() { //check session 1 cache result not show in session 2 assert_eq!(get_static_cache_size(&state2), 0); let exec2 = table2.scan(&state2, None, &[], None).await.unwrap(); - assert_eq!(exec2.statistics().unwrap().num_rows, Precision::Exact(8)); assert_eq!( - exec2.statistics().unwrap().total_byte_size, + exec2.partition_statistics(None).unwrap().num_rows, + Precision::Exact(8) + ); + assert_eq!( + exec2.partition_statistics(None).unwrap().total_byte_size, // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 Precision::Exact(671), ); @@ -103,9 +150,12 @@ async fn load_table_stats_with_session_level_cache() { //check session 1 cache result not show in session 2 assert_eq!(get_static_cache_size(&state1), 1); let exec3 = table1.scan(&state1, None, &[], None).await.unwrap(); - assert_eq!(exec3.statistics().unwrap().num_rows, Precision::Exact(8)); assert_eq!( - exec3.statistics().unwrap().total_byte_size, + exec3.partition_statistics(None).unwrap().num_rows, + Precision::Exact(8) + ); + assert_eq!( + exec3.partition_statistics(None).unwrap().total_byte_size, // TODO correct byte size: https://github.com/apache/datafusion/issues/14936 Precision::Exact(671), ); @@ -117,23 +167,15 @@ async fn load_table_stats_with_session_level_cache() { async fn list_files_with_session_level_cache() { let p_name = "alltypes_plain.parquet"; let testdata = datafusion::test_util::parquet_test_data(); - let filename = format!("{}/{}", testdata, p_name); + let filename = format!("{testdata}/{p_name}"); - let temp_path1 = tempdir() - .unwrap() - .into_path() - .into_os_string() - .into_string() - .unwrap(); - let temp_filename1 = format!("{}/{}", temp_path1, p_name); + let temp_dir1 = tempdir().unwrap(); + let temp_path1 = temp_dir1.path().to_str().unwrap(); + let temp_filename1 = format!("{temp_path1}/{p_name}"); - let temp_path2 = tempdir() - .unwrap() - .into_path() - .into_os_string() - .into_string() - .unwrap(); - let temp_filename2 = format!("{}/{}", temp_path2, p_name); + let temp_dir2 = tempdir().unwrap(); + let temp_path2 = temp_dir2.path().to_str().unwrap(); + let temp_filename2 = format!("{temp_path2}/{p_name}"); fs::copy(filename.clone(), temp_filename1).expect("panic"); fs::copy(filename, temp_filename2).expect("panic"); diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index 02fb59740493..b8d570916c7c 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -32,50 +32,45 @@ use arrow::compute::concat_batches; use arrow::record_batch::RecordBatch; use datafusion::physical_plan::collect; use datafusion::physical_plan::metrics::MetricsSet; -use datafusion::prelude::{col, lit, lit_timestamp_nano, Expr, SessionContext}; +use datafusion::prelude::{ + col, lit, lit_timestamp_nano, Expr, ParquetReadOptions, SessionContext, +}; use datafusion::test_util::parquet::{ParquetScanOptions, TestParquetFile}; -use datafusion_common::instant::Instant; use datafusion_expr::utils::{conjunction, disjunction, split_conjunction}; use itertools::Itertools; use parquet::file::properties::WriterProperties; use tempfile::TempDir; -use test_utils::AccessLogGenerator; /// how many rows of generated data to write to our parquet file (arbitrary) const NUM_ROWS: usize = 4096; -fn generate_file(tempdir: &TempDir, props: WriterProperties) -> TestParquetFile { - // Tune down the generator for smaller files - let generator = AccessLogGenerator::new() - .with_row_limit(NUM_ROWS) - .with_pods_per_host(1..4) - .with_containers_per_pod(1..2) - .with_entries_per_container(128..256); - - let file = tempdir.path().join("data.parquet"); - - let start = Instant::now(); - println!("Writing test data to {file:?}"); - let test_parquet_file = TestParquetFile::try_new(file, props, generator).unwrap(); - println!( - "Completed generating test data in {:?}", - Instant::now() - start - ); - test_parquet_file +async fn read_parquet_test_data>(path: T) -> Vec { + let ctx: SessionContext = SessionContext::new(); + ctx.read_parquet(path.into(), ParquetReadOptions::default()) + .await + .unwrap() + .collect() + .await + .unwrap() } #[tokio::test] async fn single_file() { - // Only create the parquet file once as it is fairly large + let batches = + read_parquet_test_data("tests/data/filter_pushdown/single_file.gz.parquet").await; - let tempdir = TempDir::new_in(Path::new(".")).unwrap(); - // Set row group size smaller so can test with fewer rows + // Set the row group size smaller so can test with fewer rows let props = WriterProperties::builder() .set_max_row_group_size(1024) .build(); - let test_parquet_file = generate_file(&tempdir, props); + // Only create the parquet file once as it is fairly large + let tempdir = TempDir::new_in(Path::new(".")).unwrap(); + + let test_parquet_file = + TestParquetFile::try_new(tempdir.path().join("data.parquet"), props, batches) + .unwrap(); let case = TestCase::new(&test_parquet_file) .with_name("selective") // request_method = 'GET' @@ -224,16 +219,25 @@ async fn single_file() { } #[tokio::test] +#[allow(dead_code)] async fn single_file_small_data_pages() { + let batches = read_parquet_test_data( + "tests/data/filter_pushdown/single_file_small_pages.gz.parquet", + ) + .await; + let tempdir = TempDir::new_in(Path::new(".")).unwrap(); - // Set low row count limit to improve page filtering + // Set a low row count limit to improve page filtering let props = WriterProperties::builder() .set_max_row_group_size(2048) .set_data_page_row_count_limit(512) .set_write_batch_size(512) .build(); - let test_parquet_file = generate_file(&tempdir, props); + + let test_parquet_file = + TestParquetFile::try_new(tempdir.path().join("data.parquet"), props, batches) + .unwrap(); // The statistics on the 'pod' column are as follows: // diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 87a5ed33f127..94d6d152a384 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -43,6 +43,7 @@ use std::sync::Arc; use tempfile::NamedTempFile; mod custom_reader; +mod encryption; mod external_access_plan; mod file_statistics; mod filter_pushdown; @@ -152,6 +153,10 @@ impl TestOutput { self.metric_value("row_groups_pruned_statistics") } + fn files_ranges_pruned_statistics(&self) -> Option { + self.metric_value("files_ranges_pruned_statistics") + } + /// The number of row_groups matched by bloom filter or statistics fn row_groups_matched(&self) -> Option { self.row_groups_matched_bloom_filter() @@ -192,6 +197,8 @@ impl ContextWithParquet { unit: Unit, mut config: SessionConfig, ) -> Self { + // Use a single partition for deterministic results no matter how many CPUs the host has + config = config.with_target_partitions(1); let file = match unit { Unit::RowGroup(row_per_group) => { config = config.with_parquet_bloom_filter_pruning(true); diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index f693485cbe01..9da879a32f6b 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -77,7 +77,7 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> DataSourceExec let source = Arc::new( ParquetSource::default() - .with_predicate(Arc::clone(&schema), predicate) + .with_predicate(predicate) .with_enable_page_index(true), ); let base_config = FileScanConfigBuilder::new(object_store_url, schema, source) diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index 5a85f47c015a..8613cd481be1 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -31,6 +31,7 @@ struct RowGroupPruningTest { expected_errors: Option, expected_row_group_matched_by_statistics: Option, expected_row_group_pruned_by_statistics: Option, + expected_files_pruned_by_statistics: Option, expected_row_group_matched_by_bloom_filter: Option, expected_row_group_pruned_by_bloom_filter: Option, expected_results: usize, @@ -44,6 +45,7 @@ impl RowGroupPruningTest { expected_errors: None, expected_row_group_matched_by_statistics: None, expected_row_group_pruned_by_statistics: None, + expected_files_pruned_by_statistics: None, expected_row_group_matched_by_bloom_filter: None, expected_row_group_pruned_by_bloom_filter: None, expected_results: 0, @@ -80,6 +82,11 @@ impl RowGroupPruningTest { self } + fn with_pruned_files(mut self, pruned_files: Option) -> Self { + self.expected_files_pruned_by_statistics = pruned_files; + self + } + // Set the expected matched row groups by bloom filter fn with_matched_by_bloom_filter(mut self, matched_by_bf: Option) -> Self { self.expected_row_group_matched_by_bloom_filter = matched_by_bf; @@ -121,6 +128,11 @@ impl RowGroupPruningTest { self.expected_row_group_pruned_by_statistics, "mismatched row_groups_pruned_statistics", ); + assert_eq!( + output.files_ranges_pruned_statistics(), + self.expected_files_pruned_by_statistics, + "mismatched files_ranges_pruned_statistics", + ); assert_eq!( output.row_groups_matched_bloom_filter(), self.expected_row_group_matched_by_bloom_filter, @@ -148,6 +160,7 @@ async fn prune_timestamps_nanos() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) @@ -165,6 +178,7 @@ async fn prune_timestamps_micros() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) @@ -182,6 +196,7 @@ async fn prune_timestamps_millis() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) @@ -199,6 +214,7 @@ async fn prune_timestamps_seconds() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) @@ -214,6 +230,7 @@ async fn prune_date32() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) @@ -256,6 +273,7 @@ async fn prune_disabled() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) @@ -301,6 +319,7 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) @@ -315,6 +334,7 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) @@ -330,6 +350,7 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) @@ -344,6 +365,7 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) @@ -359,6 +381,7 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(3) @@ -374,6 +397,7 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -389,6 +413,7 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(9) @@ -405,6 +430,7 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) @@ -421,7 +447,8 @@ macro_rules! int_tests { .with_query(&format!("SELECT * FROM t where i{} in (100)", $bits)) .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) - .with_pruned_by_stats(Some(4)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(1)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(0) @@ -438,6 +465,7 @@ macro_rules! int_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(4)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(4)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(19) @@ -467,6 +495,7 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) @@ -482,6 +511,7 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) @@ -496,6 +526,7 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) @@ -511,6 +542,7 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -526,6 +558,7 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -542,6 +575,7 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) @@ -559,6 +593,7 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(4)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(0) @@ -575,6 +610,7 @@ macro_rules! uint_tests { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(4)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(4)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(19) @@ -604,6 +640,7 @@ async fn prune_int32_eq_large_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(0) @@ -626,6 +663,7 @@ async fn prune_uint32_eq_large_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(0) @@ -641,6 +679,7 @@ async fn prune_f64_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) @@ -652,6 +691,7 @@ async fn prune_f64_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) @@ -669,6 +709,7 @@ async fn prune_f64_scalar_fun_and_gt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) @@ -685,6 +726,7 @@ async fn prune_f64_scalar_fun() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) @@ -701,6 +743,7 @@ async fn prune_f64_complex_expr() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(9) @@ -717,6 +760,7 @@ async fn prune_f64_complex_expr_subtract() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(9) @@ -735,6 +779,7 @@ async fn prune_decimal_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) @@ -746,6 +791,7 @@ async fn prune_decimal_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(8) @@ -757,6 +803,7 @@ async fn prune_decimal_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) @@ -768,6 +815,7 @@ async fn prune_decimal_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(8) @@ -786,6 +834,7 @@ async fn prune_decimal_eq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -797,6 +846,7 @@ async fn prune_decimal_eq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -809,6 +859,7 @@ async fn prune_decimal_eq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -820,6 +871,7 @@ async fn prune_decimal_eq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -839,6 +891,7 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(5) @@ -850,6 +903,7 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) @@ -861,6 +915,7 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(5) @@ -872,6 +927,7 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) @@ -885,6 +941,7 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(1) @@ -898,6 +955,7 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(1) @@ -911,6 +969,7 @@ async fn prune_decimal_in_list() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(1) @@ -929,6 +988,7 @@ async fn prune_string_eq_match() { // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(1) @@ -947,6 +1007,7 @@ async fn prune_string_eq_no_match() { // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(0) @@ -963,6 +1024,7 @@ async fn prune_string_eq_no_match() { // false positive on 'mixed' batch: 'backend one' < 'frontend nine' < 'frontend six' .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(0) @@ -980,6 +1042,7 @@ async fn prune_string_neq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(3)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(14) @@ -998,6 +1061,7 @@ async fn prune_string_lt() { // matches 'all backends' only .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(3) @@ -1012,6 +1076,7 @@ async fn prune_string_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) // all backends from 'mixed' and 'all backends' @@ -1031,6 +1096,7 @@ async fn prune_binary_eq_match() { // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(1) @@ -1049,6 +1115,7 @@ async fn prune_binary_eq_no_match() { // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(0) @@ -1065,6 +1132,7 @@ async fn prune_binary_eq_no_match() { // false positive on 'mixed' batch: 'backend one' < 'frontend nine' < 'frontend six' .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(0) @@ -1082,6 +1150,7 @@ async fn prune_binary_neq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(3)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(14) @@ -1100,6 +1169,7 @@ async fn prune_binary_lt() { // matches 'all backends' only .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(3) @@ -1114,6 +1184,7 @@ async fn prune_binary_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) // all backends from 'mixed' and 'all backends' @@ -1133,6 +1204,7 @@ async fn prune_fixedsizebinary_eq_match() { // false positive on 'all frontends' batch: 'fe1' < 'fe6' < 'fe7' .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(1) @@ -1148,6 +1220,7 @@ async fn prune_fixedsizebinary_eq_match() { // false positive on 'all frontends' batch: 'fe1' < 'fe6' < 'fe7' .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(1) @@ -1166,6 +1239,7 @@ async fn prune_fixedsizebinary_eq_no_match() { // false positive on 'mixed' batch: 'be1' < 'be9' < 'fe4' .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(0) @@ -1183,6 +1257,7 @@ async fn prune_fixedsizebinary_neq() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(3)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(14) @@ -1201,6 +1276,7 @@ async fn prune_fixedsizebinary_lt() { // matches 'all backends' only .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -1215,6 +1291,7 @@ async fn prune_fixedsizebinary_lt() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) // all backends from 'mixed' and 'all backends' @@ -1235,6 +1312,7 @@ async fn prune_periods_in_column_names() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(7) @@ -1246,6 +1324,7 @@ async fn prune_periods_in_column_names() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(5) @@ -1257,6 +1336,7 @@ async fn prune_periods_in_column_names() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) @@ -1277,6 +1357,7 @@ async fn test_row_group_with_null_values() { .with_query("SELECT * FROM t WHERE \"i8\" <= 5") .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_pruned_by_stats(Some(2)) .with_expected_rows(5) .with_matched_by_bloom_filter(Some(0)) @@ -1290,6 +1371,7 @@ async fn test_row_group_with_null_values() { .with_query("SELECT * FROM t WHERE \"i8\" is Null") .with_expected_errors(Some(0)) .with_matched_by_stats(Some(2)) + .with_pruned_files(Some(0)) .with_pruned_by_stats(Some(1)) .with_expected_rows(10) .with_matched_by_bloom_filter(Some(0)) @@ -1303,6 +1385,7 @@ async fn test_row_group_with_null_values() { .with_query("SELECT * FROM t WHERE \"i16\" is Not Null") .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_pruned_by_stats(Some(2)) .with_expected_rows(5) .with_matched_by_bloom_filter(Some(0)) @@ -1316,7 +1399,8 @@ async fn test_row_group_with_null_values() { .with_query("SELECT * FROM t WHERE \"i32\" > 7") .with_expected_errors(Some(0)) .with_matched_by_stats(Some(0)) - .with_pruned_by_stats(Some(3)) + .with_pruned_by_stats(Some(0)) + .with_pruned_files(Some(1)) .with_expected_rows(0) .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) @@ -1332,6 +1416,7 @@ async fn test_bloom_filter_utf8_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1344,6 +1429,7 @@ async fn test_bloom_filter_utf8_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1356,6 +1442,7 @@ async fn test_bloom_filter_utf8_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1368,6 +1455,7 @@ async fn test_bloom_filter_utf8_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1383,6 +1471,7 @@ async fn test_bloom_filter_integer_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1395,6 +1484,7 @@ async fn test_bloom_filter_integer_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1407,6 +1497,7 @@ async fn test_bloom_filter_integer_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1419,6 +1510,7 @@ async fn test_bloom_filter_integer_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1434,6 +1526,7 @@ async fn test_bloom_filter_unsigned_integer_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1446,6 +1539,7 @@ async fn test_bloom_filter_unsigned_integer_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1461,6 +1555,7 @@ async fn test_bloom_filter_binary_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1473,6 +1568,7 @@ async fn test_bloom_filter_binary_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1485,6 +1581,7 @@ async fn test_bloom_filter_binary_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1499,6 +1596,7 @@ async fn test_bloom_filter_binary_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) @@ -1514,6 +1612,7 @@ async fn test_bloom_filter_decimal_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(1) .with_pruned_by_bloom_filter(Some(0)) .with_matched_by_bloom_filter(Some(1)) @@ -1526,6 +1625,7 @@ async fn test_bloom_filter_decimal_dict() { .with_expected_errors(Some(0)) .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(1)) + .with_pruned_files(Some(0)) .with_expected_rows(0) .with_pruned_by_bloom_filter(Some(1)) .with_matched_by_bloom_filter(Some(0)) diff --git a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs index 568be0d18f24..de8f7b36a3a7 100644 --- a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs @@ -136,7 +136,7 @@ fn aggregations_not_combined() -> datafusion_common::Result<()> { let plan = final_aggregate_exec( repartition_exec(partial_aggregate_exec( - parquet_exec(&schema), + parquet_exec(schema.clone()), PhysicalGroupBy::default(), aggr_expr.clone(), )), @@ -157,7 +157,7 @@ fn aggregations_not_combined() -> datafusion_common::Result<()> { let plan = final_aggregate_exec( partial_aggregate_exec( - parquet_exec(&schema), + parquet_exec(schema), PhysicalGroupBy::default(), aggr_expr1, ), @@ -183,7 +183,7 @@ fn aggregations_combined() -> datafusion_common::Result<()> { let plan = final_aggregate_exec( partial_aggregate_exec( - parquet_exec(&schema), + parquet_exec(schema), PhysicalGroupBy::default(), aggr_expr.clone(), ), @@ -215,11 +215,8 @@ fn aggregations_with_group_combined() -> datafusion_common::Result<()> { vec![(col("c", &schema)?, "c".to_string())]; let partial_group_by = PhysicalGroupBy::new_single(groups); - let partial_agg = partial_aggregate_exec( - parquet_exec(&schema), - partial_group_by, - aggr_expr.clone(), - ); + let partial_agg = + partial_aggregate_exec(parquet_exec(schema), partial_group_by, aggr_expr.clone()); let groups: Vec<(Arc, String)> = vec![(col("c", &partial_agg.schema())?, "c".to_string())]; @@ -245,11 +242,8 @@ fn aggregations_with_limit_combined() -> datafusion_common::Result<()> { vec![(col("c", &schema)?, "c".to_string())]; let partial_group_by = PhysicalGroupBy::new_single(groups); - let partial_agg = partial_aggregate_exec( - parquet_exec(&schema), - partial_group_by, - aggr_expr.clone(), - ); + let partial_agg = + partial_aggregate_exec(parquet_exec(schema), partial_group_by, aggr_expr.clone()); let groups: Vec<(Arc, String)> = vec![(col("c", &partial_agg.schema())?, "c".to_string())]; diff --git a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs index 9898f6204e88..fd847763124a 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs @@ -19,31 +19,35 @@ use std::fmt::Debug; use std::ops::Deref; use std::sync::Arc; -use crate::physical_optimizer::test_utils::parquet_exec_with_sort; use crate::physical_optimizer::test_utils::{ - check_integrity, coalesce_partitions_exec, repartition_exec, schema, - sort_merge_join_exec, sort_preserving_merge_exec, + check_integrity, coalesce_partitions_exec, parquet_exec_with_sort, + parquet_exec_with_stats, repartition_exec, schema, sort_exec, + sort_exec_with_preserve_partitioning, sort_merge_join_exec, + sort_preserving_merge_exec, union_exec, }; +use arrow::array::{RecordBatch, UInt64Array, UInt8Array}; use arrow::compute::SortOptions; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion::config::ConfigOptions; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{CsvSource, ParquetSource}; use datafusion::datasource::source::DataSourceExec; +use datafusion::datasource::MemTable; +use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::error::Result; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_expr::{JoinType, Operator}; -use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; -use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_expr::{ - expressions::binary, expressions::lit, LexOrdering, PhysicalSortExpr, +use datafusion_physical_expr::expressions::{binary, lit, BinaryExpr, Column, Literal}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, OrderingRequirements, PhysicalSortExpr, }; -use datafusion_physical_expr_common::sort_expr::LexRequirement; use datafusion_physical_optimizer::enforce_distribution::*; use datafusion_physical_optimizer::enforce_sorting::EnforceSorting; use datafusion_physical_optimizer::output_requirements::OutputRequirements; @@ -52,19 +56,18 @@ use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::execution_plan::ExecutionPlan; use datafusion_physical_plan::expressions::col; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::JoinOn; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::union::UnionExec; -use datafusion_physical_plan::ExecutionPlanProperties; -use datafusion_physical_plan::PlanProperties; use datafusion_physical_plan::{ - get_plan_string, DisplayAs, DisplayFormatType, Statistics, + get_plan_string, DisplayAs, DisplayFormatType, ExecutionPlanProperties, + PlanProperties, Statistics, }; /// Models operators like BoundedWindowExec that require an input @@ -140,12 +143,8 @@ impl ExecutionPlan for SortRequiredExec { } // model that it requires the output ordering of its input - fn required_input_ordering(&self) -> Vec> { - if self.expr.is_empty() { - vec![None] - } else { - vec![Some(LexRequirement::from(self.expr.clone()))] - } + fn required_input_ordering(&self) -> Vec> { + vec![Some(OrderingRequirements::from(self.expr.clone()))] } fn with_new_children( @@ -169,12 +168,12 @@ impl ExecutionPlan for SortRequiredExec { } fn statistics(&self) -> Result { - self.input.statistics() + self.input.partition_statistics(None) } } fn parquet_exec() -> Arc { - parquet_exec_with_sort(vec![]) + parquet_exec_with_sort(schema(), vec![]) } fn parquet_exec_multiple() -> Arc { @@ -320,16 +319,6 @@ fn filter_exec(input: Arc) -> Arc { Arc::new(FilterExec::try_new(predicate, input).unwrap()) } -fn sort_exec( - sort_exprs: LexOrdering, - input: Arc, - preserve_partitioning: bool, -) -> Arc { - let new_sort = SortExec::new(sort_exprs, input) - .with_preserve_partitioning(preserve_partitioning); - Arc::new(new_sort) -} - fn limit_exec(input: Arc) -> Arc { Arc::new(GlobalLimitExec::new( Arc::new(LocalLimitExec::new(input, 100)), @@ -338,10 +327,6 @@ fn limit_exec(input: Arc) -> Arc { )) } -fn union_exec(input: Vec>) -> Arc { - Arc::new(UnionExec::new(input)) -} - fn sort_required_exec_with_req( input: Arc, sort_exprs: LexOrdering, @@ -524,8 +509,7 @@ impl TestConfig { assert_eq!( &expected_lines, &actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines + "\n\nexpected:\n\n{expected_lines:#?}\nactual:\n\n{actual_lines:#?}\n\n" ); Ok(optimized) @@ -643,7 +627,7 @@ fn multi_hash_joins() -> Result<()> { test_config.run(&expected, top_join.clone(), &DISTRIB_DISTRIB_SORT)?; test_config.run(&expected, top_join, &SORT_DISTRIB_DISTRIB)?; } - JoinType::RightSemi | JoinType::RightAnti => {} + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {} } match join_type { @@ -652,7 +636,8 @@ fn multi_hash_joins() -> Result<()> { | JoinType::Right | JoinType::Full | JoinType::RightSemi - | JoinType::RightAnti => { + | JoinType::RightAnti + | JoinType::RightMark => { // This time we use (b1 == c) for top join // Join on (b1 == c) let top_join_on = vec![( @@ -1736,10 +1721,11 @@ fn smj_join_key_ordering() -> Result<()> { fn merge_does_not_need_sort() -> Result<()> { // see https://github.com/apache/datafusion/issues/4331 let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); // Scan some sorted parquet files let exec = parquet_exec_multiple_sorted(vec![sort_key.clone()]); @@ -1936,11 +1922,12 @@ fn repartition_unsorted_limit() -> Result<()> { #[test] fn repartition_sorted_limit() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let plan = limit_exec(sort_exec(sort_key, parquet_exec(), false)); + }] + .into(); + let plan = limit_exec(sort_exec(sort_key, parquet_exec())); let expected = &[ "GlobalLimitExec: skip=0, fetch=100", @@ -1960,12 +1947,13 @@ fn repartition_sorted_limit() -> Result<()> { #[test] fn repartition_sorted_limit_with_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_required_exec_with_req( - filter_exec(sort_exec(sort_key.clone(), parquet_exec(), false)), + filter_exec(sort_exec(sort_key.clone(), parquet_exec())), sort_key, ); @@ -2043,10 +2031,11 @@ fn repartition_ignores_union() -> Result<()> { fn repartition_through_sort_preserving_merge() -> Result<()> { // sort preserving merge with non-sorted input let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_preserving_merge_exec(sort_key, parquet_exec()); // need resort as the data was not sorted correctly @@ -2066,10 +2055,11 @@ fn repartition_through_sort_preserving_merge() -> Result<()> { fn repartition_ignores_sort_preserving_merge() -> Result<()> { // sort preserving merge already sorted input, let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_preserving_merge_exec( sort_key.clone(), parquet_exec_multiple_sorted(vec![sort_key]), @@ -2101,11 +2091,15 @@ fn repartition_ignores_sort_preserving_merge() -> Result<()> { fn repartition_ignores_sort_preserving_merge_with_union() -> Result<()> { // 2 sorted parquet files unioned (partitions are concatenated, sort is preserved) let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let input = union_exec(vec![parquet_exec_with_sort(vec![sort_key.clone()]); 2]); + }] + .into(); + let input = union_exec(vec![ + parquet_exec_with_sort(schema, vec![sort_key.clone()]); + 2 + ]); let plan = sort_preserving_merge_exec(sort_key, input); // Test: run EnforceDistribution, then EnforceSort. @@ -2139,12 +2133,13 @@ fn repartition_does_not_destroy_sort() -> Result<()> { // SortRequired // Parquet(sorted) let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("d", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("d", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_required_exec_with_req( - filter_exec(parquet_exec_with_sort(vec![sort_key.clone()])), + filter_exec(parquet_exec_with_sort(schema, vec![sort_key.clone()])), sort_key, ); @@ -2177,12 +2172,13 @@ fn repartition_does_not_destroy_sort_more_complex() -> Result<()> { // Parquet(unsorted) let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input1 = sort_required_exec_with_req( - parquet_exec_with_sort(vec![sort_key.clone()]), + parquet_exec_with_sort(schema, vec![sort_key.clone()]), sort_key, ); let input2 = filter_exec(parquet_exec()); @@ -2213,18 +2209,19 @@ fn repartition_transitively_with_projection() -> Result<()> { let schema = schema(); let proj_exprs = vec![( Arc::new(BinaryExpr::new( - col("a", &schema).unwrap(), + col("a", &schema)?, Operator::Plus, - col("b", &schema).unwrap(), - )) as Arc, + col("b", &schema)?, + )) as _, "sum".to_string(), )]; // non sorted input let proj = Arc::new(ProjectionExec::try_new(proj_exprs, parquet_exec())?); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("sum", &proj.schema()).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("sum", &proj.schema())?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_preserving_merge_exec(sort_key, proj); // Test: run EnforceDistribution, then EnforceSort. @@ -2256,10 +2253,11 @@ fn repartition_transitively_with_projection() -> Result<()> { #[test] fn repartition_ignores_transitively_with_projection() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let alias = vec![ ("a".to_string(), "a".to_string()), ("b".to_string(), "b".to_string()), @@ -2291,10 +2289,11 @@ fn repartition_ignores_transitively_with_projection() -> Result<()> { #[test] fn repartition_transitively_past_sort_with_projection() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let alias = vec![ ("a".to_string(), "a".to_string()), ("b".to_string(), "b".to_string()), @@ -2302,10 +2301,9 @@ fn repartition_transitively_past_sort_with_projection() -> Result<()> { ]; let plan = sort_preserving_merge_exec( sort_key.clone(), - sort_exec( + sort_exec_with_preserve_partitioning( sort_key, projection_exec_with_alias(parquet_exec(), alias), - true, ), ); @@ -2326,11 +2324,12 @@ fn repartition_transitively_past_sort_with_projection() -> Result<()> { #[test] fn repartition_transitively_past_sort_with_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); - let plan = sort_exec(sort_key, filter_exec(parquet_exec()), false); + }] + .into(); + let plan = sort_exec(sort_key, filter_exec(parquet_exec())); // Test: run EnforceDistribution, then EnforceSort. let expected = &[ @@ -2362,10 +2361,11 @@ fn repartition_transitively_past_sort_with_filter() -> Result<()> { #[cfg(feature = "parquet")] fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = sort_exec( sort_key, projection_exec_with_alias( @@ -2376,7 +2376,6 @@ fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> ("c".to_string(), "c".to_string()), ], ), - false, ); // Test: run EnforceDistribution, then EnforceSort. @@ -2447,10 +2446,11 @@ fn parallelization_single_partition() -> Result<()> { #[test] fn parallelization_multiple_files() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let plan = filter_exec(parquet_exec_multiple_sorted(vec![sort_key.clone()])); let plan = sort_required_exec_with_req(plan, sort_key); @@ -2637,12 +2637,13 @@ fn parallelization_two_partitions_into_four() -> Result<()> { #[test] fn parallelization_sorted_limit() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let plan_parquet = limit_exec(sort_exec(sort_key.clone(), parquet_exec(), false)); - let plan_csv = limit_exec(sort_exec(sort_key, csv_exec(), false)); + }] + .into(); + let plan_parquet = limit_exec(sort_exec(sort_key.clone(), parquet_exec())); + let plan_csv = limit_exec(sort_exec(sort_key, csv_exec())); let test_config = TestConfig::default(); @@ -2680,16 +2681,14 @@ fn parallelization_sorted_limit() -> Result<()> { #[test] fn parallelization_limit_with_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let plan_parquet = limit_exec(filter_exec(sort_exec( - sort_key.clone(), - parquet_exec(), - false, - ))); - let plan_csv = limit_exec(filter_exec(sort_exec(sort_key, csv_exec(), false))); + }] + .into(); + let plan_parquet = + limit_exec(filter_exec(sort_exec(sort_key.clone(), parquet_exec()))); + let plan_csv = limit_exec(filter_exec(sort_exec(sort_key, csv_exec()))); let test_config = TestConfig::default(); @@ -2834,14 +2833,15 @@ fn parallelization_union_inputs() -> Result<()> { #[test] fn parallelization_prior_to_sort_preserving_merge() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); // sort preserving merge already sorted input, let plan_parquet = sort_preserving_merge_exec( sort_key.clone(), - parquet_exec_with_sort(vec![sort_key.clone()]), + parquet_exec_with_sort(schema, vec![sort_key.clone()]), ); let plan_csv = sort_preserving_merge_exec(sort_key.clone(), csv_exec_with_sort(vec![sort_key])); @@ -2875,13 +2875,17 @@ fn parallelization_prior_to_sort_preserving_merge() -> Result<()> { #[test] fn parallelization_sort_preserving_merge_with_union() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); // 2 sorted parquet files unioned (partitions are concatenated, sort is preserved) let input_parquet = - union_exec(vec![parquet_exec_with_sort(vec![sort_key.clone()]); 2]); + union_exec(vec![ + parquet_exec_with_sort(schema, vec![sort_key.clone()]); + 2 + ]); let input_csv = union_exec(vec![csv_exec_with_sort(vec![sort_key.clone()]); 2]); let plan_parquet = sort_preserving_merge_exec(sort_key.clone(), input_parquet); let plan_csv = sort_preserving_merge_exec(sort_key, input_csv); @@ -2948,14 +2952,15 @@ fn parallelization_sort_preserving_merge_with_union() -> Result<()> { #[test] fn parallelization_does_not_benefit() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); // SortRequired // Parquet(sorted) let plan_parquet = sort_required_exec_with_req( - parquet_exec_with_sort(vec![sort_key.clone()]), + parquet_exec_with_sort(schema, vec![sort_key.clone()]), sort_key.clone(), ); let plan_csv = @@ -2993,22 +2998,26 @@ fn parallelization_does_not_benefit() -> Result<()> { fn parallelization_ignores_transitively_with_projection_parquet() -> Result<()> { // sorted input let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); //Projection(a as a2, b as b2) let alias_pairs: Vec<(String, String)> = vec![ ("a".to_string(), "a2".to_string()), ("c".to_string(), "c2".to_string()), ]; - let proj_parquet = - projection_exec_with_alias(parquet_exec_with_sort(vec![sort_key]), alias_pairs); - let sort_key_after_projection = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c2", &proj_parquet.schema()).unwrap(), + let proj_parquet = projection_exec_with_alias( + parquet_exec_with_sort(schema, vec![sort_key]), + alias_pairs, + ); + let sort_key_after_projection = [PhysicalSortExpr { + expr: col("c2", &proj_parquet.schema())?, options: SortOptions::default(), - }]); + }] + .into(); let plan_parquet = sort_preserving_merge_exec(sort_key_after_projection, proj_parquet); let expected = &[ @@ -3039,10 +3048,11 @@ fn parallelization_ignores_transitively_with_projection_parquet() -> Result<()> fn parallelization_ignores_transitively_with_projection_csv() -> Result<()> { // sorted input let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); //Projection(a as a2, b as b2) let alias_pairs: Vec<(String, String)> = vec![ @@ -3052,10 +3062,11 @@ fn parallelization_ignores_transitively_with_projection_csv() -> Result<()> { let proj_csv = projection_exec_with_alias(csv_exec_with_sort(vec![sort_key]), alias_pairs); - let sort_key_after_projection = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c2", &proj_csv.schema()).unwrap(), + let sort_key_after_projection = [PhysicalSortExpr { + expr: col("c2", &proj_csv.schema())?, options: SortOptions::default(), - }]); + }] + .into(); let plan_csv = sort_preserving_merge_exec(sort_key_after_projection, proj_csv); let expected = &[ "SortPreservingMergeExec: [c2@1 ASC]", @@ -3108,10 +3119,11 @@ fn remove_redundant_roundrobins() -> Result<()> { #[test] fn remove_unnecessary_spm_after_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -3138,10 +3150,11 @@ fn remove_unnecessary_spm_after_filter() -> Result<()> { #[test] fn preserve_ordering_through_repartition() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("d", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("d", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -3163,10 +3176,11 @@ fn preserve_ordering_through_repartition() -> Result<()> { #[test] fn do_not_preserve_ordering_through_repartition() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -3202,10 +3216,11 @@ fn do_not_preserve_ordering_through_repartition() -> Result<()> { #[test] fn no_need_for_sort_after_filter() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -3227,16 +3242,18 @@ fn no_need_for_sort_after_filter() -> Result<()> { #[test] fn do_not_preserve_ordering_through_repartition2() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key]); - let sort_req = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_req = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let physical_plan = sort_preserving_merge_exec(sort_req, filter_exec(input)); let test_config = TestConfig::default(); @@ -3272,10 +3289,11 @@ fn do_not_preserve_ordering_through_repartition2() -> Result<()> { #[test] fn do_not_preserve_ordering_through_repartition3() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key]); let physical_plan = filter_exec(input); @@ -3294,10 +3312,11 @@ fn do_not_preserve_ordering_through_repartition3() -> Result<()> { #[test] fn do_not_put_sort_when_input_is_invalid() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec(); let physical_plan = sort_required_exec_with_req(filter_exec(input), sort_key); let expected = &[ @@ -3331,10 +3350,11 @@ fn do_not_put_sort_when_input_is_invalid() -> Result<()> { #[test] fn put_sort_when_input_is_valid() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema).unwrap(), + let sort_key: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_required_exec_with_req(filter_exec(input), sort_key); @@ -3368,12 +3388,13 @@ fn put_sort_when_input_is_valid() -> Result<()> { #[test] fn do_not_add_unnecessary_hash() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let alias = vec![("a".to_string(), "a".to_string())]; - let input = parquet_exec_with_sort(vec![sort_key]); + let input = parquet_exec_with_sort(schema, vec![sort_key]); let physical_plan = aggregate_exec_with_alias(input, alias); // TestConfig: @@ -3394,10 +3415,11 @@ fn do_not_add_unnecessary_hash() -> Result<()> { #[test] fn do_not_add_unnecessary_hash2() -> Result<()> { let schema = schema(); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + let sort_key = [PhysicalSortExpr { + expr: col("c", &schema)?, options: SortOptions::default(), - }]); + }] + .into(); let alias = vec![("a".to_string(), "a".to_string())]; let input = parquet_exec_multiple_sorted(vec![sort_key]); let aggregate = aggregate_exec_with_alias(input, alias.clone()); @@ -3471,3 +3493,140 @@ fn optimize_away_unnecessary_repartition2() -> Result<()> { Ok(()) } + +/// Ensures that `DataSourceExec` has been repartitioned into `target_partitions` file groups +#[tokio::test] +async fn test_distribute_sort_parquet() -> Result<()> { + let test_config: TestConfig = + TestConfig::default().with_prefer_repartition_file_scans(1000); + assert!( + test_config.config.optimizer.repartition_file_scans, + "should enable scans to be repartitioned" + ); + + let schema = schema(); + let sort_key = [PhysicalSortExpr::new_default(col("c", &schema)?)].into(); + let physical_plan = sort_exec(sort_key, parquet_exec_with_stats(10000 * 8192)); + + // prior to optimization, this is the starting plan + let starting = &[ + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + plans_matches_expected!(starting, physical_plan.clone()); + + // what the enforce distribution run does. + let expected = &[ + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", + " CoalescePartitionsExec", + " DataSourceExec: file_groups={10 groups: [[x:0..8192000], [x:8192000..16384000], [x:16384000..24576000], [x:24576000..32768000], [x:32768000..40960000], [x:40960000..49152000], [x:49152000..57344000], [x:57344000..65536000], [x:65536000..73728000], [x:73728000..81920000]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + test_config.run(expected, physical_plan.clone(), &[Run::Distribution])?; + + // what the sort parallelization (in enforce sorting), does after the enforce distribution changes + let expected = &[ + "SortPreservingMergeExec: [c@2 ASC]", + " SortExec: expr=[c@2 ASC], preserve_partitioning=[true]", + " DataSourceExec: file_groups={10 groups: [[x:0..8192000], [x:8192000..16384000], [x:16384000..24576000], [x:24576000..32768000], [x:32768000..40960000], [x:40960000..49152000], [x:49152000..57344000], [x:57344000..65536000], [x:65536000..73728000], [x:73728000..81920000]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + test_config.run(expected, physical_plan, &[Run::Distribution, Run::Sorting])?; + Ok(()) +} + +/// Ensures that `DataSourceExec` has been repartitioned into `target_partitions` memtable groups +#[tokio::test] +async fn test_distribute_sort_memtable() -> Result<()> { + let test_config: TestConfig = + TestConfig::default().with_prefer_repartition_file_scans(1000); + assert!( + test_config.config.optimizer.repartition_file_scans, + "should enable scans to be repartitioned" + ); + + let mem_table = create_memtable()?; + let session_config = SessionConfig::new() + .with_repartition_file_min_size(1000) + .with_target_partitions(3); + let ctx = SessionContext::new_with_config(session_config); + ctx.register_table("users", Arc::new(mem_table))?; + + let dataframe = ctx.sql("SELECT * FROM users order by id;").await?; + let physical_plan = dataframe.create_physical_plan().await?; + + // this is the final, optimized plan + let expected = &[ + "SortPreservingMergeExec: [id@0 ASC NULLS LAST]", + " SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true]", + " DataSourceExec: partitions=3, partition_sizes=[34, 33, 33]", + ]; + plans_matches_expected!(expected, physical_plan); + + Ok(()) +} + +/// Create a [`MemTable`] with 100 batches of 8192 rows each, in 1 partition +fn create_memtable() -> Result { + let mut batches = Vec::with_capacity(100); + for _ in 0..100 { + batches.push(create_record_batch()?); + } + let partitions = vec![batches]; + MemTable::try_new(get_schema(), partitions) +} + +fn create_record_batch() -> Result { + let id_array = UInt8Array::from(vec![1; 8192]); + let account_array = UInt64Array::from(vec![9000; 8192]); + + Ok(RecordBatch::try_new( + get_schema(), + vec![Arc::new(id_array), Arc::new(account_array)], + ) + .unwrap()) +} + +fn get_schema() -> SchemaRef { + SchemaRef::new(Schema::new(vec![ + Field::new("id", DataType::UInt8, false), + Field::new("bank_account", DataType::UInt64, true), + ])) +} +#[test] +fn test_replace_order_preserving_variants_with_fetch() -> Result<()> { + // Create a base plan + let parquet_exec = parquet_exec(); + + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("id", 0))); + + // Create a SortPreservingMergeExec with fetch=5 + let spm_exec = Arc::new( + SortPreservingMergeExec::new([sort_expr].into(), parquet_exec.clone()) + .with_fetch(Some(5)), + ); + + // Create distribution context + let dist_context = DistributionContext::new( + spm_exec, + true, + vec![DistributionContext::new(parquet_exec, false, vec![])], + ); + + // Apply the function + let result = replace_order_preserving_variants(dist_context)?; + + // Verify the plan was transformed to CoalescePartitionsExec + result + .plan + .as_any() + .downcast_ref::() + .expect("Expected CoalescePartitionsExec"); + + // Verify fetch was preserved + assert_eq!( + result.plan.fetch(), + Some(5), + "Fetch value was not preserved after transformation" + ); + + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index d4b84a52f401..38bc10a967e2 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -18,13 +18,14 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - aggregate_exec, bounded_window_exec, check_integrity, coalesce_batches_exec, - coalesce_partitions_exec, create_test_schema, create_test_schema2, - create_test_schema3, filter_exec, global_limit_exec, hash_join_exec, limit_exec, - local_limit_exec, memory_exec, parquet_exec, repartition_exec, sort_exec, - sort_exec_with_fetch, sort_expr, sort_expr_options, sort_merge_join_exec, - sort_preserving_merge_exec, sort_preserving_merge_exec_with_fetch, - spr_repartition_exec, stream_exec_ordered, union_exec, RequirementsTestExec, + aggregate_exec, bounded_window_exec, bounded_window_exec_with_partition, + check_integrity, coalesce_batches_exec, coalesce_partitions_exec, create_test_schema, + create_test_schema2, create_test_schema3, filter_exec, global_limit_exec, + hash_join_exec, local_limit_exec, memory_exec, parquet_exec, parquet_exec_with_sort, + projection_exec, repartition_exec, sort_exec, sort_exec_with_fetch, sort_expr, + sort_expr_options, sort_merge_join_exec, sort_preserving_merge_exec, + sort_preserving_merge_exec_with_fetch, spr_repartition_exec, stream_exec_ordered, + union_exec, RequirementsTestExec, }; use arrow::compute::SortOptions; @@ -32,89 +33,53 @@ use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{TreeNode, TransformedResult}; use datafusion_common::{Result, ScalarValue}; +use datafusion_datasource::file_scan_config::FileScanConfigBuilder; +use datafusion_datasource::source::DataSourceExec; +use datafusion_expr_common::operator::Operator; use datafusion_expr::{JoinType, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition}; use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_functions_aggregate::average::avg_udaf; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_expr::expressions::{col, Column, NotExpr}; -use datafusion_physical_expr::Partitioning; -use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, PhysicalSortExpr, PhysicalSortRequirement, OrderingRequirements +}; +use datafusion_physical_expr::{Distribution, Partitioning}; +use datafusion_physical_expr::expressions::{col, BinaryExpr, Column, NotExpr}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::windows::{create_window_expr, BoundedWindowAggExec, WindowAggExec}; use datafusion_physical_plan::{displayable, get_plan_string, ExecutionPlan, InputOrderMode}; -use datafusion::datasource::physical_plan::{CsvSource, ParquetSource}; +use datafusion::datasource::physical_plan::CsvSource; use datafusion::datasource::listing::PartitionedFile; use datafusion_physical_optimizer::enforce_sorting::{EnforceSorting, PlanWithCorrespondingCoalescePartitions, PlanWithCorrespondingSort, parallelize_sorts, ensure_sorting}; use datafusion_physical_optimizer::enforce_sorting::replace_with_order_preserving_variants::{replace_with_order_preserving_variants, OrderPreservationContext}; use datafusion_physical_optimizer::enforce_sorting::sort_pushdown::{SortPushDown, assign_initial_requirements, pushdown_sorts}; use datafusion_physical_optimizer::enforce_distribution::EnforceDistribution; +use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; use datafusion_physical_optimizer::PhysicalOptimizerRule; -use datafusion_functions_aggregate::average::avg_udaf; -use datafusion_functions_aggregate::count::count_udaf; -use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf}; -use datafusion_datasource::file_scan_config::FileScanConfigBuilder; -use datafusion_datasource::source::DataSourceExec; use rstest::rstest; -/// Create a csv exec for tests -fn csv_exec_ordered( - schema: &SchemaRef, - sort_exprs: impl IntoIterator, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema.clone(), - Arc::new(CsvSource::new(true, 0, b'"')), - ) - .with_file(PartitionedFile::new("file_path".to_string(), 100)) - .with_output_ordering(vec![sort_exprs]) - .build(); - - DataSourceExec::from_data_source(config) -} - -/// Created a sorted parquet exec -pub fn parquet_exec_sorted( - schema: &SchemaRef, - sort_exprs: impl IntoIterator, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - - let source = Arc::new(ParquetSource::default()); - let config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("test:///").unwrap(), - schema.clone(), - source, - ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_output_ordering(vec![sort_exprs]) - .build(); - - DataSourceExec::from_data_source(config) -} - /// Create a sorted Csv exec fn csv_exec_sorted( schema: &SchemaRef, sort_exprs: impl IntoIterator, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - - let config = FileScanConfigBuilder::new( + let mut builder = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), schema.clone(), Arc::new(CsvSource::new(false, 0, 0)), ) - .with_file(PartitionedFile::new("x".to_string(), 100)) - .with_output_ordering(vec![sort_exprs]) - .build(); + .with_file(PartitionedFile::new("x".to_string(), 100)); + if let Some(ordering) = LexOrdering::new(sort_exprs) { + builder = builder.with_output_ordering(vec![ordering]); + } + let config = builder.build(); DataSourceExec::from_data_source(config) } @@ -163,7 +128,7 @@ macro_rules! assert_optimized { plan_with_pipeline_fixer, false, true, - &config, + &config, ) }) .data() @@ -210,24 +175,27 @@ async fn test_remove_unnecessary_sort5() -> Result<()> { let left_schema = create_test_schema2()?; let right_schema = create_test_schema3()?; let left_input = memory_exec(&left_schema); - let parquet_sort_exprs = vec![sort_expr("a", &right_schema)]; - let right_input = parquet_exec_sorted(&right_schema, parquet_sort_exprs); - + let parquet_ordering = [sort_expr("a", &right_schema)].into(); + let right_input = + parquet_exec_with_sort(right_schema.clone(), vec![parquet_ordering]); let on = vec![( Arc::new(Column::new_with_schema("col_a", &left_schema)?) as _, Arc::new(Column::new_with_schema("c", &right_schema)?) as _, )]; let join = hash_join_exec(left_input, right_input, on, None, &JoinType::Inner)?; - let physical_plan = sort_exec(vec![sort_expr("a", &join.schema())], join); + let physical_plan = sort_exec([sort_expr("a", &join.schema())].into(), join); - let expected_input = ["SortExec: expr=[a@2 ASC], preserve_partitioning=[false]", + let expected_input = [ + "SortExec: expr=[a@2 ASC], preserve_partitioning=[false]", " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col_a@0, c@2)]", " DataSourceExec: partitions=1, partition_sizes=[0]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet"]; - - let expected_optimized = ["HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col_a@0, c@2)]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet", + ]; + let expected_optimized = [ + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col_a@0, c@2)]", " DataSourceExec: partitions=1, partition_sizes=[0]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -236,41 +204,40 @@ async fn test_remove_unnecessary_sort5() -> Result<()> { #[tokio::test] async fn test_do_not_remove_sort_with_limit() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort = sort_exec(sort_exprs.clone(), source1); - let limit = limit_exec(sort); - - let parquet_sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs); - + ] + .into(); + let sort = sort_exec(ordering.clone(), source1); + let limit = local_limit_exec(sort, 100); + let parquet_ordering = [sort_expr("nullable_col", &schema)].into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering]); let union = union_exec(vec![source2, limit]); let repartition = repartition_exec(union); - let physical_plan = sort_preserving_merge_exec(sort_exprs, repartition); + let physical_plan = sort_preserving_merge_exec(ordering, repartition); - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", " UnionExec", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " GlobalLimitExec: skip=0, fetch=100", - " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - + " LocalLimitExec: fetch=100", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; // We should keep the bottom `SortExec`. - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + let expected_optimized = [ + "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", " UnionExec", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " GlobalLimitExec: skip=0, fetch=100", - " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " LocalLimitExec: fetch=100", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -279,18 +246,15 @@ async fn test_do_not_remove_sort_with_limit() -> Result<()> { #[tokio::test] async fn test_union_inputs_sorted() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source1); - - let source2 = parquet_exec_sorted(&schema, sort_exprs.clone()); - + let source1 = parquet_exec(schema.clone()); + let ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source1); + let source2 = parquet_exec_with_sort(schema, vec![ordering.clone()]); let union = union_exec(vec![source2, sort]); - let physical_plan = sort_preserving_merge_exec(sort_exprs, union); + let physical_plan = sort_preserving_merge_exec(ordering, union); // one input to the union is already sorted, one is not. - let expected_input = vec![ + let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", @@ -298,8 +262,7 @@ async fn test_union_inputs_sorted() -> Result<()> { " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", ]; // should not add a sort at the output of the union, input plan should not be changed - let expected_optimized = expected_input.clone(); - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!(expected_input, expected_input, physical_plan, true); Ok(()) } @@ -307,22 +270,20 @@ async fn test_union_inputs_sorted() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source1); - - let parquet_sort_exprs = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source1); + let parquet_ordering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs); - + ] + .into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering]); let union = union_exec(vec![source2, sort]); - let physical_plan = sort_preserving_merge_exec(sort_exprs, union); + let physical_plan = sort_preserving_merge_exec(ordering, union); // one input to the union is already sorted, one is not. - let expected_input = vec![ + let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet", @@ -330,8 +291,7 @@ async fn test_union_inputs_different_sorted() -> Result<()> { " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", ]; // should not add a sort at the output of the union, input plan should not be changed - let expected_optimized = expected_input.clone(); - assert_optimized!(expected_input, expected_optimized, physical_plan, true); + assert_optimized!(expected_input, expected_input, physical_plan, true); Ok(()) } @@ -339,35 +299,36 @@ async fn test_union_inputs_different_sorted() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted2() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs = vec![ + let source1 = parquet_exec(schema.clone()); + let sort_exprs: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; + ] + .into(); let sort = sort_exec(sort_exprs.clone(), source1); - - let parquet_sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs); - + let parquet_ordering = [sort_expr("nullable_col", &schema)].into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering]); let union = union_exec(vec![source2, sort]); let physical_plan = sort_preserving_merge_exec(sort_exprs, union); // Input is an invalid plan. In this case rule should add required sorting in appropriate places. // First DataSourceExec has output ordering(nullable_col@0 ASC). However, it doesn't satisfy the // required ordering of SortPreservingMergeExec. - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -376,40 +337,42 @@ async fn test_union_inputs_different_sorted2() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted3() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - let sort_exprs2 = vec![sort_expr("nullable_col", &schema)]; - let sort2 = sort_exec(sort_exprs2, source1); - - let parquet_sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs.clone()); - + ] + .into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let ordering2 = [sort_expr("nullable_col", &schema)].into(); + let sort2 = sort_exec(ordering2, source1); + let parquet_ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering.clone()]); let union = union_exec(vec![sort1, source2, sort2]); - let physical_plan = sort_preserving_merge_exec(parquet_sort_exprs, union); + let physical_plan = sort_preserving_merge_exec(parquet_ordering, union); // First input to the union is not Sorted (SortExec is finer than required ordering by the SortPreservingMergeExec above). // Second input to the union is already Sorted (matches with the required ordering by the SortPreservingMergeExec above). // Third input to the union is not Sorted (SortExec is matches required ordering by the SortPreservingMergeExec above). - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; // should adjust sorting in the first input of the union such that it is not unnecessarily fine - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_optimized = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -418,40 +381,42 @@ async fn test_union_inputs_different_sorted3() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted4() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs2 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs2.clone(), source1.clone()); - let sort2 = sort_exec(sort_exprs2.clone(), source1); - - let source2 = parquet_exec_sorted(&schema, sort_exprs2); - + ] + .into(); + let ordering2: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let sort1 = sort_exec(ordering2.clone(), source1.clone()); + let sort2 = sort_exec(ordering2.clone(), source1); + let source2 = parquet_exec_with_sort(schema, vec![ordering2]); let union = union_exec(vec![sort1, source2, sort2]); - let physical_plan = sort_preserving_merge_exec(sort_exprs1, union); + let physical_plan = sort_preserving_merge_exec(ordering1, union); // Ordering requirement of the `SortPreservingMergeExec` is not met. // Should modify the plan to ensure that all three inputs to the // `UnionExec` satisfy the ordering, OR add a single sort after // the `UnionExec` (both of which are equally good for this example). - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -460,13 +425,13 @@ async fn test_union_inputs_different_sorted4() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted5() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs2 = vec![ + ] + .into(); + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr_options( "non_nullable_col", @@ -476,29 +441,33 @@ async fn test_union_inputs_different_sorted5() -> Result<()> { nulls_first: false, }, ), - ]; - let sort_exprs3 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - let sort2 = sort_exec(sort_exprs2, source1); - + ] + .into(); + let ordering3 = [sort_expr("nullable_col", &schema)].into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let sort2 = sort_exec(ordering2, source1); let union = union_exec(vec![sort1, sort2]); - let physical_plan = sort_preserving_merge_exec(sort_exprs3, union); + let physical_plan = sort_preserving_merge_exec(ordering3, union); // The `UnionExec` doesn't preserve any of the inputs ordering in the // example below. However, we should be able to change the unnecessarily // fine `SortExec`s below with required `SortExec`s that are absolutely necessary. - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -507,22 +476,20 @@ async fn test_union_inputs_different_sorted5() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted6() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - let sort_exprs2 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [sort_expr("nullable_col", &schema)].into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; + ] + .into(); let repartition = repartition_exec(source1); - let spm = sort_preserving_merge_exec(sort_exprs2, repartition); - - let parquet_sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let source2 = parquet_exec_sorted(&schema, parquet_sort_exprs.clone()); - + let spm = sort_preserving_merge_exec(ordering2, repartition); + let parquet_ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let source2 = parquet_exec_with_sort(schema, vec![parquet_ordering.clone()]); let union = union_exec(vec![sort1, source2, spm]); - let physical_plan = sort_preserving_merge_exec(parquet_sort_exprs, union); + let physical_plan = sort_preserving_merge_exec(parquet_ordering, union); // The plan is not valid as it is -- the input ordering requirement // of the `SortPreservingMergeExec` under the third child of the @@ -530,24 +497,28 @@ async fn test_union_inputs_different_sorted6() -> Result<()> { // At the same time, this ordering requirement is unnecessarily fine. // The final plan should be valid AND the ordering of the third child // shouldn't be finer than necessary. - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; // Should adjust the requirement in the third input of the union so // that it is not unnecessarily fine. - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_optimized = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -556,33 +527,36 @@ async fn test_union_inputs_different_sorted6() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted7() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs3 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs1.clone(), source1.clone()); - let sort2 = sort_exec(sort_exprs1, source1); - + ] + .into(); + let sort1 = sort_exec(ordering1.clone(), source1.clone()); + let sort2 = sort_exec(ordering1, source1); let union = union_exec(vec![sort1, sort2]); - let physical_plan = sort_preserving_merge_exec(sort_exprs3, union); + let ordering2 = [sort_expr("nullable_col", &schema)].into(); + let physical_plan = sort_preserving_merge_exec(ordering2, union); // Union has unnecessarily fine ordering below it. We should be able to replace them with absolutely necessary ordering. - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; // Union preserves the inputs ordering and we should not change any of the SortExecs under UnionExec - let expected_output = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_output = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_output, physical_plan, true); Ok(()) @@ -591,13 +565,13 @@ async fn test_union_inputs_different_sorted7() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted8() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs2 = vec![ + ] + .into(); + let ordering2 = [ sort_expr_options( "nullable_col", &schema, @@ -614,74 +588,475 @@ async fn test_union_inputs_different_sorted8() -> Result<()> { nulls_first: false, }, ), - ]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - let sort2 = sort_exec(sort_exprs2, source1); - + ] + .into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let sort2 = sort_exec(ordering2, source1); let physical_plan = union_exec(vec![sort1, sort2]); // The `UnionExec` doesn't preserve any of the inputs ordering in the // example below. - let expected_input = ["UnionExec", + let expected_input = [ + "UnionExec", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " SortExec: expr=[nullable_col@0 DESC NULLS LAST, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; // Since `UnionExec` doesn't preserve ordering in the plan above. // We shouldn't keep SortExecs in the plan. - let expected_optimized = ["UnionExec", + let expected_optimized = [ + "UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) } #[tokio::test] -async fn test_window_multi_path_sort() -> Result<()> { +async fn test_soft_hard_requirements_remove_soft_requirement() -> Result<()> { let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let sort_exprs = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(sort_exprs, source); + let partition_bys = &[col("nullable_col", &schema)?]; + let physical_plan = + bounded_window_exec_with_partition("nullable_col", vec![], partition_bys, sort); - let sort_exprs1 = vec![ - sort_expr("nullable_col", &schema), - sort_expr("non_nullable_col", &schema), + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + let expected_optimized = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Ok(()) +} + +#[tokio::test] +async fn test_soft_hard_requirements_remove_soft_requirement_without_pushdowns( +) -> Result<()> { + let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source.clone()); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "count".to_string(), + )]; + let partition_bys = &[col("nullable_col", &schema)?]; + let bounded_window = + bounded_window_exec_with_partition("nullable_col", vec![], partition_bys, sort); + let physical_plan = projection_exec(proj_exprs, bounded_window)?; + + let expected_input = [ + "ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as count]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as count]", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + let expected_optimized = [ + "ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as count]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "nullable_col".to_string(), + )]; + let partition_bys = &[col("nullable_col", &schema)?]; + let projection = projection_exec(proj_exprs, sort)?; + let physical_plan = bounded_window_exec_with_partition( + "nullable_col", + vec![], + partition_bys, + projection, + ); + + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + let expected_optimized = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", + " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Ok(()) +} + +#[tokio::test] +async fn test_soft_hard_requirements_multiple_soft_requirements() -> Result<()> { + let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source.clone()); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "nullable_col".to_string(), + )]; + let partition_bys = &[col("nullable_col", &schema)?]; + let projection = projection_exec(proj_exprs, sort)?; + let bounded_window = bounded_window_exec_with_partition( + "nullable_col", + vec![], + partition_bys, + projection, + ); + let physical_plan = bounded_window_exec_with_partition( + "count", + vec![], + partition_bys, + bounded_window, + ); + + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + let expected_optimized = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", + " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", ]; - let sort_exprs2 = vec![sort_expr("nullable_col", &schema)]; - // reverse sorting of sort_exprs2 - let sort_exprs3 = vec![sort_expr_options( + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "nullable_col".to_string(), + )]; + let partition_bys = &[col("nullable_col", &schema)?]; + let projection = projection_exec(proj_exprs, sort)?; + let bounded_window = bounded_window_exec_with_partition( + "nullable_col", + vec![], + partition_bys, + projection, + ); + + let ordering2: LexOrdering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort2 = sort_exec(ordering2.clone(), bounded_window); + let sort3 = sort_exec(ordering2, sort2); + let physical_plan = + bounded_window_exec_with_partition("count", vec![], partition_bys, sort3); + + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + let expected_optimized = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", + " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Ok(()) +} +#[tokio::test] +async fn test_soft_hard_requirements_multiple_sorts() -> Result<()> { + let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let ordering = [sort_expr_options( "nullable_col", &schema, SortOptions { descending: true, nulls_first: false, }, + )] + .into(); + let sort = sort_exec(ordering, source); + let proj_exprs = vec![( + Arc::new(BinaryExpr::new( + col("nullable_col", &schema)?, + Operator::Plus, + col("non_nullable_col", &schema)?, + )) as _, + "nullable_col".to_string(), )]; - let source1 = parquet_exec_sorted(&schema, sort_exprs1); - let source2 = parquet_exec_sorted(&schema, sort_exprs2); - let sort1 = sort_exec(sort_exprs3.clone(), source1); - let sort2 = sort_exec(sort_exprs3.clone(), source2); + let partition_bys = &[col("nullable_col", &schema)?]; + let projection = projection_exec(proj_exprs, sort)?; + let bounded_window = bounded_window_exec_with_partition( + "nullable_col", + vec![], + partition_bys, + projection, + ); + let ordering2: LexOrdering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort2 = sort_exec(ordering2.clone(), bounded_window); + let physical_plan = sort_exec(ordering2, sort2); + let expected_input = [ + "SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + let expected_optimized = [ + "SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", + " ProjectionExec: expr=[nullable_col@0 + non_nullable_col@1 as nullable_col]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Ok(()) +} + +#[tokio::test] +async fn test_soft_hard_requirements_with_multiple_soft_requirements_and_output_requirement( +) -> Result<()> { + let schema = create_test_schema()?; + let source = parquet_exec(schema.clone()); + let ordering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let sort = sort_exec(ordering, source); + let partition_bys1 = &[col("nullable_col", &schema)?]; + let bounded_window = + bounded_window_exec_with_partition("nullable_col", vec![], partition_bys1, sort); + let partition_bys2 = &[col("non_nullable_col", &schema)?]; + let bounded_window2 = bounded_window_exec_with_partition( + "non_nullable_col", + vec![], + partition_bys2, + bounded_window, + ); + let requirement = [PhysicalSortRequirement::new( + col("non_nullable_col", &schema)?, + Some(SortOptions::new(false, true)), + )] + .into(); + let physical_plan = Arc::new(OutputRequirementExec::new( + bounded_window2, + Some(OrderingRequirements::new(requirement)), + Distribution::SinglePartition, + )); + + let expected_input = [ + "OutputRequirementExec", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + // TODO When sort pushdown respects to the alternatives, and removes soft SortExecs this should be changed + // let expected_optimized = [ + // "OutputRequirementExec", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + // " SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", + // " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear]", + // " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + // ]; + let expected_optimized = [ + "OutputRequirementExec", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + Ok(()) +} + +#[tokio::test] +async fn test_window_multi_path_sort() -> Result<()> { + let schema = create_test_schema()?; + let ordering1 = [ + sort_expr("nullable_col", &schema), + sort_expr("non_nullable_col", &schema), + ] + .into(); + let ordering2 = [sort_expr("nullable_col", &schema)].into(); + // Reverse of the above + let ordering3: LexOrdering = [sort_expr_options( + "nullable_col", + &schema, + SortOptions { + descending: true, + nulls_first: false, + }, + )] + .into(); + let source1 = parquet_exec_with_sort(schema.clone(), vec![ordering1]); + let source2 = parquet_exec_with_sort(schema, vec![ordering2]); + let sort1 = sort_exec(ordering3.clone(), source1); + let sort2 = sort_exec(ordering3.clone(), source2); let union = union_exec(vec![sort1, sort2]); - let spm = sort_preserving_merge_exec(sort_exprs3.clone(), union); - let physical_plan = bounded_window_exec("nullable_col", sort_exprs3, spm); + let spm = sort_preserving_merge_exec(ordering3.clone(), union); + let physical_plan = bounded_window_exec("nullable_col", ordering3, spm); // The `WindowAggExec` gets its sorting from multiple children jointly. // During the removal of `SortExec`s, it should be able to remove the // corresponding SortExecs together. Also, the inputs of these `SortExec`s // are not necessarily the same to be able to remove them. let expected_input = [ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " SortPreservingMergeExec: [nullable_col@0 DESC NULLS LAST]", " UnionExec", " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet", " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", + ]; let expected_optimized = [ - "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -690,35 +1065,38 @@ async fn test_window_multi_path_sort() -> Result<()> { #[tokio::test] async fn test_window_multi_path_sort2() -> Result<()> { let schema = create_test_schema()?; - - let sort_exprs1 = LexOrdering::new(vec![ + let ordering1: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]); - let sort_exprs2 = vec![sort_expr("nullable_col", &schema)]; - let source1 = parquet_exec_sorted(&schema, sort_exprs2.clone()); - let source2 = parquet_exec_sorted(&schema, sort_exprs2.clone()); - let sort1 = sort_exec(sort_exprs1.clone(), source1); - let sort2 = sort_exec(sort_exprs1.clone(), source2); - + ] + .into(); + let ordering2: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let source1 = parquet_exec_with_sort(schema.clone(), vec![ordering2.clone()]); + let source2 = parquet_exec_with_sort(schema, vec![ordering2.clone()]); + let sort1 = sort_exec(ordering1.clone(), source1); + let sort2 = sort_exec(ordering1.clone(), source2); let union = union_exec(vec![sort1, sort2]); - let spm = Arc::new(SortPreservingMergeExec::new(sort_exprs1, union)) as _; - let physical_plan = bounded_window_exec("nullable_col", sort_exprs2, spm); + let spm = Arc::new(SortPreservingMergeExec::new(ordering1, union)) as _; + let physical_plan = bounded_window_exec("nullable_col", ordering2, spm); // The `WindowAggExec` can get its required sorting from the leaf nodes directly. // The unnecessary SortExecs should be removed - let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet"]; - let expected_optimized = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", + ]; + let expected_optimized = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -727,13 +1105,13 @@ async fn test_window_multi_path_sort2() -> Result<()> { #[tokio::test] async fn test_union_inputs_different_sorted_with_limit() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); - let sort_exprs1 = vec![ + let source1 = parquet_exec(schema.clone()); + let ordering1 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort_exprs2 = vec![ + ] + .into(); + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr_options( "non_nullable_col", @@ -743,34 +1121,37 @@ async fn test_union_inputs_different_sorted_with_limit() -> Result<()> { nulls_first: false, }, ), - ]; - let sort_exprs3 = vec![sort_expr("nullable_col", &schema)]; - let sort1 = sort_exec(sort_exprs1, source1.clone()); - - let sort2 = sort_exec(sort_exprs2, source1); - let limit = local_limit_exec(sort2); - let limit = global_limit_exec(limit); - + ] + .into(); + let sort1 = sort_exec(ordering1, source1.clone()); + let sort2 = sort_exec(ordering2, source1); + let limit = local_limit_exec(sort2, 100); + let limit = global_limit_exec(limit, 0, Some(100)); let union = union_exec(vec![sort1, limit]); - let physical_plan = sort_preserving_merge_exec(sort_exprs3, union); + let ordering3 = [sort_expr("nullable_col", &schema)].into(); + let physical_plan = sort_preserving_merge_exec(ordering3, union); // Should not change the unnecessarily fine `SortExec`s because there is `LimitExec` - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " GlobalLimitExec: skip=0, fetch=100", " LocalLimitExec: fetch=100", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " GlobalLimitExec: skip=0, fetch=100", " LocalLimitExec: fetch=100", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -781,13 +1162,13 @@ async fn test_sort_merge_join_order_by_left() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; - let left = parquet_exec(&left_schema); - let right = parquet_exec(&right_schema); + let left = parquet_exec(left_schema); + let right = parquet_exec(right_schema); // Join on (nullable_col == col_a) let join_on = vec![( - Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("nullable_col", &left.schema())?) as _, + Arc::new(Column::new_with_schema("col_a", &right.schema())?) as _, )]; let join_types = vec![ @@ -801,11 +1182,12 @@ async fn test_sort_merge_join_order_by_left() -> Result<()> { for join_type in join_types { let join = sort_merge_join_exec(left.clone(), right.clone(), &join_on, &join_type); - let sort_exprs = vec![ + let ordering = [ sort_expr("nullable_col", &join.schema()), sort_expr("non_nullable_col", &join.schema()), - ]; - let physical_plan = sort_preserving_merge_exec(sort_exprs.clone(), join); + ] + .into(); + let physical_plan = sort_preserving_merge_exec(ordering, join); let join_plan = format!( "SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" @@ -853,13 +1235,13 @@ async fn test_sort_merge_join_order_by_right() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; - let left = parquet_exec(&left_schema); - let right = parquet_exec(&right_schema); + let left = parquet_exec(left_schema); + let right = parquet_exec(right_schema); // Join on (nullable_col == col_a) let join_on = vec![( - Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("nullable_col", &left.schema())?) as _, + Arc::new(Column::new_with_schema("col_a", &right.schema())?) as _, )]; let join_types = vec![ @@ -872,11 +1254,12 @@ async fn test_sort_merge_join_order_by_right() -> Result<()> { for join_type in join_types { let join = sort_merge_join_exec(left.clone(), right.clone(), &join_on, &join_type); - let sort_exprs = vec![ + let ordering = [ sort_expr("col_a", &join.schema()), sort_expr("col_b", &join.schema()), - ]; - let physical_plan = sort_preserving_merge_exec(sort_exprs, join); + ] + .into(); + let physical_plan = sort_preserving_merge_exec(ordering, join); let join_plan = format!( "SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" @@ -925,58 +1308,65 @@ async fn test_sort_merge_join_complex_order_by() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; - let left = parquet_exec(&left_schema); - let right = parquet_exec(&right_schema); + let left = parquet_exec(left_schema); + let right = parquet_exec(right_schema); // Join on (nullable_col == col_a) let join_on = vec![( - Arc::new(Column::new_with_schema("nullable_col", &left.schema()).unwrap()) as _, - Arc::new(Column::new_with_schema("col_a", &right.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("nullable_col", &left.schema())?) as _, + Arc::new(Column::new_with_schema("col_a", &right.schema())?) as _, )]; let join = sort_merge_join_exec(left, right, &join_on, &JoinType::Inner); // order by (col_b, col_a) - let sort_exprs1 = vec![ + let ordering = [ sort_expr("col_b", &join.schema()), sort_expr("col_a", &join.schema()), - ]; - let physical_plan = sort_preserving_merge_exec(sort_exprs1, join.clone()); + ] + .into(); + let physical_plan = sort_preserving_merge_exec(ordering, join.clone()); - let expected_input = ["SortPreservingMergeExec: [col_b@3 ASC, col_a@2 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [col_b@3 ASC, col_a@2 ASC]", " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; - + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet", + ]; // can not push down the sort requirements, need to add SortExec - let expected_optimized = ["SortExec: expr=[col_b@3 ASC, col_a@2 ASC], preserve_partitioning=[false]", + let expected_optimized = [ + "SortExec: expr=[col_b@3 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); // order by (nullable_col, col_b, col_a) - let sort_exprs2 = vec![ + let ordering2 = [ sort_expr("nullable_col", &join.schema()), sort_expr("col_b", &join.schema()), sort_expr("col_a", &join.schema()), - ]; - let physical_plan = sort_preserving_merge_exec(sort_exprs2, join); + ] + .into(); + let physical_plan = sort_preserving_merge_exec(ordering2, join); - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, col_b@3 ASC, col_a@2 ASC]", + let expected_input = [ + "SortPreservingMergeExec: [nullable_col@0 ASC, col_b@3 ASC, col_a@2 ASC]", " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; - - // can not push down the sort requirements, need to add SortExec - let expected_optimized = ["SortExec: expr=[nullable_col@0 ASC, col_b@3 ASC, col_a@2 ASC], preserve_partitioning=[false]", - " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", - " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", - " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet", + ]; + // Can push down the sort requirements since col_a = nullable_col + let expected_optimized = [ + "SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + " SortExec: expr=[col_a@0 ASC, col_b@1 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -985,33 +1375,34 @@ async fn test_sort_merge_join_complex_order_by() -> Result<()> { #[tokio::test] async fn test_multilayer_coalesce_partitions() -> Result<()> { let schema = create_test_schema()?; - - let source1 = parquet_exec(&schema); + let source1 = parquet_exec(schema.clone()); let repartition = repartition_exec(source1); - let coalesce = Arc::new(CoalescePartitionsExec::new(repartition)) as _; + let coalesce = coalesce_partitions_exec(repartition) as _; // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before let filter = filter_exec( - Arc::new(NotExpr::new( - col("non_nullable_col", schema.as_ref()).unwrap(), - )), + Arc::new(NotExpr::new(col("non_nullable_col", schema.as_ref())?)), coalesce, ); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let physical_plan = sort_exec(sort_exprs, filter); + let ordering = [sort_expr("nullable_col", &schema)].into(); + let physical_plan = sort_exec(ordering, filter); // CoalescePartitionsExec and SortExec are not directly consecutive. In this case // we should be able to parallelize Sorting also (given that executors in between don't require) // single partition. - let expected_input = ["SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", + let expected_input = [ + "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " FilterExec: NOT non_nullable_col@1", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true]", " FilterExec: NOT non_nullable_col@1", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], file_type=parquet", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -1020,26 +1411,30 @@ async fn test_multilayer_coalesce_partitions() -> Result<()> { #[tokio::test] async fn test_with_lost_ordering_bounded() -> Result<()> { let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let sort_exprs = [sort_expr("a", &schema)]; let source = csv_exec_sorted(&schema, sort_exprs); let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, - Partitioning::Hash(vec![col("c", &schema).unwrap()], 10), + Partitioning::Hash(vec![col("c", &schema)?], 10), )?) as _; let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); + let physical_plan = sort_exec([sort_expr("a", &schema)].into(), coalesce_partitions); - let expected_input = ["SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + let expected_input = [ + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false"]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [a@0 ASC]", " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false"]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -1051,20 +1446,20 @@ async fn test_with_lost_ordering_unbounded_bounded( #[values(false, true)] source_unbounded: bool, ) -> Result<()> { let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let sort_exprs = [sort_expr("a", &schema)]; // create either bounded or unbounded source let source = if source_unbounded { - stream_exec_ordered(&schema, sort_exprs) + stream_exec_ordered(&schema, sort_exprs.clone().into()) } else { - csv_exec_ordered(&schema, sort_exprs) + csv_exec_sorted(&schema, sort_exprs.clone()) }; let repartition_rr = repartition_exec(source); let repartition_hash = Arc::new(RepartitionExec::try_new( repartition_rr, - Partitioning::Hash(vec![col("c", &schema).unwrap()], 10), + Partitioning::Hash(vec![col("c", &schema)?], 10), )?) as _; let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); + let physical_plan = sort_exec(sort_exprs.into(), coalesce_partitions); // Expected inputs unbounded and bounded let expected_input_unbounded = vec![ @@ -1079,7 +1474,7 @@ async fn test_with_lost_ordering_unbounded_bounded( " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=true", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false", ]; // Expected unbounded result (same for with and without flag) @@ -1096,14 +1491,14 @@ async fn test_with_lost_ordering_unbounded_bounded( " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=true", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false", ]; let expected_optimized_bounded_parallelize_sort = vec![ "SortPreservingMergeExec: [a@0 ASC]", " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=true", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=csv, has_header=false", ]; let (expected_input, expected_optimized, expected_optimized_sort_parallelize) = if source_unbounded { @@ -1138,20 +1533,24 @@ async fn test_with_lost_ordering_unbounded_bounded( #[tokio::test] async fn test_do_not_pushdown_through_spm() -> Result<()> { let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let sort_exprs = [sort_expr("a", &schema), sort_expr("b", &schema)]; let source = csv_exec_sorted(&schema, sort_exprs.clone()); let repartition_rr = repartition_exec(source); - let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); - let physical_plan = sort_exec(vec![sort_expr("b", &schema)], spm); + let spm = sort_preserving_merge_exec(sort_exprs.into(), repartition_rr); + let physical_plan = sort_exec([sort_expr("b", &schema)].into(), spm); - let expected_input = ["SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", + let expected_input = [ + "SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false",]; - let expected_optimized = ["SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false", + ]; + let expected_optimized = [ + "SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false",]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, false); Ok(()) @@ -1160,27 +1559,31 @@ async fn test_do_not_pushdown_through_spm() -> Result<()> { #[tokio::test] async fn test_pushdown_through_spm() -> Result<()> { let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let sort_exprs = [sort_expr("a", &schema), sort_expr("b", &schema)]; let source = csv_exec_sorted(&schema, sort_exprs.clone()); let repartition_rr = repartition_exec(source); - let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); + let spm = sort_preserving_merge_exec(sort_exprs.into(), repartition_rr); let physical_plan = sort_exec( - vec![ + [ sort_expr("a", &schema), sort_expr("b", &schema), sort_expr("c", &schema), - ], + ] + .into(), spm, ); - let expected_input = ["SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false]", + let expected_input = [ + "SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false]", " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false",]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false", + ]; let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", " SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false",]; + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], file_type=csv, has_header=false", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, false); Ok(()) @@ -1189,17 +1592,16 @@ async fn test_pushdown_through_spm() -> Result<()> { #[tokio::test] async fn test_window_multi_layer_requirement() -> Result<()> { let schema = create_test_schema3()?; - let sort_exprs = vec![sort_expr("a", &schema), sort_expr("b", &schema)]; + let sort_exprs = [sort_expr("a", &schema), sort_expr("b", &schema)]; let source = csv_exec_sorted(&schema, vec![]); - let sort = sort_exec(sort_exprs.clone(), source); + let sort = sort_exec(sort_exprs.clone().into(), source); let repartition = repartition_exec(sort); let repartition = spr_repartition_exec(repartition); - let spm = sort_preserving_merge_exec(sort_exprs.clone(), repartition); - + let spm = sort_preserving_merge_exec(sort_exprs.clone().into(), repartition); let physical_plan = bounded_window_exec("a", sort_exprs, spm); let expected_input = [ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC, b@1 ASC", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", @@ -1207,7 +1609,7 @@ async fn test_window_multi_layer_requirement() -> Result<()> { " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", ]; let expected_optimized = [ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", @@ -1221,15 +1623,15 @@ async fn test_window_multi_layer_requirement() -> Result<()> { #[tokio::test] async fn test_not_replaced_with_partial_sort_for_bounded_input() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let parquet_input = parquet_exec_sorted(&schema, input_sort_exprs); - + let parquet_ordering = [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let parquet_input = parquet_exec_with_sort(schema.clone(), vec![parquet_ordering]); let physical_plan = sort_exec( - vec![ + [ sort_expr("a", &schema), sort_expr("b", &schema), sort_expr("c", &schema), - ], + ] + .into(), parquet_input, ); let expected_input = [ @@ -1333,8 +1735,8 @@ macro_rules! assert_optimized { async fn test_remove_unnecessary_sort() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = sort_exec(vec![sort_expr("non_nullable_col", &schema)], source); - let physical_plan = sort_exec(vec![sort_expr("nullable_col", &schema)], input); + let input = sort_exec([sort_expr("non_nullable_col", &schema)].into(), source); + let physical_plan = sort_exec([sort_expr("nullable_col", &schema)].into(), input); let expected_input = [ "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", @@ -1354,57 +1756,54 @@ async fn test_remove_unnecessary_sort() -> Result<()> { async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - - let sort_exprs = vec![sort_expr_options( + let ordering: LexOrdering = [sort_expr_options( "non_nullable_col", &source.schema(), SortOptions { descending: true, nulls_first: true, }, - )]; - let sort = sort_exec(sort_exprs.clone(), source); + )] + .into(); + let sort = sort_exec(ordering.clone(), source); // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before - let coalesce_batches = coalesce_batches_exec(sort); - - let window_agg = - bounded_window_exec("non_nullable_col", sort_exprs, coalesce_batches); - - let sort_exprs = vec![sort_expr_options( + let coalesce_batches = coalesce_batches_exec(sort, 128); + let window_agg = bounded_window_exec("non_nullable_col", ordering, coalesce_batches); + let ordering2: LexOrdering = [sort_expr_options( "non_nullable_col", &window_agg.schema(), SortOptions { descending: false, nulls_first: false, }, - )]; - - let sort = sort_exec(sort_exprs.clone(), window_agg); - + )] + .into(); + let sort = sort_exec(ordering2.clone(), window_agg); // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before let filter = filter_exec( - Arc::new(NotExpr::new( - col("non_nullable_col", schema.as_ref()).unwrap(), - )), + Arc::new(NotExpr::new(col("non_nullable_col", schema.as_ref())?)), sort, ); + let physical_plan = bounded_window_exec("non_nullable_col", ordering2, filter); - let physical_plan = bounded_window_exec("non_nullable_col", sort_exprs, filter); - - let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " FilterExec: NOT non_nullable_col@1", " SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " CoalesceBatchesExec: target_batch_size=128", " SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; + " DataSourceExec: partitions=1, partition_sizes=[0]" + ]; - let expected_optimized = ["WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + let expected_optimized = [ + "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " FilterExec: NOT non_nullable_col@1", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " CoalesceBatchesExec: target_batch_size=128", " SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; + " DataSourceExec: partitions=1, partition_sizes=[0]" + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -1414,10 +1813,8 @@ async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> { async fn test_add_required_sort() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - - let physical_plan = sort_preserving_merge_exec(sort_exprs, source); + let ordering = [sort_expr("nullable_col", &schema)].into(); + let physical_plan = sort_preserving_merge_exec(ordering, source); let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC]", @@ -1436,13 +1833,12 @@ async fn test_add_required_sort() -> Result<()> { async fn test_remove_unnecessary_sort1() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source); - let spm = sort_preserving_merge_exec(sort_exprs, sort); + let ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source); + let spm = sort_preserving_merge_exec(ordering.clone(), sort); + let sort = sort_exec(ordering.clone(), spm); + let physical_plan = sort_preserving_merge_exec(ordering, sort); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), spm); - let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", @@ -1463,19 +1859,18 @@ async fn test_remove_unnecessary_sort1() -> Result<()> { async fn test_remove_unnecessary_sort2() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr("non_nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source); - let spm = sort_preserving_merge_exec(sort_exprs, sort); - - let sort_exprs = vec![ + let ordering: LexOrdering = [sort_expr("non_nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source); + let spm = sort_preserving_merge_exec(ordering, sort); + let ordering2: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let sort2 = sort_exec(sort_exprs.clone(), spm); - let spm2 = sort_preserving_merge_exec(sort_exprs, sort2); - - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort3 = sort_exec(sort_exprs, spm2); + ] + .into(); + let sort2 = sort_exec(ordering2.clone(), spm); + let spm2 = sort_preserving_merge_exec(ordering2, sort2); + let ordering3 = [sort_expr("nullable_col", &schema)].into(); + let sort3 = sort_exec(ordering3, spm2); let physical_plan = repartition_exec(repartition_exec(sort3)); let expected_input = [ @@ -1488,7 +1883,6 @@ async fn test_remove_unnecessary_sort2() -> Result<()> { " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: partitions=1, partition_sizes=[0]", ]; - let expected_optimized = [ "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", @@ -1503,21 +1897,20 @@ async fn test_remove_unnecessary_sort2() -> Result<()> { async fn test_remove_unnecessary_sort3() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr("non_nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source); - let spm = sort_preserving_merge_exec(sort_exprs, sort); - - let sort_exprs = LexOrdering::new(vec![ + let ordering: LexOrdering = [sort_expr("non_nullable_col", &schema)].into(); + let sort = sort_exec(ordering.clone(), source); + let spm = sort_preserving_merge_exec(ordering, sort); + let repartition_exec = repartition_exec(spm); + let ordering2: LexOrdering = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]); - let repartition_exec = repartition_exec(spm); + ] + .into(); let sort2 = Arc::new( - SortExec::new(sort_exprs.clone(), repartition_exec) + SortExec::new(ordering2.clone(), repartition_exec) .with_preserve_partitioning(true), ) as _; - let spm2 = sort_preserving_merge_exec(sort_exprs, sort2); - + let spm2 = sort_preserving_merge_exec(ordering2, sort2); let physical_plan = aggregate_exec(spm2); // When removing a `SortPreservingMergeExec`, make sure that partitioning @@ -1532,7 +1925,6 @@ async fn test_remove_unnecessary_sort3() -> Result<()> { " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", " DataSourceExec: partitions=1, partition_sizes=[0]", ]; - let expected_optimized = [ "AggregateExec: mode=Final, gby=[], aggr=[]", " CoalescePartitionsExec", @@ -1548,34 +1940,29 @@ async fn test_remove_unnecessary_sort3() -> Result<()> { async fn test_remove_unnecessary_sort4() -> Result<()> { let schema = create_test_schema()?; let source1 = repartition_exec(memory_exec(&schema)); - let source2 = repartition_exec(memory_exec(&schema)); let union = union_exec(vec![source1, source2]); - - let sort_exprs = LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]); - // let sort = sort_exec(sort_exprs.clone(), union); - let sort = Arc::new( - SortExec::new(sort_exprs.clone(), union).with_preserve_partitioning(true), - ) as _; - let spm = sort_preserving_merge_exec(sort_exprs, sort); - + let ordering: LexOrdering = [sort_expr("non_nullable_col", &schema)].into(); + let sort = + Arc::new(SortExec::new(ordering.clone(), union).with_preserve_partitioning(true)) + as _; + let spm = sort_preserving_merge_exec(ordering, sort); let filter = filter_exec( - Arc::new(NotExpr::new( - col("non_nullable_col", schema.as_ref()).unwrap(), - )), + Arc::new(NotExpr::new(col("non_nullable_col", schema.as_ref())?)), spm, ); - - let sort_exprs = vec![ + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; - let physical_plan = sort_exec(sort_exprs, filter); + ] + .into(); + let physical_plan = sort_exec(ordering2, filter); // When removing a `SortPreservingMergeExec`, make sure that partitioning // requirements are not violated. In some cases, we may need to replace // it with a `CoalescePartitionsExec` instead of directly removing it. - let expected_input = ["SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", + let expected_input = [ + "SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " FilterExec: NOT non_nullable_col@1", " SortPreservingMergeExec: [non_nullable_col@1 ASC]", " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[true]", @@ -1583,16 +1970,18 @@ async fn test_remove_unnecessary_sort4() -> Result<()> { " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " DataSourceExec: partitions=1, partition_sizes=[0]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; - - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + " DataSourceExec: partitions=1, partition_sizes=[0]", + ]; + let expected_optimized = [ + "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true]", " FilterExec: NOT non_nullable_col@1", " UnionExec", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " DataSourceExec: partitions=1, partition_sizes=[0]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; + " DataSourceExec: partitions=1, partition_sizes=[0]", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -1602,18 +1991,17 @@ async fn test_remove_unnecessary_sort4() -> Result<()> { async fn test_remove_unnecessary_sort6() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = Arc::new( - SortExec::new( - LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), - source, - ) - .with_fetch(Some(2)), + let input = sort_exec_with_fetch( + [sort_expr("non_nullable_col", &schema)].into(), + Some(2), + source, ); let physical_plan = sort_exec( - vec![ + [ sort_expr("non_nullable_col", &schema), sort_expr("nullable_col", &schema), - ], + ] + .into(), input, ); @@ -1635,21 +2023,19 @@ async fn test_remove_unnecessary_sort6() -> Result<()> { async fn test_remove_unnecessary_sort7() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = Arc::new(SortExec::new( - LexOrdering::new(vec![ + let input = sort_exec( + [ sort_expr("non_nullable_col", &schema), sort_expr("nullable_col", &schema), - ]), + ] + .into(), source, - )); - - let physical_plan = Arc::new( - SortExec::new( - LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), - input, - ) - .with_fetch(Some(2)), - ) as Arc; + ); + let physical_plan = sort_exec_with_fetch( + [sort_expr("non_nullable_col", &schema)].into(), + Some(2), + input, + ); let expected_input = [ "SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC], preserve_partitioning=[false], sort_prefix=[non_nullable_col@1 ASC]", @@ -1670,16 +2056,14 @@ async fn test_remove_unnecessary_sort7() -> Result<()> { async fn test_remove_unnecessary_sort8() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = Arc::new(SortExec::new( - LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), - source, - )); + let input = sort_exec([sort_expr("non_nullable_col", &schema)].into(), source); let limit = Arc::new(LocalLimitExec::new(input, 2)); let physical_plan = sort_exec( - vec![ + [ sort_expr("non_nullable_col", &schema), sort_expr("nullable_col", &schema), - ], + ] + .into(), limit, ); @@ -1703,13 +2087,9 @@ async fn test_remove_unnecessary_sort8() -> Result<()> { async fn test_do_not_pushdown_through_limit() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - // let input = sort_exec(vec![sort_expr("non_nullable_col", &schema)], source); - let input = Arc::new(SortExec::new( - LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), - source, - )); + let input = sort_exec([sort_expr("non_nullable_col", &schema)].into(), source); let limit = Arc::new(GlobalLimitExec::new(input, 0, Some(5))) as _; - let physical_plan = sort_exec(vec![sort_expr("nullable_col", &schema)], limit); + let physical_plan = sort_exec([sort_expr("nullable_col", &schema)].into(), limit); let expected_input = [ "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", @@ -1732,12 +2112,11 @@ async fn test_do_not_pushdown_through_limit() -> Result<()> { async fn test_remove_unnecessary_spm1() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let input = - sort_preserving_merge_exec(vec![sort_expr("non_nullable_col", &schema)], source); - let input2 = - sort_preserving_merge_exec(vec![sort_expr("non_nullable_col", &schema)], input); + let ordering: LexOrdering = [sort_expr("non_nullable_col", &schema)].into(); + let input = sort_preserving_merge_exec(ordering.clone(), source); + let input2 = sort_preserving_merge_exec(ordering, input); let physical_plan = - sort_preserving_merge_exec(vec![sort_expr("nullable_col", &schema)], input2); + sort_preserving_merge_exec([sort_expr("nullable_col", &schema)].into(), input2); let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC]", @@ -1759,7 +2138,7 @@ async fn test_remove_unnecessary_spm2() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); let input = sort_preserving_merge_exec_with_fetch( - vec![sort_expr("non_nullable_col", &schema)], + [sort_expr("non_nullable_col", &schema)].into(), source, 100, ); @@ -1782,12 +2161,13 @@ async fn test_remove_unnecessary_spm2() -> Result<()> { async fn test_change_wrong_sorting() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![ + let sort_exprs = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), ]; - let sort = sort_exec(vec![sort_exprs[0].clone()], source); - let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); + let sort = sort_exec([sort_exprs[0].clone()].into(), source); + let physical_plan = sort_preserving_merge_exec(sort_exprs.into(), sort); + let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", @@ -1806,13 +2186,13 @@ async fn test_change_wrong_sorting() -> Result<()> { async fn test_change_wrong_sorting2() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - let sort_exprs = vec![ + let sort_exprs = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), ]; - let spm1 = sort_preserving_merge_exec(sort_exprs.clone(), source); - let sort2 = sort_exec(vec![sort_exprs[0].clone()], spm1); - let physical_plan = sort_preserving_merge_exec(vec![sort_exprs[1].clone()], sort2); + let spm1 = sort_preserving_merge_exec(sort_exprs.clone().into(), source); + let sort2 = sort_exec([sort_exprs[0].clone()].into(), spm1); + let physical_plan = sort_preserving_merge_exec([sort_exprs[1].clone()].into(), sort2); let expected_input = [ "SortPreservingMergeExec: [non_nullable_col@1 ASC]", @@ -1833,31 +2213,31 @@ async fn test_change_wrong_sorting2() -> Result<()> { async fn test_multiple_sort_window_exec() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); - - let sort_exprs1 = vec![sort_expr("nullable_col", &schema)]; - let sort_exprs2 = vec![ + let ordering1 = [sort_expr("nullable_col", &schema)]; + let sort1 = sort_exec(ordering1.clone().into(), source); + let window_agg1 = bounded_window_exec("non_nullable_col", ordering1.clone(), sort1); + let ordering2 = [ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), ]; + let window_agg2 = bounded_window_exec("non_nullable_col", ordering2, window_agg1); + let physical_plan = bounded_window_exec("non_nullable_col", ordering1, window_agg2); - let sort1 = sort_exec(sort_exprs1.clone(), source); - let window_agg1 = bounded_window_exec("non_nullable_col", sort_exprs1.clone(), sort1); - let window_agg2 = bounded_window_exec("non_nullable_col", sort_exprs2, window_agg1); - // let filter_exec = sort_exec; - let physical_plan = bounded_window_exec("non_nullable_col", sort_exprs1, window_agg2); - - let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; - - let expected_optimized = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " DataSourceExec: partitions=1, partition_sizes=[0]", + ]; + let expected_optimized = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", - " DataSourceExec: partitions=1, partition_sizes=[0]"]; + " DataSourceExec: partitions=1, partition_sizes=[0]", + ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); Ok(()) @@ -1871,20 +2251,17 @@ async fn test_multiple_sort_window_exec() -> Result<()> { // EnforceDistribution may invalidate ordering invariant. async fn test_commutativity() -> Result<()> { let schema = create_test_schema()?; - let config = ConfigOptions::new(); - let memory_exec = memory_exec(&schema); - let sort_exprs = LexOrdering::new(vec![sort_expr("nullable_col", &schema)]); + let sort_exprs = [sort_expr("nullable_col", &schema)]; let window = bounded_window_exec("nullable_col", sort_exprs.clone(), memory_exec); let repartition = repartition_exec(window); + let orig_plan = sort_exec(sort_exprs.into(), repartition); - let orig_plan = - Arc::new(SortExec::new(sort_exprs, repartition)) as Arc; let actual = get_plan_string(&orig_plan); let expected_input = vec![ "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " DataSourceExec: partitions=1, partition_sizes=[0]", ]; assert_eq!( @@ -1892,26 +2269,25 @@ async fn test_commutativity() -> Result<()> { "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_input:#?}\nactual:\n\n{actual:#?}\n\n" ); - let mut plan = orig_plan.clone(); + let config = ConfigOptions::new(); let rules = vec![ Arc::new(EnforceDistribution::new()) as Arc, Arc::new(EnforceSorting::new()) as Arc, ]; + let mut first_plan = orig_plan.clone(); for rule in rules { - plan = rule.optimize(plan, &config)?; + first_plan = rule.optimize(first_plan, &config)?; } - let first_plan = plan.clone(); - let mut plan = orig_plan.clone(); let rules = vec![ Arc::new(EnforceSorting::new()) as Arc, Arc::new(EnforceDistribution::new()) as Arc, Arc::new(EnforceSorting::new()) as Arc, ]; + let mut second_plan = orig_plan.clone(); for rule in rules { - plan = rule.optimize(plan, &config)?; + second_plan = rule.optimize(second_plan, &config)?; } - let second_plan = plan.clone(); assert_eq!(get_plan_string(&first_plan), get_plan_string(&second_plan)); Ok(()) @@ -1922,15 +2298,15 @@ async fn test_coalesce_propagate() -> Result<()> { let schema = create_test_schema()?; let source = memory_exec(&schema); let repartition = repartition_exec(source); - let coalesce_partitions = Arc::new(CoalescePartitionsExec::new(repartition)); + let coalesce_partitions = coalesce_partitions_exec(repartition); let repartition = repartition_exec(coalesce_partitions); - let sort_exprs = LexOrdering::new(vec![sort_expr("nullable_col", &schema)]); + let ordering: LexOrdering = [sort_expr("nullable_col", &schema)].into(); // Add local sort let sort = Arc::new( - SortExec::new(sort_exprs.clone(), repartition).with_preserve_partitioning(true), + SortExec::new(ordering.clone(), repartition).with_preserve_partitioning(true), ) as _; - let spm = sort_preserving_merge_exec(sort_exprs.clone(), sort); - let sort = sort_exec(sort_exprs, spm); + let spm = sort_preserving_merge_exec(ordering.clone(), sort); + let sort = sort_exec(ordering, spm); let physical_plan = sort.clone(); // Sort Parallelize rule should end Coalesce + Sort linkage when Sort is Global Sort @@ -1958,15 +2334,15 @@ async fn test_coalesce_propagate() -> Result<()> { #[tokio::test] async fn test_replace_with_partial_sort2() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("a", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs); - + let input_ordering = [sort_expr("a", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering); let physical_plan = sort_exec( - vec![ + [ sort_expr("a", &schema), sort_expr("c", &schema), sort_expr("d", &schema), - ], + ] + .into(), unbounded_input, ); @@ -1985,76 +2361,65 @@ async fn test_replace_with_partial_sort2() -> Result<()> { #[tokio::test] async fn test_push_with_required_input_ordering_prohibited() -> Result<()> { - // SortExec: expr=[b] <-- can't push this down - // RequiredInputOrder expr=[a] <-- this requires input sorted by a, and preserves the input order - // SortExec: expr=[a] - // DataSourceExec let schema = create_test_schema3()?; - let sort_exprs_a = LexOrdering::new(vec![sort_expr("a", &schema)]); - let sort_exprs_b = LexOrdering::new(vec![sort_expr("b", &schema)]); + let ordering_a: LexOrdering = [sort_expr("a", &schema)].into(); + let ordering_b: LexOrdering = [sort_expr("b", &schema)].into(); let plan = memory_exec(&schema); - let plan = sort_exec(sort_exprs_a.clone(), plan); + let plan = sort_exec(ordering_a.clone(), plan); let plan = RequirementsTestExec::new(plan) - .with_required_input_ordering(sort_exprs_a) + .with_required_input_ordering(Some(ordering_a)) .with_maintains_input_order(true) .into_arc(); - let plan = sort_exec(sort_exprs_b, plan); + let plan = sort_exec(ordering_b, plan); let expected_input = [ - "SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", - " RequiredInputOrderingExec", + "SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", // <-- can't push this down + " RequiredInputOrderingExec", // <-- this requires input sorted by a, and preserves the input order " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", " DataSourceExec: partitions=1, partition_sizes=[0]", ]; // should not be able to push shorts - let expected_no_change = expected_input; - assert_optimized!(expected_input, expected_no_change, plan, true); + assert_optimized!(expected_input, expected_input, plan, true); Ok(()) } // test when the required input ordering is satisfied so could push through #[tokio::test] async fn test_push_with_required_input_ordering_allowed() -> Result<()> { - // SortExec: expr=[a,b] <-- can push this down (as it is compatible with the required input ordering) - // RequiredInputOrder expr=[a] <-- this requires input sorted by a, and preserves the input order - // SortExec: expr=[a] - // DataSourceExec let schema = create_test_schema3()?; - let sort_exprs_a = LexOrdering::new(vec![sort_expr("a", &schema)]); - let sort_exprs_ab = - LexOrdering::new(vec![sort_expr("a", &schema), sort_expr("b", &schema)]); + let ordering_a: LexOrdering = [sort_expr("a", &schema)].into(); + let ordering_ab = [sort_expr("a", &schema), sort_expr("b", &schema)].into(); let plan = memory_exec(&schema); - let plan = sort_exec(sort_exprs_a.clone(), plan); + let plan = sort_exec(ordering_a.clone(), plan); let plan = RequirementsTestExec::new(plan) - .with_required_input_ordering(sort_exprs_a) + .with_required_input_ordering(Some(ordering_a)) .with_maintains_input_order(true) .into_arc(); - let plan = sort_exec(sort_exprs_ab, plan); + let plan = sort_exec(ordering_ab, plan); let expected_input = [ - "SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", - " RequiredInputOrderingExec", + "SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", // <-- can push this down (as it is compatible with the required input ordering) + " RequiredInputOrderingExec", // <-- this requires input sorted by a, and preserves the input order " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", " DataSourceExec: partitions=1, partition_sizes=[0]", ]; - // should able to push shorts - let expected = [ + // Should be able to push down + let expected_optimized = [ "RequiredInputOrderingExec", " SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", " DataSourceExec: partitions=1, partition_sizes=[0]", ]; - assert_optimized!(expected_input, expected, plan, true); + assert_optimized!(expected_input, expected_optimized, plan, true); Ok(()) } #[tokio::test] async fn test_replace_with_partial_sort() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("a", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs); - + let input_ordering = [sort_expr("a", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering); let physical_plan = sort_exec( - vec![sort_expr("a", &schema), sort_expr("c", &schema)], + [sort_expr("a", &schema), sort_expr("c", &schema)].into(), unbounded_input, ); @@ -2073,38 +2438,38 @@ async fn test_replace_with_partial_sort() -> Result<()> { #[tokio::test] async fn test_not_replaced_with_partial_sort_for_unbounded_input() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs); - + let input_ordering = [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering); let physical_plan = sort_exec( - vec![ + [ sort_expr("a", &schema), sort_expr("b", &schema), sort_expr("c", &schema), - ], + ] + .into(), unbounded_input, ); let expected_input = [ "SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false]", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]" ]; - let expected_no_change = expected_input; - assert_optimized!(expected_input, expected_no_change, physical_plan, true); + assert_optimized!(expected_input, expected_input, physical_plan, true); Ok(()) } #[tokio::test] async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { let input_schema = create_test_schema()?; - let sort_exprs = vec![sort_expr_options( + let ordering = [sort_expr_options( "nullable_col", &input_schema, SortOptions { descending: false, nulls_first: false, }, - )]; - let source = parquet_exec_sorted(&input_schema, sort_exprs); + )] + .into(); + let source = parquet_exec_with_sort(input_schema.clone(), vec![ordering]) as _; // Function definition - Alias of the resulting column - Arguments of the function #[derive(Clone)] @@ -2460,11 +2825,11 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("nullable_col", true, false), ("count", false, false)], initial_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ - "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2476,11 +2841,11 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("max", false, true), ("nullable_col", true, false)], initial_plan: vec![ "SortExec: expr=[max@2 DESC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ - "WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + "WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2492,11 +2857,11 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("min", true, true), ("nullable_col", true, false)], initial_plan: vec![ "SortExec: expr=[min@2 ASC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ - "WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + "WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2508,12 +2873,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("avg", false, false), ("nullable_col", true, false)], initial_plan: vec![ "SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ "SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2529,12 +2894,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("nullable_col", true, false), ("count", true, false)], initial_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2546,11 +2911,11 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("nullable_col", true, false), ("max", false, true)], initial_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ - "WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + "WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2562,12 +2927,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("min", true, false), ("nullable_col", true, false)], initial_plan: vec![ "SortExec: expr=[min@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ "SortExec: expr=[min@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2579,12 +2944,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("avg", false, false), ("nullable_col", true, false)], initial_plan: vec![ "SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ "SortExec: expr=[avg@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2600,11 +2965,11 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("nullable_col", true, false), ("count", false, false)], initial_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ - "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2616,12 +2981,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("nullable_col", true, false), ("max", true, false)], initial_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2633,12 +2998,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("min", false, false)], initial_plan: vec![ "SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ "SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2650,12 +3015,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("avg", false, false)], initial_plan: vec![ "SortExec: expr=[avg@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ "SortExec: expr=[avg@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2671,12 +3036,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("count", false, false), ("nullable_col", true, false)], initial_plan: vec![ "SortExec: expr=[count@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ "SortExec: expr=[count@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" ], }, @@ -2688,11 +3053,11 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("nullable_col", true, false), ("max", false, true)], initial_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC], preserve_partitioning=[false]", - " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ - "WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + "WindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2704,12 +3069,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("min", false, false)], initial_plan: vec![ "SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ "SortExec: expr=[min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2721,12 +3086,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("nullable_col", true, false), ("avg", true, false)], initial_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" ], expected_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + " WindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" ], }, @@ -2742,11 +3107,11 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("nullable_col", true, false), ("count", true, false)], initial_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2758,12 +3123,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("max", false, false), ("nullable_col", true, false)], initial_plan: vec![ "SortExec: expr=[max@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ "SortExec: expr=[max@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" ], }, @@ -2775,11 +3140,11 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("min", false, false), ("nullable_col", true, false)], initial_plan: vec![ "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" ], expected_plan: vec![ - "BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + "BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2791,12 +3156,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("nullable_col", true, false), ("avg", true, false)], initial_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2812,11 +3177,11 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("nullable_col", true, false), ("count", true, true)], initial_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2828,11 +3193,11 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("max", true, false), ("nullable_col", true, false)], initial_plan: vec![ "SortExec: expr=[max@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ - "BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + "BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2844,12 +3209,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("min", false, true), ("nullable_col", true, false)], initial_plan: vec![ "SortExec: expr=[min@2 DESC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ "SortExec: expr=[min@2 DESC, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2861,12 +3226,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("avg", true, false)], initial_plan: vec![ "SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ "SortExec: expr=[avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2882,11 +3247,11 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("nullable_col", true, false), ("count", true, false)], initial_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, count@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2898,12 +3263,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("max", true, false), ("nullable_col", true, false)], initial_plan: vec![ "SortExec: expr=[max@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" ], expected_plan: vec![ "SortExec: expr=[max@2 ASC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet" ], }, @@ -2915,12 +3280,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("min", false, false), ("nullable_col", true, false)], initial_plan: vec![ "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ "SortExec: expr=[min@2 DESC NULLS LAST, nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2932,12 +3297,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("nullable_col", true, false), ("avg", true, false)], initial_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, avg@2 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2953,12 +3318,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![ ("count", true, true)], initial_plan: vec![ "SortExec: expr=[count@2 ASC], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ "SortExec: expr=[count@2 ASC], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, // Case 45: @@ -2969,12 +3334,12 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("nullable_col", true, false), ("max", false, false)], initial_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, max@2 DESC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[max: Ok(Field { name: \"max\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -2986,11 +3351,11 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("nullable_col", true, false), ("min", false, false)], initial_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST, min@2 DESC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ - "BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + "BoundedWindowAggExec: wdw=[min: Ok(Field { name: \"min\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -3002,11 +3367,11 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { required_sort_columns: vec![("nullable_col", true, false)], initial_plan: vec![ "SortExec: expr=[nullable_col@0 ASC NULLS LAST], preserve_partitioning=[false]", - " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], expected_plan: vec![ - "BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", + "BoundedWindowAggExec: wdw=[avg: Ok(Field { name: \"avg\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC NULLS LAST], file_type=parquet", ], }, @@ -3307,7 +3672,7 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { case.func.1, &case.func.2, &partition_by, - &LexOrdering::default(), + &[], case.window_frame, input_schema.as_ref(), false, @@ -3341,7 +3706,8 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { ) }) .collect::>(); - let physical_plan = sort_exec(sort_expr, window_exec); + let ordering = LexOrdering::new(sort_expr).unwrap(); + let physical_plan = sort_exec(ordering, window_exec); assert_optimized!( case.initial_plan, @@ -3358,11 +3724,11 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { #[test] fn test_removes_unused_orthogonal_sort() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs.clone()); - - let orthogonal_sort = sort_exec(vec![sort_expr("a", &schema)], unbounded_input); - let output_sort = sort_exec(input_sort_exprs, orthogonal_sort); // same sort as data source + let input_ordering: LexOrdering = + [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering.clone()); + let orthogonal_sort = sort_exec([sort_expr("a", &schema)].into(), unbounded_input); + let output_sort = sort_exec(input_ordering, orthogonal_sort); // same sort as data source // Test scenario/input has an orthogonal sort: let expected_input = [ @@ -3370,7 +3736,7 @@ fn test_removes_unused_orthogonal_sort() -> Result<()> { " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]" ]; - assert_eq!(get_plan_string(&output_sort), expected_input,); + assert_eq!(get_plan_string(&output_sort), expected_input); // Test: should remove orthogonal sort, and the uppermost (unneeded) sort: let expected_optimized = [ @@ -3384,12 +3750,12 @@ fn test_removes_unused_orthogonal_sort() -> Result<()> { #[test] fn test_keeps_used_orthogonal_sort() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs.clone()); - + let input_ordering: LexOrdering = + [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering.clone()); let orthogonal_sort = - sort_exec_with_fetch(vec![sort_expr("a", &schema)], Some(3), unbounded_input); // has fetch, so this orthogonal sort changes the output - let output_sort = sort_exec(input_sort_exprs, orthogonal_sort); + sort_exec_with_fetch([sort_expr("a", &schema)].into(), Some(3), unbounded_input); // has fetch, so this orthogonal sort changes the output + let output_sort = sort_exec(input_ordering, orthogonal_sort); // Test scenario/input has an orthogonal sort: let expected_input = [ @@ -3397,7 +3763,7 @@ fn test_keeps_used_orthogonal_sort() -> Result<()> { " SortExec: TopK(fetch=3), expr=[a@0 ASC], preserve_partitioning=[false]", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]" ]; - assert_eq!(get_plan_string(&output_sort), expected_input,); + assert_eq!(get_plan_string(&output_sort), expected_input); // Test: should keep the orthogonal sort, since it modifies the output: let expected_optimized = expected_input; @@ -3409,15 +3775,17 @@ fn test_keeps_used_orthogonal_sort() -> Result<()> { #[test] fn test_handles_multiple_orthogonal_sorts() -> Result<()> { let schema = create_test_schema3()?; - let input_sort_exprs = vec![sort_expr("b", &schema), sort_expr("c", &schema)]; - let unbounded_input = stream_exec_ordered(&schema, input_sort_exprs.clone()); - - let orthogonal_sort_0 = sort_exec(vec![sort_expr("c", &schema)], unbounded_input); // has no fetch, so can be removed + let input_ordering: LexOrdering = + [sort_expr("b", &schema), sort_expr("c", &schema)].into(); + let unbounded_input = stream_exec_ordered(&schema, input_ordering.clone()); + let ordering0: LexOrdering = [sort_expr("c", &schema)].into(); + let orthogonal_sort_0 = sort_exec(ordering0.clone(), unbounded_input); // has no fetch, so can be removed + let ordering1: LexOrdering = [sort_expr("a", &schema)].into(); let orthogonal_sort_1 = - sort_exec_with_fetch(vec![sort_expr("a", &schema)], Some(3), orthogonal_sort_0); // has fetch, so this orthogonal sort changes the output - let orthogonal_sort_2 = sort_exec(vec![sort_expr("c", &schema)], orthogonal_sort_1); // has no fetch, so can be removed - let orthogonal_sort_3 = sort_exec(vec![sort_expr("a", &schema)], orthogonal_sort_2); // has no fetch, so can be removed - let output_sort = sort_exec(input_sort_exprs, orthogonal_sort_3); // final sort + sort_exec_with_fetch(ordering1.clone(), Some(3), orthogonal_sort_0); // has fetch, so this orthogonal sort changes the output + let orthogonal_sort_2 = sort_exec(ordering0, orthogonal_sort_1); // has no fetch, so can be removed + let orthogonal_sort_3 = sort_exec(ordering1, orthogonal_sort_2); // has no fetch, so can be removed + let output_sort = sort_exec(input_ordering, orthogonal_sort_3); // final sort // Test scenario/input has an orthogonal sort: let expected_input = [ @@ -3428,7 +3796,7 @@ fn test_handles_multiple_orthogonal_sorts() -> Result<()> { " SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]", ]; - assert_eq!(get_plan_string(&output_sort), expected_input,); + assert_eq!(get_plan_string(&output_sort), expected_input); // Test: should keep only the needed orthogonal sort, and remove the unneeded ones: let expected_optimized = [ @@ -3440,3 +3808,38 @@ fn test_handles_multiple_orthogonal_sorts() -> Result<()> { Ok(()) } + +#[test] +fn test_parallelize_sort_preserves_fetch() -> Result<()> { + // Create a schema + let schema = create_test_schema3()?; + let parquet_exec = parquet_exec(schema); + let coalesced = coalesce_partitions_exec(parquet_exec.clone()); + let top_coalesced = coalesce_partitions_exec(coalesced.clone()) + .with_fetch(Some(10)) + .unwrap(); + + let requirements = PlanWithCorrespondingCoalescePartitions::new( + top_coalesced, + true, + vec![PlanWithCorrespondingCoalescePartitions::new( + coalesced, + true, + vec![PlanWithCorrespondingCoalescePartitions::new( + parquet_exec, + false, + vec![], + )], + )], + ); + + let res = parallelize_sorts(requirements)?; + + // Verify fetch was preserved + assert_eq!( + res.data.plan.fetch(), + Some(10), + "Fetch value was not preserved after transformation" + ); + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs new file mode 100644 index 000000000000..f1ef365c9220 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs @@ -0,0 +1,521 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::{Arc, LazyLock}; + +use arrow::{ + array::record_batch, + datatypes::{DataType, Field, Schema, SchemaRef}, + util::pretty::pretty_format_batches, +}; +use arrow_schema::SortOptions; +use datafusion::{ + logical_expr::Operator, + physical_plan::{ + expressions::{BinaryExpr, Column, Literal}, + PhysicalExpr, + }, + prelude::{ParquetReadOptions, SessionConfig, SessionContext}, + scalar::ScalarValue, +}; +use datafusion_common::config::ConfigOptions; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_physical_expr::{aggregate::AggregateExprBuilder, Partitioning}; +use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; +use datafusion_physical_optimizer::{ + filter_pushdown::FilterPushdown, PhysicalOptimizerRule, +}; +use datafusion_physical_plan::{ + aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}, + coalesce_batches::CoalesceBatchesExec, + filter::FilterExec, + repartition::RepartitionExec, + sorts::sort::SortExec, + ExecutionPlan, +}; + +use futures::StreamExt; +use object_store::{memory::InMemory, ObjectStore}; +use util::{format_plan_for_test, OptimizationTest, TestNode, TestScanBuilder}; + +mod util; + +#[test] +fn test_pushdown_into_scan() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, scan).unwrap()); + + // expect the predicate to be pushed down into the DataSource + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +/// Show that we can use config options to determine how to do pushdown. +#[test] +fn test_pushdown_into_scan_with_config_options() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, scan).unwrap()) as _; + + let mut cfg = ConfigOptions::default(); + insta::assert_snapshot!( + OptimizationTest::new( + Arc::clone(&plan), + FilterPushdown::new(), + false + ), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + " + ); + + cfg.execution.parquet.pushdown_filters = true; + insta::assert_snapshot!( + OptimizationTest::new( + plan, + FilterPushdown::new(), + true + ), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_filter_collapse() { + // filter should be pushed down into the parquet scan with two filters + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate1 = col_lit_predicate("a", "foo", &schema()); + let filter1 = Arc::new(FilterExec::try_new(predicate1, scan).unwrap()); + let predicate2 = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate2, filter1).unwrap()); + + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar + " + ); +} + +#[test] +fn test_filter_with_projection() { + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let projection = vec![1, 0]; + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new( + FilterExec::try_new(predicate, Arc::clone(&scan)) + .unwrap() + .with_projection(Some(projection)) + .unwrap(), + ); + + // expect the predicate to be pushed down into the DataSource but the FilterExec to be converted to ProjectionExec + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo, projection=[b@1, a@0] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - ProjectionExec: expr=[b@1 as b, a@0 as a] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + ", + ); + + // add a test where the filter is on a column that isn't included in the output + let projection = vec![1]; + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new( + FilterExec::try_new(predicate, scan) + .unwrap() + .with_projection(Some(projection)) + .unwrap(), + ); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(),true), + @r" + OptimizationTest: + input: + - FilterExec: a@0 = foo, projection=[b@1] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - ProjectionExec: expr=[b@1 as b] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +#[test] +fn test_push_down_through_transparent_nodes() { + // expect the predicate to be pushed down into the DataSource + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let coalesce = Arc::new(CoalesceBatchesExec::new(scan, 1)); + let predicate = col_lit_predicate("a", "foo", &schema()); + let filter = Arc::new(FilterExec::try_new(predicate, coalesce).unwrap()); + let repartition = Arc::new( + RepartitionExec::try_new(filter, Partitioning::RoundRobinBatch(1)).unwrap(), + ); + let predicate = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, repartition).unwrap()); + + // expect the predicate to be pushed down into the DataSource + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(),true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=1 + - FilterExec: a@0 = foo + - CoalesceBatchesExec: target_batch_size=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - RepartitionExec: partitioning=RoundRobinBatch(1), input_partitions=1 + - CoalesceBatchesExec: target_batch_size=1 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo AND b@1 = bar + " + ); +} + +#[test] +fn test_no_pushdown_through_aggregates() { + // There are 2 important points here: + // 1. The outer filter **is not** pushed down at all because we haven't implemented pushdown support + // yet for AggregateExec. + // 2. The inner filter **is** pushed down into the DataSource. + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + + let coalesce = Arc::new(CoalesceBatchesExec::new(scan, 10)); + + let filter = Arc::new( + FilterExec::try_new(col_lit_predicate("a", "foo", &schema()), coalesce).unwrap(), + ); + + let aggregate_expr = + vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema()).unwrap()]) + .schema(schema()) + .alias("cnt") + .build() + .map(Arc::new) + .unwrap(), + ]; + let group_by = PhysicalGroupBy::new_single(vec![ + (col("a", &schema()).unwrap(), "a".to_string()), + (col("b", &schema()).unwrap(), "b".to_string()), + ]); + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expr.clone(), + vec![None], + filter, + schema(), + ) + .unwrap(), + ); + + let coalesce = Arc::new(CoalesceBatchesExec::new(aggregate, 100)); + + let predicate = col_lit_predicate("b", "bar", &schema()); + let plan = Arc::new(FilterExec::try_new(predicate, coalesce).unwrap()); + + // expect the predicate to be pushed down into the DataSource + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - FilterExec: b@1 = bar + - CoalesceBatchesExec: target_batch_size=100 + - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt], ordering_mode=PartiallySorted([0]) + - FilterExec: a@0 = foo + - CoalesceBatchesExec: target_batch_size=10 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - FilterExec: b@1 = bar + - CoalesceBatchesExec: target_batch_size=100 + - AggregateExec: mode=Final, gby=[a@0 as a, b@1 as b], aggr=[cnt] + - CoalesceBatchesExec: target_batch_size=10 + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + " + ); +} + +/// Test various combinations of handling of child pushdown results +/// in an ExectionPlan in combination with support/not support in a DataSource. +#[test] +fn test_node_handles_child_pushdown_result() { + // If we set `with_support(true)` + `inject_filter = true` then the filter is pushed down to the DataSource + // and no FilterExec is created. + let scan = TestScanBuilder::new(schema()).with_support(true).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(TestNode::new(true, Arc::clone(&scan), predicate)); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - TestInsertExec { inject_filter: true } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - TestInsertExec { inject_filter: true } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = foo + ", + ); + + // If we set `with_support(false)` + `inject_filter = true` then the filter is not pushed down to the DataSource + // and a FilterExec is created. + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(TestNode::new(true, Arc::clone(&scan), predicate)); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - TestInsertExec { inject_filter: true } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - TestInsertExec { inject_filter: false } + - FilterExec: a@0 = foo + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + ", + ); + + // If we set `with_support(false)` + `inject_filter = false` then the filter is not pushed down to the DataSource + // and no FilterExec is created. + let scan = TestScanBuilder::new(schema()).with_support(false).build(); + let predicate = col_lit_predicate("a", "foo", &schema()); + let plan = Arc::new(TestNode::new(false, Arc::clone(&scan), predicate)); + insta::assert_snapshot!( + OptimizationTest::new(plan, FilterPushdown::new(), true), + @r" + OptimizationTest: + input: + - TestInsertExec { inject_filter: false } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + output: + Ok: + - TestInsertExec { inject_filter: false } + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=false + ", + ); +} + +#[tokio::test] +async fn test_topk_dynamic_filter_pushdown() { + let batches = vec![ + record_batch!( + ("a", Utf8, ["aa", "ab"]), + ("b", Utf8, ["bd", "bc"]), + ("c", Float64, [1.0, 2.0]) + ) + .unwrap(), + record_batch!( + ("a", Utf8, ["ac", "ad"]), + ("b", Utf8, ["bb", "ba"]), + ("c", Float64, [2.0, 1.0]) + ) + .unwrap(), + ]; + let scan = TestScanBuilder::new(schema()) + .with_support(true) + .with_batches(batches) + .build(); + let plan = Arc::new( + SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new( + col("b", &schema()).unwrap(), + SortOptions::new(true, false), // descending, nulls_first + )]) + .unwrap(), + Arc::clone(&scan), + ) + .with_fetch(Some(1)), + ) as Arc; + + // expect the predicate to be pushed down into the DataSource + insta::assert_snapshot!( + OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true), + @r" + OptimizationTest: + input: + - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true + output: + Ok: + - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ true ] + " + ); + + // Actually apply the optimization to the plan and put some data through it to check that the filter is updated to reflect the TopK state + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + let plan = FilterPushdown::new_post_optimization() + .optimize(plan, &config) + .unwrap(); + let config = SessionConfig::new().with_batch_size(2); + let session_ctx = SessionContext::new_with_config(config); + session_ctx.register_object_store( + ObjectStoreUrl::parse("test://").unwrap().as_ref(), + Arc::new(InMemory::new()), + ); + let state = session_ctx.state(); + let task_ctx = state.task_ctx(); + let mut stream = plan.execute(0, Arc::clone(&task_ctx)).unwrap(); + // Iterate one batch + stream.next().await.unwrap().unwrap(); + // Now check what our filter looks like + insta::assert_snapshot!( + format!("{}", format_plan_for_test(&plan)), + @r" + - SortExec: TopK(fetch=1), expr=[b@1 DESC NULLS LAST], preserve_partitioning=[false], filter=[b@1 > bd] + - DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ b@1 > bd ] + " + ); +} + +/// Integration test for dynamic filter pushdown with TopK. +/// We use an integration test because there are complex interactions in the optimizer rules +/// that the unit tests applying a single optimizer rule do not cover. +#[tokio::test] +async fn test_topk_dynamic_filter_pushdown_integration() { + let store = Arc::new(InMemory::new()) as Arc; + let mut cfg = SessionConfig::new(); + cfg.options_mut().execution.parquet.pushdown_filters = true; + cfg.options_mut().execution.parquet.max_row_group_size = 128; + let ctx = SessionContext::new_with_config(cfg); + ctx.register_object_store( + ObjectStoreUrl::parse("memory://").unwrap().as_ref(), + Arc::clone(&store), + ); + ctx.sql( + r" +COPY ( + SELECT 1372708800 + value AS t + FROM generate_series(0, 99999) + ORDER BY t + ) TO 'memory:///1.parquet' +STORED AS PARQUET; + ", + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + + // Register the file with the context + ctx.register_parquet( + "topk_pushdown", + "memory:///1.parquet", + ParquetReadOptions::default(), + ) + .await + .unwrap(); + + // Create a TopK query that will use dynamic filter pushdown + let df = ctx + .sql(r"EXPLAIN ANALYZE SELECT t FROM topk_pushdown ORDER BY t LIMIT 10;") + .await + .unwrap(); + let batches = df.collect().await.unwrap(); + let explain = format!("{}", pretty_format_batches(&batches).unwrap()); + + assert!(explain.contains("output_rows=128")); // Read 1 row group + assert!(explain.contains("t@0 < 1372708809")); // Dynamic filter was applied + assert!( + explain.contains("pushdown_rows_matched=128, pushdown_rows_pruned=99872"), + "{explain}" + ); + // Pushdown pruned most rows +} + +/// Schema: +/// a: String +/// b: String +/// c: f64 +static TEST_SCHEMA: LazyLock = LazyLock::new(|| { + let fields = vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ]; + Arc::new(Schema::new(fields)) +}); + +fn schema() -> SchemaRef { + Arc::clone(&TEST_SCHEMA) +} + +/// Returns a predicate that is a binary expression col = lit +fn col_lit_predicate( + column_name: &str, + scalar_value: impl Into, + schema: &Schema, +) -> Arc { + let scalar_value = scalar_value.into(); + Arc::new(BinaryExpr::new( + Arc::new(Column::new_with_schema(column_name, schema).unwrap()), + Operator::Eq, + Arc::new(Literal::new(scalar_value)), + )) +} diff --git a/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs b/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs new file mode 100644 index 000000000000..e793af8ed4b0 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs @@ -0,0 +1,562 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::SchemaRef; +use arrow::error::ArrowError; +use arrow::{array::RecordBatch, compute::concat_batches}; +use datafusion::{datasource::object_store::ObjectStoreUrl, physical_plan::PhysicalExpr}; +use datafusion_common::{config::ConfigOptions, internal_err, Result, Statistics}; +use datafusion_datasource::{ + file::FileSource, file_meta::FileMeta, file_scan_config::FileScanConfig, + file_scan_config::FileScanConfigBuilder, file_stream::FileOpenFuture, + file_stream::FileOpener, schema_adapter::DefaultSchemaAdapterFactory, + schema_adapter::SchemaAdapterFactory, source::DataSourceExec, PartitionedFile, +}; +use datafusion_physical_expr::conjunction; +use datafusion_physical_expr_common::physical_expr::fmt_sql; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::filter_pushdown::FilterPushdownPhase; +use datafusion_physical_plan::{ + displayable, + filter::FilterExec, + filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPropagation, + PredicateSupport, PredicateSupports, + }, + metrics::ExecutionPlanMetricsSet, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, +}; +use futures::stream::BoxStream; +use futures::{FutureExt, Stream}; +use object_store::ObjectStore; +use std::{ + any::Any, + fmt::{Display, Formatter}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +pub struct TestOpener { + batches: Vec, + batch_size: Option, + schema: Option, + projection: Option>, +} + +impl FileOpener for TestOpener { + fn open( + &self, + _file_meta: FileMeta, + _file: PartitionedFile, + ) -> Result { + let mut batches = self.batches.clone(); + if let Some(batch_size) = self.batch_size { + let batch = concat_batches(&batches[0].schema(), &batches)?; + let mut new_batches = Vec::new(); + for i in (0..batch.num_rows()).step_by(batch_size) { + let end = std::cmp::min(i + batch_size, batch.num_rows()); + let batch = batch.slice(i, end - i); + new_batches.push(batch); + } + batches = new_batches.into_iter().collect(); + } + if let Some(schema) = &self.schema { + let factory = DefaultSchemaAdapterFactory::from_schema(Arc::clone(schema)); + let (mapper, projection) = factory.map_schema(&batches[0].schema()).unwrap(); + let mut new_batches = Vec::new(); + for batch in batches { + let batch = batch.project(&projection).unwrap(); + let batch = mapper.map_batch(batch).unwrap(); + new_batches.push(batch); + } + batches = new_batches; + } + if let Some(projection) = &self.projection { + batches = batches + .into_iter() + .map(|batch| batch.project(projection).unwrap()) + .collect(); + } + + let stream = TestStream::new(batches); + + Ok((async { + let stream: BoxStream<'static, Result> = + Box::pin(stream); + Ok(stream) + }) + .boxed()) + } +} + +/// A placeholder data source that accepts filter pushdown +#[derive(Clone, Default)] +pub struct TestSource { + support: bool, + predicate: Option>, + statistics: Option, + batch_size: Option, + batches: Vec, + schema: Option, + metrics: ExecutionPlanMetricsSet, + projection: Option>, + schema_adapter_factory: Option>, +} + +impl TestSource { + fn new(support: bool, batches: Vec) -> Self { + Self { + support, + metrics: ExecutionPlanMetricsSet::new(), + batches, + ..Default::default() + } + } +} + +impl FileSource for TestSource { + fn create_file_opener( + &self, + _object_store: Arc, + _base_config: &FileScanConfig, + _partition: usize, + ) -> Arc { + Arc::new(TestOpener { + batches: self.batches.clone(), + batch_size: self.batch_size, + schema: self.schema.clone(), + projection: self.projection.clone(), + }) + } + + fn as_any(&self) -> &dyn Any { + todo!("should not be called") + } + + fn with_batch_size(&self, batch_size: usize) -> Arc { + Arc::new(TestSource { + batch_size: Some(batch_size), + ..self.clone() + }) + } + + fn with_schema(&self, schema: SchemaRef) -> Arc { + Arc::new(TestSource { + schema: Some(schema), + ..self.clone() + }) + } + + fn with_projection(&self, config: &FileScanConfig) -> Arc { + Arc::new(TestSource { + projection: config.projection.clone(), + ..self.clone() + }) + } + + fn with_statistics(&self, statistics: Statistics) -> Arc { + Arc::new(TestSource { + statistics: Some(statistics), + ..self.clone() + }) + } + + fn metrics(&self) -> &ExecutionPlanMetricsSet { + &self.metrics + } + + fn statistics(&self) -> Result { + Ok(self + .statistics + .as_ref() + .expect("statistics not set") + .clone()) + } + + fn file_type(&self) -> &str { + "test" + } + + fn fmt_extra(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let support = format!(", pushdown_supported={}", self.support); + + let predicate_string = self + .predicate + .as_ref() + .map(|p| format!(", predicate={p}")) + .unwrap_or_default(); + + write!(f, "{support}{predicate_string}") + } + DisplayFormatType::TreeRender => { + if let Some(predicate) = &self.predicate { + writeln!(f, "pushdown_supported={}", fmt_sql(predicate.as_ref()))?; + writeln!(f, "predicate={}", fmt_sql(predicate.as_ref()))?; + } + Ok(()) + } + } + } + + fn try_pushdown_filters( + &self, + mut filters: Vec>, + config: &ConfigOptions, + ) -> Result>> { + if self.support && config.execution.parquet.pushdown_filters { + if let Some(internal) = self.predicate.as_ref() { + filters.push(Arc::clone(internal)); + } + let new_node = Arc::new(TestSource { + predicate: Some(conjunction(filters.clone())), + ..self.clone() + }); + Ok(FilterPushdownPropagation { + filters: PredicateSupports::all_supported(filters), + updated_node: Some(new_node), + }) + } else { + Ok(FilterPushdownPropagation::unsupported(filters)) + } + } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } +} + +#[derive(Debug, Clone)] +pub struct TestScanBuilder { + support: bool, + batches: Vec, + schema: SchemaRef, +} + +impl TestScanBuilder { + pub fn new(schema: SchemaRef) -> Self { + Self { + support: false, + batches: vec![], + schema, + } + } + + pub fn with_support(mut self, support: bool) -> Self { + self.support = support; + self + } + + pub fn with_batches(mut self, batches: Vec) -> Self { + self.batches = batches; + self + } + + pub fn build(self) -> Arc { + let source = Arc::new(TestSource::new(self.support, self.batches)); + let base_config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("test://").unwrap(), + Arc::clone(&self.schema), + source, + ) + .with_file(PartitionedFile::new("test.parquet", 123)) + .build(); + DataSourceExec::from_data_source(base_config) + } +} + +/// Index into the data that has been returned so far +#[derive(Debug, Default, Clone)] +pub struct BatchIndex { + inner: Arc>, +} + +impl BatchIndex { + /// Return the current index + pub fn value(&self) -> usize { + let inner = self.inner.lock().unwrap(); + *inner + } + + // increment the current index by one + pub fn incr(&self) { + let mut inner = self.inner.lock().unwrap(); + *inner += 1; + } +} + +/// Iterator over batches +#[derive(Debug, Default)] +pub struct TestStream { + /// Vector of record batches + data: Vec, + /// Index into the data that has been returned so far + index: BatchIndex, +} + +impl TestStream { + /// Create an iterator for a vector of record batches. Assumes at + /// least one entry in data (for the schema) + pub fn new(data: Vec) -> Self { + // check that there is at least one entry in data and that all batches have the same schema + assert!(!data.is_empty(), "data must not be empty"); + assert!( + data.iter().all(|batch| batch.schema() == data[0].schema()), + "all batches must have the same schema" + ); + Self { + data, + ..Default::default() + } + } +} + +impl Stream for TestStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + let next_batch = self.index.value(); + + Poll::Ready(if next_batch < self.data.len() { + let next_batch = self.index.value(); + self.index.incr(); + Some(Ok(self.data[next_batch].clone())) + } else { + None + }) + } + + fn size_hint(&self) -> (usize, Option) { + (self.data.len(), Some(self.data.len())) + } +} + +/// A harness for testing physical optimizers. +/// +/// You can use this to test the output of a physical optimizer rule using insta snapshots +#[derive(Debug)] +pub struct OptimizationTest { + input: Vec, + output: Result, String>, +} + +impl OptimizationTest { + pub fn new( + input_plan: Arc, + opt: O, + allow_pushdown_filters: bool, + ) -> Self + where + O: PhysicalOptimizerRule, + { + let mut parquet_pushdown_config = ConfigOptions::default(); + parquet_pushdown_config.execution.parquet.pushdown_filters = + allow_pushdown_filters; + + let input = format_execution_plan(&input_plan); + let input_schema = input_plan.schema(); + + let output_result = opt.optimize(input_plan, &parquet_pushdown_config); + let output = output_result + .and_then(|plan| { + if opt.schema_check() && (plan.schema() != input_schema) { + internal_err!( + "Schema mismatch:\n\nBefore:\n{:?}\n\nAfter:\n{:?}", + input_schema, + plan.schema() + ) + } else { + Ok(plan) + } + }) + .map(|plan| format_execution_plan(&plan)) + .map_err(|e| e.to_string()); + + Self { input, output } + } +} + +impl Display for OptimizationTest { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + writeln!(f, "OptimizationTest:")?; + writeln!(f, " input:")?; + for line in &self.input { + writeln!(f, " - {line}")?; + } + writeln!(f, " output:")?; + match &self.output { + Ok(output) => { + writeln!(f, " Ok:")?; + for line in output { + writeln!(f, " - {line}")?; + } + } + Err(err) => { + writeln!(f, " Err: {err}")?; + } + } + Ok(()) + } +} + +pub fn format_execution_plan(plan: &Arc) -> Vec { + format_lines(&displayable(plan.as_ref()).indent(false).to_string()) +} + +fn format_lines(s: &str) -> Vec { + s.trim().split('\n').map(|s| s.to_string()).collect() +} + +pub fn format_plan_for_test(plan: &Arc) -> String { + let mut out = String::new(); + for line in format_execution_plan(plan) { + out.push_str(&format!(" - {line}\n")); + } + out.push('\n'); + out +} + +#[derive(Debug)] +pub(crate) struct TestNode { + inject_filter: bool, + input: Arc, + predicate: Arc, +} + +impl TestNode { + pub fn new( + inject_filter: bool, + input: Arc, + predicate: Arc, + ) -> Self { + Self { + inject_filter, + input, + predicate, + } + } +} + +impl DisplayAs for TestNode { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!( + f, + "TestInsertExec {{ inject_filter: {} }}", + self.inject_filter + ) + } +} + +impl ExecutionPlan for TestNode { + fn name(&self) -> &str { + "TestInsertExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + self.input.properties() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert!(children.len() == 1); + Ok(Arc::new(TestNode::new( + self.inject_filter, + children[0].clone(), + self.predicate.clone(), + ))) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!("TestInsertExec is a stub for testing.") + } + + fn gather_filters_for_pushdown( + &self, + _phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + Ok(FilterDescription::new_with_child_count(1) + .all_parent_filters_supported(parent_filters) + .with_self_filter(Arc::clone(&self.predicate))) + } + + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + if self.inject_filter { + // Add a FilterExec if our own filter was not handled by the child + + // We have 1 child + assert_eq!(child_pushdown_result.self_filters.len(), 1); + let self_pushdown_result = child_pushdown_result.self_filters[0].clone(); + // And pushed down 1 filter + assert_eq!(self_pushdown_result.len(), 1); + let self_pushdown_result = self_pushdown_result.into_inner(); + + match &self_pushdown_result[0] { + PredicateSupport::Unsupported(filter) => { + // We have a filter to push down + let new_child = + FilterExec::try_new(Arc::clone(filter), Arc::clone(&self.input))?; + let new_self = + TestNode::new(false, Arc::new(new_child), self.predicate.clone()); + let mut res = + FilterPushdownPropagation::transparent(child_pushdown_result); + res.updated_node = Some(Arc::new(new_self) as Arc); + Ok(res) + } + PredicateSupport::Supported(_) => { + let res = + FilterPushdownPropagation::transparent(child_pushdown_result); + Ok(res) + } + } + } else { + let res = FilterPushdownPropagation::transparent(child_pushdown_result); + Ok(res) + } + } +} diff --git a/datafusion/core/tests/physical_optimizer/join_selection.rs b/datafusion/core/tests/physical_optimizer/join_selection.rs index d3b6ec700bee..3477ac77123c 100644 --- a/datafusion/core/tests/physical_optimizer/join_selection.rs +++ b/datafusion/core/tests/physical_optimizer/join_selection.rs @@ -25,8 +25,8 @@ use std::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::config::ConfigOptions; -use datafusion_common::JoinSide; use datafusion_common::{stats::Precision, ColumnStatistics, JoinType, ScalarValue}; +use datafusion_common::{JoinSide, NullEquality}; use datafusion_common::{Result, Statistics}; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; use datafusion_expr::Operator; @@ -222,7 +222,7 @@ async fn test_join_with_swap() { &JoinType::Left, None, PartitionMode::CollectLeft, - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ); @@ -251,11 +251,19 @@ async fn test_join_with_swap() { .expect("The type of the plan should not be changed"); assert_eq!( - swapped_join.left().statistics().unwrap().total_byte_size, + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(8192) ); assert_eq!( - swapped_join.right().statistics().unwrap().total_byte_size, + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(2097152) ); } @@ -276,7 +284,7 @@ async fn test_left_join_no_swap() { &JoinType::Left, None, PartitionMode::CollectLeft, - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ); @@ -291,11 +299,19 @@ async fn test_left_join_no_swap() { .expect("The type of the plan should not be changed"); assert_eq!( - swapped_join.left().statistics().unwrap().total_byte_size, + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(8192) ); assert_eq!( - swapped_join.right().statistics().unwrap().total_byte_size, + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(2097152) ); } @@ -317,7 +333,7 @@ async fn test_join_with_swap_semi() { &join_type, None, PartitionMode::Partitioned, - false, + NullEquality::NullEqualsNothing, ) .unwrap(); @@ -336,11 +352,19 @@ async fn test_join_with_swap_semi() { assert_eq!(swapped_join.schema().fields().len(), 1); assert_eq!( - swapped_join.left().statistics().unwrap().total_byte_size, + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(8192) ); assert_eq!( - swapped_join.right().statistics().unwrap().total_byte_size, + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(2097152) ); assert_eq!(original_schema, swapped_join.schema()); @@ -384,7 +408,7 @@ async fn test_nested_join_swap() { &JoinType::Inner, None, PartitionMode::CollectLeft, - false, + NullEquality::NullEqualsNothing, ) .unwrap(); let child_schema = child_join.schema(); @@ -401,7 +425,7 @@ async fn test_nested_join_swap() { &JoinType::Left, None, PartitionMode::CollectLeft, - false, + NullEquality::NullEqualsNothing, ) .unwrap(); @@ -440,7 +464,7 @@ async fn test_join_no_swap() { &JoinType::Inner, None, PartitionMode::CollectLeft, - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ); @@ -455,11 +479,19 @@ async fn test_join_no_swap() { .expect("The type of the plan should not be changed"); assert_eq!( - swapped_join.left().statistics().unwrap().total_byte_size, + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(8192) ); assert_eq!( - swapped_join.right().statistics().unwrap().total_byte_size, + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(2097152) ); } @@ -524,11 +556,19 @@ async fn test_nl_join_with_swap(join_type: JoinType) { ); assert_eq!( - swapped_join.left().statistics().unwrap().total_byte_size, + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(8192) ); assert_eq!( - swapped_join.right().statistics().unwrap().total_byte_size, + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(2097152) ); } @@ -589,11 +629,19 @@ async fn test_nl_join_with_swap_no_proj(join_type: JoinType) { ); assert_eq!( - swapped_join.left().statistics().unwrap().total_byte_size, + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(8192) ); assert_eq!( - swapped_join.right().statistics().unwrap().total_byte_size, + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, Precision::Inexact(2097152) ); } @@ -642,7 +690,7 @@ async fn test_hash_join_swap_on_joins_with_projections( &join_type, Some(projection), PartitionMode::Partitioned, - false, + NullEquality::NullEqualsNothing, )?); let swapped = join @@ -803,7 +851,7 @@ fn check_join_partition_mode( &JoinType::Inner, None, PartitionMode::Auto, - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ); @@ -1067,6 +1115,14 @@ impl ExecutionPlan for StatisticsExec { fn statistics(&self) -> Result { Ok(self.stats.clone()) } + + fn partition_statistics(&self, partition: Option) -> Result { + Ok(if partition.is_some() { + Statistics::new_unknown(&self.schema) + } else { + self.stats.clone() + }) + } } #[test] @@ -1442,7 +1498,7 @@ async fn test_join_with_maybe_swap_unbounded_case(t: TestCase) -> Result<()> { &t.initial_join_type, None, t.initial_mode, - false, + NullEquality::NullEqualsNothing, )?) as _; let optimized_join_plan = hash_join_swap_subrule(join, &ConfigOptions::new())?; diff --git a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs index dd2c1960a658..56d48901f284 100644 --- a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs @@ -17,28 +17,26 @@ use std::sync::Arc; +use crate::physical_optimizer::test_utils::{ + coalesce_batches_exec, coalesce_partitions_exec, global_limit_exec, local_limit_exec, + sort_exec, sort_preserving_merge_exec, stream_exec, +}; + use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::Operator; -use datafusion_physical_expr::expressions::BinaryExpr; -use datafusion_physical_expr::expressions::{col, lit}; -use datafusion_physical_expr::{Partitioning, PhysicalSortExpr}; +use datafusion_physical_expr::expressions::{col, lit, BinaryExpr}; +use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_optimizer::limit_pushdown::LimitPushdown; use datafusion_physical_optimizer::PhysicalOptimizerRule; -use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::filter::FilterExec; -use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::sorts::sort::SortExec; -use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; -use datafusion_physical_plan::{get_plan_string, ExecutionPlan, ExecutionPlanProperties}; +use datafusion_physical_plan::{get_plan_string, ExecutionPlan}; fn create_schema() -> SchemaRef { Arc::new(Schema::new(vec![ @@ -48,48 +46,6 @@ fn create_schema() -> SchemaRef { ])) } -fn streaming_table_exec(schema: SchemaRef) -> Result> { - Ok(Arc::new(StreamingTableExec::try_new( - Arc::clone(&schema), - vec![Arc::new(DummyStreamPartition { schema }) as _], - None, - None, - true, - None, - )?)) -} - -fn global_limit_exec( - input: Arc, - skip: usize, - fetch: Option, -) -> Arc { - Arc::new(GlobalLimitExec::new(input, skip, fetch)) -} - -fn local_limit_exec( - input: Arc, - fetch: usize, -) -> Arc { - Arc::new(LocalLimitExec::new(input, fetch)) -} - -fn sort_exec( - sort_exprs: impl IntoIterator, - input: Arc, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortExec::new(sort_exprs, input)) -} - -fn sort_preserving_merge_exec( - sort_exprs: impl IntoIterator, - input: Arc, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) -} - fn projection_exec( schema: SchemaRef, input: Arc, @@ -118,16 +74,6 @@ fn filter_exec( )?)) } -fn coalesce_batches_exec(input: Arc) -> Arc { - Arc::new(CoalesceBatchesExec::new(input, 8192)) -} - -fn coalesce_partitions_exec( - local_limit: Arc, -) -> Arc { - Arc::new(CoalescePartitionsExec::new(local_limit)) -} - fn repartition_exec( streaming_table: Arc, ) -> Result> { @@ -141,24 +87,11 @@ fn empty_exec(schema: SchemaRef) -> Arc { Arc::new(EmptyExec::new(schema)) } -#[derive(Debug)] -struct DummyStreamPartition { - schema: SchemaRef, -} -impl PartitionStream for DummyStreamPartition { - fn schema(&self) -> &SchemaRef { - &self.schema - } - fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { - unreachable!() - } -} - #[test] fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero() -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(schema)?; + let streaming_table = stream_exec(&schema); let global_limit = global_limit_exec(streaming_table, 0, Some(5)); let initial = get_plan_string(&global_limit); @@ -183,7 +116,7 @@ fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero() -> fn transforms_streaming_table_exec_into_fetching_version_and_keeps_the_global_limit_when_skip_is_nonzero( ) -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(schema)?; + let streaming_table = stream_exec(&schema); let global_limit = global_limit_exec(streaming_table, 2, Some(5)); let initial = get_plan_string(&global_limit); @@ -209,10 +142,10 @@ fn transforms_streaming_table_exec_into_fetching_version_and_keeps_the_global_li fn transforms_coalesce_batches_exec_into_fetching_version_and_removes_local_limit( ) -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema))?; + let streaming_table = stream_exec(&schema); let repartition = repartition_exec(streaming_table)?; let filter = filter_exec(schema, repartition)?; - let coalesce_batches = coalesce_batches_exec(filter); + let coalesce_batches = coalesce_batches_exec(filter, 8192); let local_limit = local_limit_exec(coalesce_batches, 5); let coalesce_partitions = coalesce_partitions_exec(local_limit); let global_limit = global_limit_exec(coalesce_partitions, 0, Some(5)); @@ -247,7 +180,7 @@ fn transforms_coalesce_batches_exec_into_fetching_version_and_removes_local_limi #[test] fn pushes_global_limit_exec_through_projection_exec() -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema))?; + let streaming_table = stream_exec(&schema); let filter = filter_exec(Arc::clone(&schema), streaming_table)?; let projection = projection_exec(schema, filter)?; let global_limit = global_limit_exec(projection, 0, Some(5)); @@ -279,8 +212,8 @@ fn pushes_global_limit_exec_through_projection_exec() -> Result<()> { fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batches_exec_into_fetching_version( ) -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema)).unwrap(); - let coalesce_batches = coalesce_batches_exec(streaming_table); + let streaming_table = stream_exec(&schema); + let coalesce_batches = coalesce_batches_exec(streaming_table, 8192); let projection = projection_exec(schema, coalesce_batches)?; let global_limit = global_limit_exec(projection, 0, Some(5)); @@ -310,18 +243,17 @@ fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batc #[test] fn pushes_global_limit_into_multiple_fetch_plans() -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema)).unwrap(); - let coalesce_batches = coalesce_batches_exec(streaming_table); + let streaming_table = stream_exec(&schema); + let coalesce_batches = coalesce_batches_exec(streaming_table, 8192); let projection = projection_exec(Arc::clone(&schema), coalesce_batches)?; let repartition = repartition_exec(projection)?; - let sort = sort_exec( - vec![PhysicalSortExpr { - expr: col("c1", &schema)?, - options: SortOptions::default(), - }], - repartition, - ); - let spm = sort_preserving_merge_exec(sort.output_ordering().unwrap().to_vec(), sort); + let ordering: LexOrdering = [PhysicalSortExpr { + expr: col("c1", &schema)?, + options: SortOptions::default(), + }] + .into(); + let sort = sort_exec(ordering.clone(), repartition); + let spm = sort_preserving_merge_exec(ordering, sort); let global_limit = global_limit_exec(spm, 0, Some(5)); let initial = get_plan_string(&global_limit); @@ -357,7 +289,7 @@ fn pushes_global_limit_into_multiple_fetch_plans() -> Result<()> { fn keeps_pushed_local_limit_exec_when_there_are_multiple_input_partitions() -> Result<()> { let schema = create_schema(); - let streaming_table = streaming_table_exec(Arc::clone(&schema))?; + let streaming_table = stream_exec(&schema); let repartition = repartition_exec(streaming_table)?; let filter = filter_exec(schema, repartition)?; let coalesce_partitions = coalesce_partitions_exec(filter); diff --git a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs index f9810eab8f59..409d392f7819 100644 --- a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs @@ -30,9 +30,8 @@ use datafusion::prelude::SessionContext; use datafusion_common::Result; use datafusion_execution::config::SessionConfig; use datafusion_expr::Operator; -use datafusion_physical_expr::expressions::cast; -use datafusion_physical_expr::{expressions, expressions::col, PhysicalSortExpr}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr::expressions::{self, cast, col}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::{ aggregates::{AggregateExec, AggregateMode}, collect, @@ -236,12 +235,13 @@ async fn test_distinct_cols_different_than_group_by_cols() -> Result<()> { #[test] fn test_has_order_by() -> Result<()> { - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("a", &schema()).unwrap(), + let schema = schema(); + let sort_key = [PhysicalSortExpr { + expr: col("a", &schema)?, options: SortOptions::default(), - }]); - let source = parquet_exec_with_sort(vec![sort_key]); - let schema = source.schema(); + }] + .into(); + let source = parquet_exec_with_sort(schema.clone(), vec![sort_key]); // `SELECT a FROM DataSourceExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec // the `a > 1` filter is applied in the AggregateExec @@ -263,7 +263,7 @@ fn test_has_order_by() -> Result<()> { "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], ordering_mode=Sorted", "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], file_type=parquet", ]; - let plan: Arc = Arc::new(limit_exec); + let plan = Arc::new(limit_exec) as _; assert_plan_matches_expected(&plan, &expected)?; Ok(()) } diff --git a/datafusion/core/tests/physical_optimizer/mod.rs b/datafusion/core/tests/physical_optimizer/mod.rs index 7d5d07715eeb..777c26e80e90 100644 --- a/datafusion/core/tests/physical_optimizer/mod.rs +++ b/datafusion/core/tests/physical_optimizer/mod.rs @@ -21,10 +21,13 @@ mod aggregate_statistics; mod combine_partial_final_agg; mod enforce_distribution; mod enforce_sorting; +mod filter_pushdown; mod join_selection; mod limit_pushdown; mod limited_distinct_aggregation; +mod partition_statistics; mod projection_pushdown; mod replace_with_order_preserving_variants; mod sanity_checker; mod test_utils; +mod window_optimize; diff --git a/datafusion/core/tests/physical_optimizer/partition_statistics.rs b/datafusion/core/tests/physical_optimizer/partition_statistics.rs new file mode 100644 index 000000000000..90124e0fcfc7 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/partition_statistics.rs @@ -0,0 +1,731 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{Int32Array, RecordBatch}; + use arrow_schema::{DataType, Field, Schema, SortOptions}; + use datafusion::datasource::listing::ListingTable; + use datafusion::prelude::SessionContext; + use datafusion_catalog::TableProvider; + use datafusion_common::stats::Precision; + use datafusion_common::Result; + use datafusion_common::{ColumnStatistics, ScalarValue, Statistics}; + use datafusion_execution::config::SessionConfig; + use datafusion_execution::TaskContext; + use datafusion_expr_common::operator::Operator; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::expressions::{binary, col, lit, Column}; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, + }; + use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; + use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; + use datafusion_physical_plan::empty::EmptyExec; + use datafusion_physical_plan::filter::FilterExec; + use datafusion_physical_plan::joins::CrossJoinExec; + use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; + use datafusion_physical_plan::projection::ProjectionExec; + use datafusion_physical_plan::sorts::sort::SortExec; + use datafusion_physical_plan::union::UnionExec; + use datafusion_physical_plan::{ + execute_stream_partitioned, get_plan_string, ExecutionPlan, + ExecutionPlanProperties, + }; + + use futures::TryStreamExt; + + /// Creates a test table with statistics from the test data directory. + /// + /// This function: + /// - Creates an external table from './tests/data/test_statistics_per_partition' + /// - If we set the `target_partition` to 2, the data contains 2 partitions, each with 2 rows + /// - Each partition has an "id" column (INT) with the following values: + /// - First partition: [3, 4] + /// - Second partition: [1, 2] + /// - Each row is 110 bytes in size + /// + /// @param target_partition Optional parameter to set the target partitions + /// @return ExecutionPlan representing the scan of the table with statistics + async fn create_scan_exec_with_statistics( + create_table_sql: Option<&str>, + target_partition: Option, + ) -> Arc { + let mut session_config = SessionConfig::new().with_collect_statistics(true); + if let Some(partition) = target_partition { + session_config = session_config.with_target_partitions(partition); + } + let ctx = SessionContext::new_with_config(session_config); + // Create table with partition + let create_table_sql = create_table_sql.unwrap_or( + "CREATE EXTERNAL TABLE t1 (id INT NOT NULL, date DATE) \ + STORED AS PARQUET LOCATION './tests/data/test_statistics_per_partition'\ + PARTITIONED BY (date) \ + WITH ORDER (id ASC);", + ); + // Get table name from `create_table_sql` + let table_name = create_table_sql + .split_whitespace() + .nth(3) + .unwrap_or("t1") + .to_string(); + ctx.sql(create_table_sql) + .await + .unwrap() + .collect() + .await + .unwrap(); + let table = ctx.table_provider(table_name.as_str()).await.unwrap(); + let listing_table = table + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + listing_table + .scan(&ctx.state(), None, &[], None) + .await + .unwrap() + } + + /// Helper function to create expected statistics for a partition with Int32 column + fn create_partition_statistics( + num_rows: usize, + total_byte_size: usize, + min_value: i32, + max_value: i32, + include_date_column: bool, + ) -> Statistics { + let mut column_stats = vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(max_value))), + min_value: Precision::Exact(ScalarValue::Int32(Some(min_value))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }]; + + if include_date_column { + column_stats.push(ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }); + } + + Statistics { + num_rows: Precision::Exact(num_rows), + total_byte_size: Precision::Exact(total_byte_size), + column_statistics: column_stats, + } + } + + #[derive(PartialEq, Eq, Debug)] + enum ExpectedStatistics { + Empty, // row_count == 0 + NonEmpty(i32, i32, usize), // (min_id, max_id, row_count) + } + + /// Helper function to validate that statistics from statistics_by_partition match the actual data + async fn validate_statistics_with_data( + plan: Arc, + expected_stats: Vec, + id_column_index: usize, + ) -> Result<()> { + let ctx = TaskContext::default(); + let partitions = execute_stream_partitioned(plan, Arc::new(ctx))?; + + let mut actual_stats = Vec::new(); + for partition_stream in partitions.into_iter() { + let result: Vec = partition_stream.try_collect().await?; + + let mut min_id = i32::MAX; + let mut max_id = i32::MIN; + let mut row_count = 0; + + for batch in result { + if batch.num_columns() > id_column_index { + let id_array = batch + .column(id_column_index) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + let id_value = id_array.value(i); + min_id = min_id.min(id_value); + max_id = max_id.max(id_value); + row_count += 1; + } + } + } + + if row_count == 0 { + actual_stats.push(ExpectedStatistics::Empty); + } else { + actual_stats + .push(ExpectedStatistics::NonEmpty(min_id, max_id, row_count)); + } + } + + // Compare actual data with expected statistics + assert_eq!( + actual_stats.len(), + expected_stats.len(), + "Number of partitions with data doesn't match expected" + ); + for i in 0..actual_stats.len() { + assert_eq!( + actual_stats[i], expected_stats[i], + "Partition {i} data doesn't match statistics" + ); + } + + Ok(()) + } + + #[tokio::test] + async fn test_statistics_by_partition_of_data_source() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + let statistics = (0..scan.output_partitioning().partition_count()) + .map(|idx| scan.partition_statistics(Some(idx))) + .collect::>>()?; + let expected_statistic_partition_1 = + create_partition_statistics(2, 110, 3, 4, true); + let expected_statistic_partition_2 = + create_partition_statistics(2, 110, 1, 2, true); + // Check the statistics of each partition + assert_eq!(statistics.len(), 2); + assert_eq!(statistics[0], expected_statistic_partition_1); + assert_eq!(statistics[1], expected_statistic_partition_2); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), // (min_id, max_id, row_count) for first partition + ExpectedStatistics::NonEmpty(1, 2, 2), // (min_id, max_id, row_count) for second partition + ]; + validate_statistics_with_data(scan, expected_stats, 0).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_statistics_by_partition_of_projection() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + // Add projection execution plan + let exprs: Vec<(Arc, String)> = + vec![(Arc::new(Column::new("id", 0)), "id".to_string())]; + let projection: Arc = + Arc::new(ProjectionExec::try_new(exprs, scan)?); + let statistics = (0..projection.output_partitioning().partition_count()) + .map(|idx| projection.partition_statistics(Some(idx))) + .collect::>>()?; + let expected_statistic_partition_1 = + create_partition_statistics(2, 8, 3, 4, false); + let expected_statistic_partition_2 = + create_partition_statistics(2, 8, 1, 2, false); + // Check the statistics of each partition + assert_eq!(statistics.len(), 2); + assert_eq!(statistics[0], expected_statistic_partition_1); + assert_eq!(statistics[1], expected_statistic_partition_2); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ]; + validate_statistics_with_data(projection, expected_stats, 0).await?; + Ok(()) + } + + #[tokio::test] + async fn test_statistics_by_partition_of_sort() -> Result<()> { + let scan_1 = create_scan_exec_with_statistics(None, Some(1)).await; + // Add sort execution plan + let ordering = [PhysicalSortExpr::new( + Arc::new(Column::new("id", 0)), + SortOptions::new(false, false), + )]; + let sort = SortExec::new(ordering.clone().into(), scan_1); + let sort_exec: Arc = Arc::new(sort); + let statistics = (0..sort_exec.output_partitioning().partition_count()) + .map(|idx| sort_exec.partition_statistics(Some(idx))) + .collect::>>()?; + let expected_statistic_partition = + create_partition_statistics(4, 220, 1, 4, true); + assert_eq!(statistics.len(), 1); + assert_eq!(statistics[0], expected_statistic_partition); + // Check the statistics_by_partition with real results + let expected_stats = vec![ExpectedStatistics::NonEmpty(1, 4, 4)]; + validate_statistics_with_data(sort_exec.clone(), expected_stats, 0).await?; + + // Sort with preserve_partitioning + let scan_2 = create_scan_exec_with_statistics(None, Some(2)).await; + // Add sort execution plan + let sort_exec: Arc = Arc::new( + SortExec::new(ordering.into(), scan_2).with_preserve_partitioning(true), + ); + let expected_statistic_partition_1 = + create_partition_statistics(2, 110, 3, 4, true); + let expected_statistic_partition_2 = + create_partition_statistics(2, 110, 1, 2, true); + let statistics = (0..sort_exec.output_partitioning().partition_count()) + .map(|idx| sort_exec.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 2); + assert_eq!(statistics[0], expected_statistic_partition_1); + assert_eq!(statistics[1], expected_statistic_partition_2); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ]; + validate_statistics_with_data(sort_exec, expected_stats, 0).await?; + Ok(()) + } + + #[tokio::test] + async fn test_statistics_by_partition_of_filter() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + let predicate = binary( + Arc::new(Column::new("id", 0)), + Operator::Lt, + lit(1i32), + &schema, + )?; + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, scan)?); + let full_statistics = filter.partition_statistics(None)?; + let expected_full_statistic = Statistics { + num_rows: Precision::Inexact(0), + total_byte_size: Precision::Inexact(0), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Null), + min_value: Precision::Exact(ScalarValue::Null), + sum_value: Precision::Exact(ScalarValue::Null), + distinct_count: Precision::Exact(0), + }, + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Null), + min_value: Precision::Exact(ScalarValue::Null), + sum_value: Precision::Exact(ScalarValue::Null), + distinct_count: Precision::Exact(0), + }, + ], + }; + assert_eq!(full_statistics, expected_full_statistic); + + let statistics = (0..filter.output_partitioning().partition_count()) + .map(|idx| filter.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 2); + assert_eq!(statistics[0], expected_full_statistic); + assert_eq!(statistics[1], expected_full_statistic); + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_union() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + let union_exec: Arc = + Arc::new(UnionExec::new(vec![scan.clone(), scan])); + let statistics = (0..union_exec.output_partitioning().partition_count()) + .map(|idx| union_exec.partition_statistics(Some(idx))) + .collect::>>()?; + // Check that we have 4 partitions (2 from each scan) + assert_eq!(statistics.len(), 4); + let expected_statistic_partition_1 = + create_partition_statistics(2, 110, 3, 4, true); + let expected_statistic_partition_2 = + create_partition_statistics(2, 110, 1, 2, true); + // Verify first partition (from first scan) + assert_eq!(statistics[0], expected_statistic_partition_1); + // Verify second partition (from first scan) + assert_eq!(statistics[1], expected_statistic_partition_2); + // Verify third partition (from second scan - same as first partition) + assert_eq!(statistics[2], expected_statistic_partition_1); + // Verify fourth partition (from second scan - same as second partition) + assert_eq!(statistics[3], expected_statistic_partition_2); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ]; + validate_statistics_with_data(union_exec, expected_stats, 0).await?; + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_cross_join() -> Result<()> { + let left_scan = create_scan_exec_with_statistics(None, Some(1)).await; + let right_create_table_sql = "CREATE EXTERNAL TABLE t2 (id INT NOT NULL) \ + STORED AS PARQUET LOCATION './tests/data/test_statistics_per_partition'\ + WITH ORDER (id ASC);"; + let right_scan = + create_scan_exec_with_statistics(Some(right_create_table_sql), Some(2)).await; + let cross_join: Arc = + Arc::new(CrossJoinExec::new(left_scan, right_scan)); + let statistics = (0..cross_join.output_partitioning().partition_count()) + .map(|idx| cross_join.partition_statistics(Some(idx))) + .collect::>>()?; + // Check that we have 2 partitions + assert_eq!(statistics.len(), 2); + let mut expected_statistic_partition_1 = + create_partition_statistics(8, 48400, 1, 4, true); + expected_statistic_partition_1 + .column_statistics + .push(ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }); + let mut expected_statistic_partition_2 = + create_partition_statistics(8, 48400, 1, 4, true); + expected_statistic_partition_2 + .column_statistics + .push(ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Int32(Some(2))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }); + assert_eq!(statistics[0], expected_statistic_partition_1); + assert_eq!(statistics[1], expected_statistic_partition_2); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(1, 4, 8), + ExpectedStatistics::NonEmpty(1, 4, 8), + ]; + validate_statistics_with_data(cross_join, expected_stats, 0).await?; + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_coalesce_batches() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + dbg!(scan.partition_statistics(Some(0))?); + let coalesce_batches: Arc = + Arc::new(CoalesceBatchesExec::new(scan, 2)); + let expected_statistic_partition_1 = + create_partition_statistics(2, 110, 3, 4, true); + let expected_statistic_partition_2 = + create_partition_statistics(2, 110, 1, 2, true); + let statistics = (0..coalesce_batches.output_partitioning().partition_count()) + .map(|idx| coalesce_batches.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 2); + assert_eq!(statistics[0], expected_statistic_partition_1); + assert_eq!(statistics[1], expected_statistic_partition_2); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ]; + validate_statistics_with_data(coalesce_batches, expected_stats, 0).await?; + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_coalesce_partitions() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + let coalesce_partitions: Arc = + Arc::new(CoalescePartitionsExec::new(scan)); + let expected_statistic_partition = + create_partition_statistics(4, 220, 1, 4, true); + let statistics = (0..coalesce_partitions.output_partitioning().partition_count()) + .map(|idx| coalesce_partitions.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 1); + assert_eq!(statistics[0], expected_statistic_partition); + + // Check the statistics_by_partition with real results + let expected_stats = vec![ExpectedStatistics::NonEmpty(1, 4, 4)]; + validate_statistics_with_data(coalesce_partitions, expected_stats, 0).await?; + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_local_limit() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + let local_limit: Arc = + Arc::new(LocalLimitExec::new(scan.clone(), 1)); + let statistics = (0..local_limit.output_partitioning().partition_count()) + .map(|idx| local_limit.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 2); + let schema = scan.schema(); + let mut expected_statistic_partition = Statistics::new_unknown(&schema); + expected_statistic_partition.num_rows = Precision::Exact(1); + assert_eq!(statistics[0], expected_statistic_partition); + assert_eq!(statistics[1], expected_statistic_partition); + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_global_limit_partitions() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + // Skip 2 rows + let global_limit: Arc = + Arc::new(GlobalLimitExec::new(scan.clone(), 0, Some(2))); + let statistics = (0..global_limit.output_partitioning().partition_count()) + .map(|idx| global_limit.partition_statistics(Some(idx))) + .collect::>>()?; + assert_eq!(statistics.len(), 1); + let expected_statistic_partition = + create_partition_statistics(2, 110, 3, 4, true); + assert_eq!(statistics[0], expected_statistic_partition); + Ok(()) + } + + #[tokio::test] + async fn test_statistic_by_partition_of_agg() -> Result<()> { + let scan = create_scan_exec_with_statistics(None, Some(2)).await; + + let scan_schema = scan.schema(); + + // select id, 1+id, count(*) from t group by id, 1+id + let group_by = PhysicalGroupBy::new_single(vec![ + (col("id", &scan_schema)?, "id".to_string()), + ( + binary( + lit(1), + Operator::Plus, + col("id", &scan_schema)?, + &scan_schema, + )?, + "expr".to_string(), + ), + ]); + + let aggr_expr = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) + .schema(Arc::clone(&scan_schema)) + .alias(String::from("COUNT(c)")) + .build() + .map(Arc::new)?]; + + let aggregate_exec_partial = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggr_expr.clone(), + vec![None], + Arc::clone(&scan), + scan_schema.clone(), + )?) as _; + + let mut plan_string = get_plan_string(&aggregate_exec_partial); + let _ = plan_string.swap_remove(1); + let expected_plan = vec![ + "AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)]", + //" DataSourceExec: file_groups={2 groups: [[.../datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-01/j5fUeSDQo22oPyPU.parquet, .../datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-02/j5fUeSDQo22oPyPU.parquet], [.../datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-03/j5fUeSDQo22oPyPU.parquet, .../datafusion/core/tests/data/test_statistics_per_partition/date=2025-03-04/j5fUeSDQo22oPyPU.parquet]]}, projection=[id, date], file_type=parquet + ]; + assert_eq!(plan_string, expected_plan); + + let p0_statistics = aggregate_exec_partial.partition_statistics(Some(0))?; + + let expected_p0_statistics = Statistics { + num_rows: Precision::Inexact(2), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Int32(Some(4))), + min_value: Precision::Exact(ScalarValue::Int32(Some(3))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }, + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ], + }; + + assert_eq!(&p0_statistics, &expected_p0_statistics); + + let expected_p1_statistics = Statistics { + num_rows: Precision::Inexact(2), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Exact(ScalarValue::Int32(Some(2))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + }, + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ], + }; + + let p1_statistics = aggregate_exec_partial.partition_statistics(Some(1))?; + assert_eq!(&p1_statistics, &expected_p1_statistics); + + validate_statistics_with_data( + aggregate_exec_partial.clone(), + vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ], + 0, + ) + .await?; + + let agg_final = Arc::new(AggregateExec::try_new( + AggregateMode::FinalPartitioned, + group_by.clone(), + aggr_expr.clone(), + vec![None], + aggregate_exec_partial.clone(), + aggregate_exec_partial.schema(), + )?); + + let p0_statistics = agg_final.partition_statistics(Some(0))?; + assert_eq!(&p0_statistics, &expected_p0_statistics); + + let p1_statistics = agg_final.partition_statistics(Some(1))?; + assert_eq!(&p1_statistics, &expected_p1_statistics); + + validate_statistics_with_data( + agg_final.clone(), + vec![ + ExpectedStatistics::NonEmpty(3, 4, 2), + ExpectedStatistics::NonEmpty(1, 2, 2), + ], + 0, + ) + .await?; + + // select id, 1+id, count(*) from empty_table group by id, 1+id + let empty_table = + Arc::new(EmptyExec::new(scan_schema.clone()).with_partitions(2)); + + let agg_partial = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggr_expr.clone(), + vec![None], + empty_table.clone(), + scan_schema.clone(), + )?) as _; + + let agg_plan = get_plan_string(&agg_partial).remove(0); + assert_eq!("AggregateExec: mode=Partial, gby=[id@0 as id, 1 + id@0 as expr], aggr=[COUNT(c)]",agg_plan); + + let empty_stat = Statistics { + num_rows: Precision::Exact(0), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ], + }; + + assert_eq!(&empty_stat, &agg_partial.partition_statistics(Some(0))?); + assert_eq!(&empty_stat, &agg_partial.partition_statistics(Some(1))?); + validate_statistics_with_data( + agg_partial.clone(), + vec![ExpectedStatistics::Empty, ExpectedStatistics::Empty], + 0, + ) + .await?; + + let agg_partial = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggr_expr.clone(), + vec![None], + empty_table.clone(), + scan_schema.clone(), + )?); + + let agg_final = Arc::new(AggregateExec::try_new( + AggregateMode::FinalPartitioned, + group_by.clone(), + aggr_expr.clone(), + vec![None], + agg_partial.clone(), + agg_partial.schema(), + )?); + + assert_eq!(&empty_stat, &agg_final.partition_statistics(Some(0))?); + assert_eq!(&empty_stat, &agg_final.partition_statistics(Some(1))?); + + validate_statistics_with_data( + agg_final, + vec![ExpectedStatistics::Empty, ExpectedStatistics::Empty], + 0, + ) + .await?; + + // select count(*) from empty_table + let agg_partial = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + aggr_expr.clone(), + vec![None], + empty_table.clone(), + scan_schema.clone(), + )?); + + let coalesce = Arc::new(CoalescePartitionsExec::new(agg_partial.clone())); + + let agg_final = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + aggr_expr.clone(), + vec![None], + coalesce.clone(), + coalesce.schema(), + )?); + + let expect_stat = Statistics { + num_rows: Precision::Exact(1), + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics::new_unknown()], + }; + + assert_eq!(&expect_stat, &agg_final.partition_statistics(Some(0))?); + + // Verify that the aggregate final result has exactly one partition with one row + let mut partitions = execute_stream_partitioned( + agg_final.clone(), + Arc::new(TaskContext::default()), + )?; + assert_eq!(1, partitions.len()); + let result: Vec = partitions.remove(0).try_collect().await?; + assert_eq!(1, result[0].num_rows()); + + Ok(()) + } +} diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index 911d2c0cee05..1f8aad0f2334 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -25,21 +25,22 @@ use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::physical_plan::CsvSource; use datafusion::datasource::source::DataSourceExec; use datafusion_common::config::ConfigOptions; -use datafusion_common::Result; -use datafusion_common::{JoinSide, JoinType, ScalarValue}; +use datafusion_common::{JoinSide, JoinType, NullEquality, Result, ScalarValue}; +use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{ Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr::expressions::{ binary, cast, col, BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, }; -use datafusion_physical_expr::ScalarFunctionExpr; -use datafusion_physical_expr::{ - Distribution, Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, +use datafusion_physical_expr::{Distribution, Partitioning, ScalarFunctionExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{ + OrderingRequirements, PhysicalSortExpr, PhysicalSortRequirement, }; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; use datafusion_physical_optimizer::projection_pushdown::ProjectionPushdown; use datafusion_physical_optimizer::PhysicalOptimizerRule; @@ -54,13 +55,10 @@ use datafusion_physical_plan::projection::{update_expr, ProjectionExec}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; -use datafusion_physical_plan::streaming::PartitionStream; -use datafusion_physical_plan::streaming::StreamingTableExec; +use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::union::UnionExec; use datafusion_physical_plan::{get_plan_string, ExecutionPlan}; -use datafusion_datasource::file_scan_config::FileScanConfigBuilder; -use datafusion_expr_common::columnar_value::ColumnarValue; use itertools::Itertools; /// Mocked UDF @@ -128,7 +126,7 @@ fn test_update_matching_exprs() -> Result<()> { Arc::new(Column::new("b", 1)), )), ], - DataType::Int32, + Field::new("f", DataType::Int32, true).into(), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -193,7 +191,7 @@ fn test_update_matching_exprs() -> Result<()> { Arc::new(Column::new("b", 1)), )), ], - DataType::Int32, + Field::new("f", DataType::Int32, true).into(), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 3))), @@ -261,7 +259,7 @@ fn test_update_projected_exprs() -> Result<()> { Arc::new(Column::new("b", 1)), )), ], - DataType::Int32, + Field::new("f", DataType::Int32, true).into(), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -326,7 +324,7 @@ fn test_update_projected_exprs() -> Result<()> { Arc::new(Column::new("b_new", 1)), )), ], - DataType::Int32, + Field::new("f", DataType::Int32, true).into(), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d_new", 3))), @@ -519,7 +517,7 @@ fn test_streaming_table_after_projection() -> Result<()> { }) as _], Some(&vec![0_usize, 2, 4, 3]), vec![ - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: Arc::new(Column::new("e", 2)), options: SortOptions::default(), @@ -528,11 +526,13 @@ fn test_streaming_table_after_projection() -> Result<()> { expr: Arc::new(Column::new("a", 0)), options: SortOptions::default(), }, - ]), - LexOrdering::new(vec![PhysicalSortExpr { + ] + .into(), + [PhysicalSortExpr { expr: Arc::new(Column::new("d", 3)), options: SortOptions::default(), - }]), + }] + .into(), ] .into_iter(), true, @@ -579,7 +579,7 @@ fn test_streaming_table_after_projection() -> Result<()> { assert_eq!( result.projected_output_ordering().into_iter().collect_vec(), vec![ - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: Arc::new(Column::new("e", 1)), options: SortOptions::default(), @@ -588,11 +588,13 @@ fn test_streaming_table_after_projection() -> Result<()> { expr: Arc::new(Column::new("a", 2)), options: SortOptions::default(), }, - ]), - LexOrdering::new(vec![PhysicalSortExpr { + ] + .into(), + [PhysicalSortExpr { expr: Arc::new(Column::new("d", 0)), options: SortOptions::default(), - }]), + }] + .into(), ] ); assert!(result.is_infinite()); @@ -652,21 +654,24 @@ fn test_projection_after_projection() -> Result<()> { fn test_output_req_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); let sort_req: Arc = Arc::new(OutputRequirementExec::new( - csv.clone(), - Some(LexRequirement::new(vec![ - PhysicalSortRequirement { - expr: Arc::new(Column::new("b", 1)), - options: Some(SortOptions::default()), - }, - PhysicalSortRequirement { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("c", 2)), - Operator::Plus, - Arc::new(Column::new("a", 0)), - )), - options: Some(SortOptions::default()), - }, - ])), + csv, + Some(OrderingRequirements::new( + [ + PhysicalSortRequirement::new( + Arc::new(Column::new("b", 1)), + Some(SortOptions::default()), + ), + PhysicalSortRequirement::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )), + Some(SortOptions::default()), + ), + ] + .into(), + )), Distribution::HashPartitioned(vec![ Arc::new(Column::new("a", 0)), Arc::new(Column::new("b", 1)), @@ -699,20 +704,23 @@ fn test_output_req_after_projection() -> Result<()> { ]; assert_eq!(get_plan_string(&after_optimize), expected); - let expected_reqs = LexRequirement::new(vec![ - PhysicalSortRequirement { - expr: Arc::new(Column::new("b", 2)), - options: Some(SortOptions::default()), - }, - PhysicalSortRequirement { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("c", 0)), - Operator::Plus, - Arc::new(Column::new("new_a", 1)), - )), - options: Some(SortOptions::default()), - }, - ]); + let expected_reqs = OrderingRequirements::new( + [ + PhysicalSortRequirement::new( + Arc::new(Column::new("b", 2)), + Some(SortOptions::default()), + ), + PhysicalSortRequirement::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 0)), + Operator::Plus, + Arc::new(Column::new("new_a", 1)), + )), + Some(SortOptions::default()), + ), + ] + .into(), + ); assert_eq!( after_optimize .as_any() @@ -795,15 +803,15 @@ fn test_filter_after_projection() -> Result<()> { Arc::new(Column::new("a", 0)), )), )); - let filter: Arc = Arc::new(FilterExec::try_new(predicate, csv)?); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let filter = Arc::new(FilterExec::try_new(predicate, csv)?); + let projection = Arc::new(ProjectionExec::try_new( vec![ (Arc::new(Column::new("a", 0)), "a_new".to_string()), (Arc::new(Column::new("b", 1)), "b".to_string()), (Arc::new(Column::new("d", 3)), "d".to_string()), ], filter.clone(), - )?); + )?) as _; let initial = get_plan_string(&projection); let expected_initial = [ @@ -875,12 +883,12 @@ fn test_join_after_projection() -> Result<()> { ])), )), &JoinType::Inner, - true, + NullEquality::NullEqualsNull, None, None, StreamJoinPartitionMode::SinglePartition, )?); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let projection = Arc::new(ProjectionExec::try_new( vec![ (Arc::new(Column::new("c", 2)), "c_from_left".to_string()), (Arc::new(Column::new("b", 1)), "b_from_left".to_string()), @@ -889,7 +897,7 @@ fn test_join_after_projection() -> Result<()> { (Arc::new(Column::new("c", 7)), "c_from_right".to_string()), ], join, - )?); + )?) as _; let initial = get_plan_string(&projection); let expected_initial = [ "ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, a@5 as a_from_right, c@7 as c_from_right]", @@ -945,7 +953,7 @@ fn test_join_after_required_projection() -> Result<()> { let left_csv = create_simple_csv_exec(); let right_csv = create_simple_csv_exec(); - let join: Arc = Arc::new(SymmetricHashJoinExec::try_new( + let join = Arc::new(SymmetricHashJoinExec::try_new( left_csv, right_csv, vec![(Arc::new(Column::new("b", 1)), Arc::new(Column::new("c", 2)))], @@ -989,12 +997,12 @@ fn test_join_after_required_projection() -> Result<()> { ])), )), &JoinType::Inner, - true, + NullEquality::NullEqualsNull, None, None, StreamJoinPartitionMode::SinglePartition, )?); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let projection = Arc::new(ProjectionExec::try_new( vec![ (Arc::new(Column::new("a", 5)), "a".to_string()), (Arc::new(Column::new("b", 6)), "b".to_string()), @@ -1008,7 +1016,7 @@ fn test_join_after_required_projection() -> Result<()> { (Arc::new(Column::new("e", 4)), "e".to_string()), ], join, - )?); + )?) as _; let initial = get_plan_string(&projection); let expected_initial = [ "ProjectionExec: expr=[a@5 as a, b@6 as b, c@7 as c, d@8 as d, e@9 as e, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e]", @@ -1061,7 +1069,7 @@ fn test_nested_loop_join_after_projection() -> Result<()> { Field::new("c", DataType::Int32, true), ]); - let join: Arc = Arc::new(NestedLoopJoinExec::try_new( + let join = Arc::new(NestedLoopJoinExec::try_new( left_csv, right_csv, Some(JoinFilter::new( @@ -1071,12 +1079,12 @@ fn test_nested_loop_join_after_projection() -> Result<()> { )), &JoinType::Inner, None, - )?); + )?) as _; - let projection: Arc = Arc::new(ProjectionExec::try_new( + let projection = Arc::new(ProjectionExec::try_new( vec![(col_left_c, "c".to_string())], Arc::clone(&join), - )?); + )?) as _; let initial = get_plan_string(&projection); let expected_initial = [ "ProjectionExec: expr=[c@2 as c]", @@ -1104,7 +1112,7 @@ fn test_hash_join_after_projection() -> Result<()> { let left_csv = create_simple_csv_exec(); let right_csv = create_simple_csv_exec(); - let join: Arc = Arc::new(HashJoinExec::try_new( + let join = Arc::new(HashJoinExec::try_new( left_csv, right_csv, vec![(Arc::new(Column::new("b", 1)), Arc::new(Column::new("c", 2)))], @@ -1150,9 +1158,9 @@ fn test_hash_join_after_projection() -> Result<()> { &JoinType::Inner, None, PartitionMode::Auto, - true, + NullEquality::NullEqualsNull, )?); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let projection = Arc::new(ProjectionExec::try_new( vec![ (Arc::new(Column::new("c", 2)), "c_from_left".to_string()), (Arc::new(Column::new("b", 1)), "b_from_left".to_string()), @@ -1160,7 +1168,7 @@ fn test_hash_join_after_projection() -> Result<()> { (Arc::new(Column::new("c", 7)), "c_from_right".to_string()), ], join.clone(), - )?); + )?) as _; let initial = get_plan_string(&projection); let expected_initial = [ "ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@7 as c_from_right]", " HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false" @@ -1174,7 +1182,7 @@ fn test_hash_join_after_projection() -> Result<()> { let expected = ["ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, c@3 as c_from_right]", " HashJoinExec: mode=Auto, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2, projection=[a@0, b@1, c@2, c@7]", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false", " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=csv, has_header=false"]; assert_eq!(get_plan_string(&after_optimize), expected); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let projection = Arc::new(ProjectionExec::try_new( vec![ (Arc::new(Column::new("a", 0)), "a".to_string()), (Arc::new(Column::new("b", 1)), "b".to_string()), @@ -1197,7 +1205,7 @@ fn test_hash_join_after_projection() -> Result<()> { #[test] fn test_repartition_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); - let repartition: Arc = Arc::new(RepartitionExec::try_new( + let repartition = Arc::new(RepartitionExec::try_new( csv, Partitioning::Hash( vec![ @@ -1208,14 +1216,14 @@ fn test_repartition_after_projection() -> Result<()> { 6, ), )?); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let projection = Arc::new(ProjectionExec::try_new( vec![ (Arc::new(Column::new("b", 1)), "b_new".to_string()), (Arc::new(Column::new("a", 0)), "a".to_string()), (Arc::new(Column::new("d", 3)), "d_new".to_string()), ], repartition, - )?); + )?) as _; let initial = get_plan_string(&projection); let expected_initial = [ "ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new]", @@ -1257,31 +1265,26 @@ fn test_repartition_after_projection() -> Result<()> { #[test] fn test_sort_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); - let sort_req: Arc = Arc::new(SortExec::new( - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("c", 2)), - Operator::Plus, - Arc::new(Column::new("a", 0)), - )), - options: SortOptions::default(), - }, - ]), - csv.clone(), - )); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let sort_exec = SortExec::new( + [ + PhysicalSortExpr::new_default(Arc::new(Column::new("b", 1))), + PhysicalSortExpr::new_default(Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + ))), + ] + .into(), + csv, + ); + let projection = Arc::new(ProjectionExec::try_new( vec![ (Arc::new(Column::new("c", 2)), "c".to_string()), (Arc::new(Column::new("a", 0)), "new_a".to_string()), (Arc::new(Column::new("b", 1)), "b".to_string()), ], - sort_req.clone(), - )?); + Arc::new(sort_exec), + )?) as _; let initial = get_plan_string(&projection); let expected_initial = [ @@ -1307,31 +1310,26 @@ fn test_sort_after_projection() -> Result<()> { #[test] fn test_sort_preserving_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); - let sort_req: Arc = Arc::new(SortPreservingMergeExec::new( - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("c", 2)), - Operator::Plus, - Arc::new(Column::new("a", 0)), - )), - options: SortOptions::default(), - }, - ]), - csv.clone(), - )); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let sort_exec = SortPreservingMergeExec::new( + [ + PhysicalSortExpr::new_default(Arc::new(Column::new("b", 1))), + PhysicalSortExpr::new_default(Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + ))), + ] + .into(), + csv, + ); + let projection = Arc::new(ProjectionExec::try_new( vec![ (Arc::new(Column::new("c", 2)), "c".to_string()), (Arc::new(Column::new("a", 0)), "new_a".to_string()), (Arc::new(Column::new("b", 1)), "b".to_string()), ], - sort_req.clone(), - )?); + Arc::new(sort_exec), + )?) as _; let initial = get_plan_string(&projection); let expected_initial = [ @@ -1357,16 +1355,15 @@ fn test_sort_preserving_after_projection() -> Result<()> { #[test] fn test_union_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); - let union: Arc = - Arc::new(UnionExec::new(vec![csv.clone(), csv.clone(), csv])); - let projection: Arc = Arc::new(ProjectionExec::try_new( + let union = Arc::new(UnionExec::new(vec![csv.clone(), csv.clone(), csv])); + let projection = Arc::new(ProjectionExec::try_new( vec![ (Arc::new(Column::new("c", 2)), "c".to_string()), (Arc::new(Column::new("a", 0)), "new_a".to_string()), (Arc::new(Column::new("b", 1)), "b".to_string()), ], union.clone(), - )?); + )?) as _; let initial = get_plan_string(&projection); let expected_initial = [ diff --git a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs index 58eb866c590c..c9baa9a932ae 100644 --- a/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/tests/physical_optimizer/replace_with_order_preserving_variants.rs @@ -18,7 +18,10 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::{ - check_integrity, sort_preserving_merge_exec, stream_exec_ordered_with_projection, + check_integrity, coalesce_batches_exec, coalesce_partitions_exec, + create_test_schema3, parquet_exec_with_sort, sort_exec, + sort_exec_with_preserve_partitioning, sort_preserving_merge_exec, + sort_preserving_merge_exec_with_fetch, stream_exec_ordered_with_projection, }; use datafusion::prelude::SessionContext; @@ -26,26 +29,25 @@ use arrow::array::{ArrayRef, Int32Array}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use datafusion_common::tree_node::{TransformedResult, TreeNode}; +use datafusion_common::{assert_contains, NullEquality, Result}; +use datafusion_common::config::ConfigOptions; +use datafusion_datasource::source::DataSourceExec; use datafusion_execution::TaskContext; +use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::expressions::{self, col, Column}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_optimizer::enforce_sorting::replace_with_order_preserving_variants::{ + plan_with_order_breaking_variants, plan_with_order_preserving_variants, replace_with_order_preserving_variants, OrderPreservationContext +}; use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; -use datafusion_physical_plan::collect; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion::datasource::memory::MemorySourceConfig; use datafusion_physical_plan::repartition::RepartitionExec; -use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::{ - displayable, get_plan_string, ExecutionPlan, Partitioning, + collect, displayable, get_plan_string, ExecutionPlan, Partitioning, }; -use datafusion::datasource::source::DataSourceExec; -use datafusion_common::tree_node::{TransformedResult, TreeNode}; -use datafusion_common::Result; -use datafusion_expr::{JoinType, Operator}; -use datafusion_physical_expr::expressions::{self, col, Column}; -use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_optimizer::enforce_sorting::replace_with_order_preserving_variants::{replace_with_order_preserving_variants, OrderPreservationContext}; -use datafusion_common::config::ConfigOptions; use object_store::memory::InMemory; use object_store::ObjectStore; @@ -188,16 +190,15 @@ async fn test_replace_multiple_input_repartition_1( #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let sort_exprs: LexOrdering = [sort_expr("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, sort_exprs.clone()) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, sort_exprs.clone()) }; let repartition = repartition_exec_hash(repartition_exec_round_robin(source)); - let sort = sort_exec(vec![sort_expr("a", &schema)], repartition, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); + let sort = sort_exec_with_preserve_partitioning(sort_exprs.clone(), repartition); + let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -257,27 +258,22 @@ async fn test_with_inter_children_change_only( #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr_default("a", &schema)]; + let ordering: LexOrdering = [sort_expr_default("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, ordering.clone()) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, ordering.clone()) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let sort = sort_exec( - vec![sort_expr_default("a", &coalesce_partitions.schema())], - coalesce_partitions, - false, - ); + let sort = sort_exec(ordering.clone(), coalesce_partitions); let repartition_rr2 = repartition_exec_round_robin(sort); let repartition_hash2 = repartition_exec_hash(repartition_rr2); let filter = filter_exec(repartition_hash2); - let sort2 = sort_exec(vec![sort_expr_default("a", &filter.schema())], filter, true); + let sort2 = sort_exec_with_preserve_partitioning(ordering.clone(), filter); - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("a", &sort2.schema())], sort2); + let physical_plan = sort_preserving_merge_exec(ordering, sort2); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -360,18 +356,17 @@ async fn test_replace_multiple_input_repartition_2( #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, ordering.clone()) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, ordering.clone()) }; let repartition_rr = repartition_exec_round_robin(source); let filter = filter_exec(repartition_rr); let repartition_hash = repartition_exec_hash(filter); - let sort = sort_exec(vec![sort_expr("a", &schema)], repartition_hash, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); + let sort = sort_exec_with_preserve_partitioning(ordering.clone(), repartition_hash); + let physical_plan = sort_preserving_merge_exec(ordering, sort); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -436,19 +431,19 @@ async fn test_replace_multiple_input_repartition_with_extra_steps( #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, ordering.clone()) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, ordering.clone()) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches_exec: Arc = coalesce_batches_exec(filter); - let sort = sort_exec(vec![sort_expr("a", &schema)], coalesce_batches_exec, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); + let coalesce_batches_exec = coalesce_batches_exec(filter, 8192); + let sort = + sort_exec_with_preserve_partitioning(ordering.clone(), coalesce_batches_exec); + let physical_plan = sort_preserving_merge_exec(ordering, sort); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -518,20 +513,20 @@ async fn test_replace_multiple_input_repartition_with_extra_steps_2( #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, ordering.clone()) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, ordering.clone()) }; let repartition_rr = repartition_exec_round_robin(source); - let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr); + let coalesce_batches_exec_1 = coalesce_batches_exec(repartition_rr, 8192); let repartition_hash = repartition_exec_hash(coalesce_batches_exec_1); let filter = filter_exec(repartition_hash); - let coalesce_batches_exec_2 = coalesce_batches_exec(filter); - let sort = sort_exec(vec![sort_expr("a", &schema)], coalesce_batches_exec_2, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); + let coalesce_batches_exec_2 = coalesce_batches_exec(filter, 8192); + let sort = + sort_exec_with_preserve_partitioning(ordering.clone(), coalesce_batches_exec_2); + let physical_plan = sort_preserving_merge_exec(ordering, sort); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -606,19 +601,17 @@ async fn test_not_replacing_when_no_need_to_preserve_sorting( #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, ordering) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, ordering) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches_exec: Arc = coalesce_batches_exec(filter); - - let physical_plan: Arc = - coalesce_partitions_exec(coalesce_batches_exec); + let coalesce_batches_exec = coalesce_batches_exec(filter, 8192); + let physical_plan = coalesce_partitions_exec(coalesce_batches_exec); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -679,20 +672,19 @@ async fn test_with_multiple_replacable_repartitions( #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, ordering.clone()) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, ordering.clone()) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let filter = filter_exec(repartition_hash); - let coalesce_batches = coalesce_batches_exec(filter); + let coalesce_batches = coalesce_batches_exec(filter, 8192); let repartition_hash_2 = repartition_exec_hash(coalesce_batches); - let sort = sort_exec(vec![sort_expr("a", &schema)], repartition_hash_2, true); - - let physical_plan = sort_preserving_merge_exec(vec![sort_expr("a", &schema)], sort); + let sort = sort_exec_with_preserve_partitioning(ordering.clone(), repartition_hash_2); + let physical_plan = sort_preserving_merge_exec(ordering, sort); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -766,23 +758,21 @@ async fn test_not_replace_with_different_orderings( #[values(false, true)] source_unbounded: bool, #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { + use datafusion_physical_expr::LexOrdering; + let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let ordering_a = [sort_expr("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, ordering_a) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, ordering_a) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); - let sort = sort_exec( - vec![sort_expr_default("c", &repartition_hash.schema())], - repartition_hash, - true, - ); - - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("c", &sort.schema())], sort); + let ordering_c: LexOrdering = + [sort_expr_default("c", &repartition_hash.schema())].into(); + let sort = sort_exec_with_preserve_partitioning(ordering_c.clone(), repartition_hash); + let physical_plan = sort_preserving_merge_exec(ordering_c, sort); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -839,17 +829,16 @@ async fn test_with_lost_ordering( #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let ordering: LexOrdering = [sort_expr("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, ordering.clone()) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, ordering.clone()) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let physical_plan = - sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions, false); + let physical_plan = sort_exec(ordering, coalesce_partitions); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -908,28 +897,26 @@ async fn test_with_lost_and_kept_ordering( #[values(false, true)] source_unbounded: bool, #[values(false, true)] prefer_existing_sort: bool, ) -> Result<()> { + use datafusion_physical_expr::LexOrdering; + let schema = create_test_schema()?; - let sort_exprs = vec![sort_expr("a", &schema)]; + let ordering_a = [sort_expr("a", &schema)].into(); let source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, sort_exprs) + stream_exec_ordered_with_projection(&schema, ordering_a) } else { - memory_exec_sorted(&schema, sort_exprs) + memory_exec_sorted(&schema, ordering_a) }; let repartition_rr = repartition_exec_round_robin(source); let repartition_hash = repartition_exec_hash(repartition_rr); let coalesce_partitions = coalesce_partitions_exec(repartition_hash); - let sort = sort_exec( - vec![sort_expr_default("c", &coalesce_partitions.schema())], - coalesce_partitions, - false, - ); + let ordering_c: LexOrdering = + [sort_expr_default("c", &coalesce_partitions.schema())].into(); + let sort = sort_exec(ordering_c.clone(), coalesce_partitions); let repartition_rr2 = repartition_exec_round_robin(sort); let repartition_hash2 = repartition_exec_hash(repartition_rr2); let filter = filter_exec(repartition_hash2); - let sort2 = sort_exec(vec![sort_expr_default("c", &filter.schema())], filter, true); - - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("c", &sort2.schema())], sort2); + let sort2 = sort_exec_with_preserve_partitioning(ordering_c.clone(), filter); + let physical_plan = sort_preserving_merge_exec(ordering_c, sort2); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -1015,22 +1002,22 @@ async fn test_with_multiple_child_trees( ) -> Result<()> { let schema = create_test_schema()?; - let left_sort_exprs = vec![sort_expr("a", &schema)]; + let left_ordering = [sort_expr("a", &schema)].into(); let left_source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, left_sort_exprs) + stream_exec_ordered_with_projection(&schema, left_ordering) } else { - memory_exec_sorted(&schema, left_sort_exprs) + memory_exec_sorted(&schema, left_ordering) }; let left_repartition_rr = repartition_exec_round_robin(left_source); let left_repartition_hash = repartition_exec_hash(left_repartition_rr); let left_coalesce_partitions = Arc::new(CoalesceBatchesExec::new(left_repartition_hash, 4096)); - let right_sort_exprs = vec![sort_expr("a", &schema)]; + let right_ordering = [sort_expr("a", &schema)].into(); let right_source = if source_unbounded { - stream_exec_ordered_with_projection(&schema, right_sort_exprs) + stream_exec_ordered_with_projection(&schema, right_ordering) } else { - memory_exec_sorted(&schema, right_sort_exprs) + memory_exec_sorted(&schema, right_ordering) }; let right_repartition_rr = repartition_exec_round_robin(right_source); let right_repartition_hash = repartition_exec_hash(right_repartition_rr); @@ -1039,14 +1026,9 @@ async fn test_with_multiple_child_trees( let hash_join_exec = hash_join_exec(left_coalesce_partitions, right_coalesce_partitions); - let sort = sort_exec( - vec![sort_expr_default("a", &hash_join_exec.schema())], - hash_join_exec, - true, - ); - - let physical_plan = - sort_preserving_merge_exec(vec![sort_expr_default("a", &sort.schema())], sort); + let ordering: LexOrdering = [sort_expr_default("a", &hash_join_exec.schema())].into(); + let sort = sort_exec_with_preserve_partitioning(ordering.clone(), hash_join_exec); + let physical_plan = sort_preserving_merge_exec(ordering, sort); // Expected inputs unbounded and bounded let expected_input_unbounded = [ @@ -1145,18 +1127,6 @@ fn sort_expr_options( } } -fn sort_exec( - sort_exprs: impl IntoIterator, - input: Arc, - preserve_partitioning: bool, -) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new( - SortExec::new(sort_exprs, input) - .with_preserve_partitioning(preserve_partitioning), - ) -} - fn repartition_exec_round_robin(input: Arc) -> Arc { Arc::new(RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(8)).unwrap()) } @@ -1184,14 +1154,6 @@ fn filter_exec(input: Arc) -> Arc { Arc::new(FilterExec::try_new(predicate, input).unwrap()) } -fn coalesce_batches_exec(input: Arc) -> Arc { - Arc::new(CoalesceBatchesExec::new(input, 8192)) -} - -fn coalesce_partitions_exec(input: Arc) -> Arc { - Arc::new(CoalescePartitionsExec::new(input)) -} - fn hash_join_exec( left: Arc, right: Arc, @@ -1209,7 +1171,7 @@ fn hash_join_exec( &JoinType::Inner, None, PartitionMode::Partitioned, - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ) @@ -1229,7 +1191,7 @@ fn create_test_schema() -> Result { // projection parameter is given static due to testing needs fn memory_exec_sorted( schema: &SchemaRef, - sort_exprs: impl IntoIterator, + ordering: LexOrdering, ) -> Arc { pub fn make_partition(schema: &SchemaRef, sz: i32) -> RecordBatch { let values = (0..sz).collect::>(); @@ -1245,7 +1207,6 @@ fn memory_exec_sorted( let rows = 5; let partitions = 1; - let sort_exprs = sort_exprs.into_iter().collect(); Arc::new({ let data: Vec> = (0..partitions) .map(|_| vec![make_partition(schema, rows)]) @@ -1254,8 +1215,79 @@ fn memory_exec_sorted( DataSourceExec::new(Arc::new( MemorySourceConfig::try_new(&data, schema.clone(), Some(projection)) .unwrap() - .try_with_sort_information(vec![sort_exprs]) + .try_with_sort_information(vec![ordering]) .unwrap(), )) }) } + +#[test] +fn test_plan_with_order_preserving_variants_preserves_fetch() -> Result<()> { + // Create a schema + let schema = create_test_schema3()?; + let parquet_sort_exprs = vec![[sort_expr("a", &schema)].into()]; + let parquet_exec = parquet_exec_with_sort(schema, parquet_sort_exprs); + let coalesced = coalesce_partitions_exec(parquet_exec.clone()) + .with_fetch(Some(10)) + .unwrap(); + + // Test sort's fetch is greater than coalesce fetch, return error because it's not reasonable + let requirements = OrderPreservationContext::new( + coalesced.clone(), + false, + vec![OrderPreservationContext::new( + parquet_exec.clone(), + false, + vec![], + )], + ); + let res = plan_with_order_preserving_variants(requirements, false, true, Some(15)); + assert_contains!(res.unwrap_err().to_string(), "CoalescePartitionsExec fetch [10] should be greater than or equal to SortExec fetch [15]"); + + // Test sort is without fetch, expected to get the fetch value from the coalesced + let requirements = OrderPreservationContext::new( + coalesced.clone(), + false, + vec![OrderPreservationContext::new( + parquet_exec.clone(), + false, + vec![], + )], + ); + let res = plan_with_order_preserving_variants(requirements, false, true, None)?; + assert_eq!(res.plan.fetch(), Some(10),); + + // Test sort's fetch is less than coalesces fetch, expected to get the fetch value from the sort + let requirements = OrderPreservationContext::new( + coalesced, + false, + vec![OrderPreservationContext::new(parquet_exec, false, vec![])], + ); + let res = plan_with_order_preserving_variants(requirements, false, true, Some(5))?; + assert_eq!(res.plan.fetch(), Some(5),); + Ok(()) +} + +#[test] +fn test_plan_with_order_breaking_variants_preserves_fetch() -> Result<()> { + let schema = create_test_schema3()?; + let parquet_sort_exprs: LexOrdering = [sort_expr("a", &schema)].into(); + let parquet_exec = parquet_exec_with_sort(schema, vec![parquet_sort_exprs.clone()]); + let spm = sort_preserving_merge_exec_with_fetch( + parquet_sort_exprs, + parquet_exec.clone(), + 10, + ); + let requirements = OrderPreservationContext::new( + spm, + true, + vec![OrderPreservationContext::new( + parquet_exec.clone(), + true, + vec![], + )], + ); + let res = plan_with_order_breaking_variants(requirements)?; + assert_eq!(res.plan.fetch(), Some(10)); + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/sanity_checker.rs b/datafusion/core/tests/physical_optimizer/sanity_checker.rs index a73d084a081f..f7d68e5d899c 100644 --- a/datafusion/core/tests/physical_optimizer/sanity_checker.rs +++ b/datafusion/core/tests/physical_optimizer/sanity_checker.rs @@ -30,6 +30,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::{JoinType, Result}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::repartition::RepartitionExec; @@ -410,18 +411,19 @@ fn assert_plan(plan: &dyn ExecutionPlan, expected_lines: Vec<&str>) { async fn test_bounded_window_agg_sort_requirement() -> Result<()> { let schema = create_test_schema(); let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr_options( + let ordering: LexOrdering = [sort_expr_options( "c9", &source.schema(), SortOptions { descending: false, nulls_first: false, }, - )]; - let sort = sort_exec(sort_exprs.clone(), source); - let bw = bounded_window_exec("c9", sort_exprs, sort); + )] + .into(); + let sort = sort_exec(ordering.clone(), source); + let bw = bounded_window_exec("c9", ordering, sort); assert_plan(bw.as_ref(), vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " SortExec: expr=[c9@0 ASC NULLS LAST], preserve_partitioning=[false]", " DataSourceExec: partitions=1, partition_sizes=[0]" ]); @@ -444,7 +446,7 @@ async fn test_bounded_window_agg_no_sort_requirement() -> Result<()> { )]; let bw = bounded_window_exec("c9", sort_exprs, source); assert_plan(bw.as_ref(), vec![ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " DataSourceExec: partitions=1, partition_sizes=[0]" ]); // Order requirement of the `BoundedWindowAggExec` is not satisfied. We expect to receive error during sanity check. @@ -458,7 +460,7 @@ async fn test_bounded_window_agg_no_sort_requirement() -> Result<()> { async fn test_global_limit_single_partition() -> Result<()> { let schema = create_test_schema(); let source = memory_exec(&schema); - let limit = global_limit_exec(source); + let limit = global_limit_exec(source, 0, Some(100)); assert_plan( limit.as_ref(), @@ -477,7 +479,7 @@ async fn test_global_limit_single_partition() -> Result<()> { async fn test_global_limit_multi_partition() -> Result<()> { let schema = create_test_schema(); let source = memory_exec(&schema); - let limit = global_limit_exec(repartition_exec(source)); + let limit = global_limit_exec(repartition_exec(source), 0, Some(100)); assert_plan( limit.as_ref(), @@ -497,7 +499,7 @@ async fn test_global_limit_multi_partition() -> Result<()> { async fn test_local_limit() -> Result<()> { let schema = create_test_schema(); let source = memory_exec(&schema); - let limit = local_limit_exec(source); + let limit = local_limit_exec(source, 100); assert_plan( limit.as_ref(), @@ -518,12 +520,12 @@ async fn test_sort_merge_join_satisfied() -> Result<()> { let source1 = memory_exec(&schema1); let source2 = memory_exec(&schema2); let sort_opts = SortOptions::default(); - let sort_exprs1 = vec![sort_expr_options("c9", &source1.schema(), sort_opts)]; - let sort_exprs2 = vec![sort_expr_options("a", &source2.schema(), sort_opts)]; - let left = sort_exec(sort_exprs1, source1); - let right = sort_exec(sort_exprs2, source2); - let left_jcol = col("c9", &left.schema()).unwrap(); - let right_jcol = col("a", &right.schema()).unwrap(); + let ordering1 = [sort_expr_options("c9", &source1.schema(), sort_opts)].into(); + let ordering2 = [sort_expr_options("a", &source2.schema(), sort_opts)].into(); + let left = sort_exec(ordering1, source1); + let right = sort_exec(ordering2, source2); + let left_jcol = col("c9", &left.schema())?; + let right_jcol = col("a", &right.schema())?; let left = Arc::new(RepartitionExec::try_new( left, Partitioning::Hash(vec![left_jcol.clone()], 10), @@ -562,15 +564,16 @@ async fn test_sort_merge_join_order_missing() -> Result<()> { let schema2 = create_test_schema2(); let source1 = memory_exec(&schema1); let right = memory_exec(&schema2); - let sort_exprs1 = vec![sort_expr_options( + let ordering1 = [sort_expr_options( "c9", &source1.schema(), SortOptions::default(), - )]; - let left = sort_exec(sort_exprs1, source1); + )] + .into(); + let left = sort_exec(ordering1, source1); // Missing sort of the right child here.. - let left_jcol = col("c9", &left.schema()).unwrap(); - let right_jcol = col("a", &right.schema()).unwrap(); + let left_jcol = col("c9", &left.schema())?; + let right_jcol = col("a", &right.schema())?; let left = Arc::new(RepartitionExec::try_new( left, Partitioning::Hash(vec![left_jcol.clone()], 10), @@ -610,16 +613,16 @@ async fn test_sort_merge_join_dist_missing() -> Result<()> { let source1 = memory_exec(&schema1); let source2 = memory_exec(&schema2); let sort_opts = SortOptions::default(); - let sort_exprs1 = vec![sort_expr_options("c9", &source1.schema(), sort_opts)]; - let sort_exprs2 = vec![sort_expr_options("a", &source2.schema(), sort_opts)]; - let left = sort_exec(sort_exprs1, source1); - let right = sort_exec(sort_exprs2, source2); + let ordering1 = [sort_expr_options("c9", &source1.schema(), sort_opts)].into(); + let ordering2 = [sort_expr_options("a", &source2.schema(), sort_opts)].into(); + let left = sort_exec(ordering1, source1); + let right = sort_exec(ordering2, source2); let right = Arc::new(RepartitionExec::try_new( right, Partitioning::RoundRobinBatch(10), )?); - let left_jcol = col("c9", &left.schema()).unwrap(); - let right_jcol = col("a", &right.schema()).unwrap(); + let left_jcol = col("c9", &left.schema())?; + let right_jcol = col("a", &right.schema())?; let left = Arc::new(RepartitionExec::try_new( left, Partitioning::Hash(vec![left_jcol.clone()], 10), diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs b/datafusion/core/tests/physical_optimizer/test_utils.rs index 4587f99989d3..c91a70989be4 100644 --- a/datafusion/core/tests/physical_optimizer/test_utils.rs +++ b/datafusion/core/tests/physical_optimizer/test_utils.rs @@ -30,19 +30,20 @@ use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::physical_plan::ParquetSource; use datafusion::datasource::source::DataSourceExec; use datafusion_common::config::ConfigOptions; +use datafusion_common::stats::Precision; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; -use datafusion_common::{JoinType, Result}; +use datafusion_common::{ColumnStatistics, JoinType, NullEquality, Result, Statistics}; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; -use datafusion_physical_expr::expressions::col; -use datafusion_physical_expr::{expressions, PhysicalExpr}; +use datafusion_physical_expr::expressions::{self, col}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ - LexOrdering, LexRequirement, PhysicalSortExpr, + LexOrdering, OrderingRequirements, PhysicalSortExpr, }; use datafusion_physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggregation; use datafusion_physical_optimizer::PhysicalOptimizerRule; @@ -55,6 +56,7 @@ use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::{JoinFilter, JoinOn}; use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; @@ -68,10 +70,10 @@ use datafusion_physical_plan::{ }; /// Create a non sorted parquet exec -pub fn parquet_exec(schema: &SchemaRef) -> Arc { +pub fn parquet_exec(schema: SchemaRef) -> Arc { let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema.clone(), + schema, Arc::new(ParquetSource::default()), ) .with_file(PartitionedFile::new("x".to_string(), 100)) @@ -82,11 +84,12 @@ pub fn parquet_exec(schema: &SchemaRef) -> Arc { /// Create a single parquet file that is sorted pub(crate) fn parquet_exec_with_sort( + schema: SchemaRef, output_ordering: Vec, ) -> Arc { let config = FileScanConfigBuilder::new( ObjectStoreUrl::parse("test:///").unwrap(), - schema(), + schema, Arc::new(ParquetSource::default()), ) .with_file(PartitionedFile::new("x".to_string(), 100)) @@ -96,6 +99,48 @@ pub(crate) fn parquet_exec_with_sort( DataSourceExec::from_data_source(config) } +fn int64_stats() -> ColumnStatistics { + ColumnStatistics { + null_count: Precision::Absent, + sum_value: Precision::Absent, + max_value: Precision::Exact(1_000_000.into()), + min_value: Precision::Exact(0.into()), + distinct_count: Precision::Absent, + } +} + +fn column_stats() -> Vec { + vec![ + int64_stats(), // a + int64_stats(), // b + int64_stats(), // c + ColumnStatistics::default(), + ColumnStatistics::default(), + ] +} + +/// Create parquet datasource exec using schema from [`schema`]. +pub(crate) fn parquet_exec_with_stats(file_size: u64) -> Arc { + let mut statistics = Statistics::new_unknown(&schema()); + statistics.num_rows = Precision::Inexact(10000); + statistics.column_statistics = column_stats(); + + let config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("test:///").unwrap(), + schema(), + Arc::new(ParquetSource::new(Default::default())), + ) + .with_file(PartitionedFile::new("x".to_string(), file_size)) + .with_statistics(statistics) + .build(); + + assert_eq!( + config.file_source.statistics().unwrap().num_rows, + Precision::Inexact(10000) + ); + DataSourceExec::from_data_source(config) +} + pub fn schema() -> SchemaRef { Arc::new(Schema::new(vec![ Field::new("a", DataType::Int64, true), @@ -145,7 +190,7 @@ pub fn sort_merge_join_exec( None, *join_type, vec![SortOptions::default(); join_on.len()], - false, + NullEquality::NullEqualsNothing, ) .unwrap(), ) @@ -191,7 +236,7 @@ pub fn hash_join_exec( join_type, None, PartitionMode::Partitioned, - true, + NullEquality::NullEqualsNull, )?)) } @@ -200,14 +245,23 @@ pub fn bounded_window_exec( sort_exprs: impl IntoIterator, input: Arc, ) -> Arc { - let sort_exprs: LexOrdering = sort_exprs.into_iter().collect(); + bounded_window_exec_with_partition(col_name, sort_exprs, &[], input) +} + +pub fn bounded_window_exec_with_partition( + col_name: &str, + sort_exprs: impl IntoIterator, + partition_by: &[Arc], + input: Arc, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect::>(); let schema = input.schema(); let window_expr = create_window_expr( &WindowFunctionDefinition::AggregateUDF(count_udaf()), "count".to_owned(), &[col(col_name, &schema).unwrap()], - &[], - sort_exprs.as_ref(), + partition_by, + &sort_exprs, Arc::new(WindowFrame::new(Some(false))), schema.as_ref(), false, @@ -233,36 +287,37 @@ pub fn filter_exec( } pub fn sort_preserving_merge_exec( - sort_exprs: impl IntoIterator, + ordering: LexOrdering, input: Arc, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) + Arc::new(SortPreservingMergeExec::new(ordering, input)) } pub fn sort_preserving_merge_exec_with_fetch( - sort_exprs: impl IntoIterator, + ordering: LexOrdering, input: Arc, fetch: usize, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortPreservingMergeExec::new(sort_exprs, input).with_fetch(Some(fetch))) + Arc::new(SortPreservingMergeExec::new(ordering, input).with_fetch(Some(fetch))) } pub fn union_exec(input: Vec>) -> Arc { Arc::new(UnionExec::new(input)) } -pub fn limit_exec(input: Arc) -> Arc { - global_limit_exec(local_limit_exec(input)) -} - -pub fn local_limit_exec(input: Arc) -> Arc { - Arc::new(LocalLimitExec::new(input, 100)) +pub fn local_limit_exec( + input: Arc, + fetch: usize, +) -> Arc { + Arc::new(LocalLimitExec::new(input, fetch)) } -pub fn global_limit_exec(input: Arc) -> Arc { - Arc::new(GlobalLimitExec::new(input, 0, Some(100))) +pub fn global_limit_exec( + input: Arc, + skip: usize, + fetch: Option, +) -> Arc { + Arc::new(GlobalLimitExec::new(input, skip, fetch)) } pub fn repartition_exec(input: Arc) -> Arc { @@ -292,30 +347,46 @@ pub fn aggregate_exec(input: Arc) -> Arc { ) } -pub fn coalesce_batches_exec(input: Arc) -> Arc { - Arc::new(CoalesceBatchesExec::new(input, 128)) +pub fn coalesce_batches_exec( + input: Arc, + batch_size: usize, +) -> Arc { + Arc::new(CoalesceBatchesExec::new(input, batch_size)) } pub fn sort_exec( - sort_exprs: impl IntoIterator, + ordering: LexOrdering, + input: Arc, +) -> Arc { + sort_exec_with_fetch(ordering, None, input) +} + +pub fn sort_exec_with_preserve_partitioning( + ordering: LexOrdering, input: Arc, ) -> Arc { - sort_exec_with_fetch(sort_exprs, None, input) + Arc::new(SortExec::new(ordering, input).with_preserve_partitioning(true)) } pub fn sort_exec_with_fetch( - sort_exprs: impl IntoIterator, + ordering: LexOrdering, fetch: Option, input: Arc, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortExec::new(sort_exprs, input).with_fetch(fetch)) + Arc::new(SortExec::new(ordering, input).with_fetch(fetch)) +} + +pub fn projection_exec( + expr: Vec<(Arc, String)>, + input: Arc, +) -> Result> { + Ok(Arc::new(ProjectionExec::try_new(expr, input)?)) } /// A test [`ExecutionPlan`] whose requirements can be configured. #[derive(Debug)] pub struct RequirementsTestExec { - required_input_ordering: LexOrdering, + required_input_ordering: Option, maintains_input_order: bool, input: Arc, } @@ -323,7 +394,7 @@ pub struct RequirementsTestExec { impl RequirementsTestExec { pub fn new(input: Arc) -> Self { Self { - required_input_ordering: LexOrdering::default(), + required_input_ordering: None, maintains_input_order: true, input, } @@ -332,7 +403,7 @@ impl RequirementsTestExec { /// sets the required input ordering pub fn with_required_input_ordering( mut self, - required_input_ordering: LexOrdering, + required_input_ordering: Option, ) -> Self { self.required_input_ordering = required_input_ordering; self @@ -377,9 +448,11 @@ impl ExecutionPlan for RequirementsTestExec { self.input.properties() } - fn required_input_ordering(&self) -> Vec> { - let requirement = LexRequirement::from(self.required_input_ordering.clone()); - vec![Some(requirement)] + fn required_input_ordering(&self) -> Vec> { + vec![self + .required_input_ordering + .as_ref() + .map(|ordering| OrderingRequirements::from(ordering.clone()))] } fn maintains_input_order(&self) -> Vec { @@ -458,13 +531,28 @@ impl PartitionStream for TestStreamPartition { } } -/// Create an unbounded stream exec +/// Create an unbounded stream table without data ordering. +pub fn stream_exec(schema: &SchemaRef) -> Arc { + Arc::new( + StreamingTableExec::try_new( + Arc::clone(schema), + vec![Arc::new(TestStreamPartition { + schema: Arc::clone(schema), + }) as _], + None, + vec![], + true, + None, + ) + .unwrap(), + ) +} + +/// Create an unbounded stream table with data ordering. pub fn stream_exec_ordered( schema: &SchemaRef, - sort_exprs: impl IntoIterator, + ordering: LexOrdering, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new( StreamingTableExec::try_new( Arc::clone(schema), @@ -472,7 +560,7 @@ pub fn stream_exec_ordered( schema: Arc::clone(schema), }) as _], None, - vec![sort_exprs], + vec![ordering], true, None, ) @@ -480,12 +568,11 @@ pub fn stream_exec_ordered( ) } -// Creates a stream exec source for the test purposes +/// Create an unbounded stream table with data ordering and built-in projection. pub fn stream_exec_ordered_with_projection( schema: &SchemaRef, - sort_exprs: impl IntoIterator, + ordering: LexOrdering, ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); let projection: Vec = vec![0, 2, 3]; Arc::new( @@ -495,7 +582,7 @@ pub fn stream_exec_ordered_with_projection( schema: Arc::clone(schema), }) as _], Some(&projection), - vec![sort_exprs], + vec![ordering], true, None, ) @@ -557,8 +644,7 @@ pub fn assert_plan_matches_expected( assert_eq!( &expected_lines, &actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines + "\n\nexpected:\n\n{expected_lines:#?}\nactual:\n\n{actual_lines:#?}\n\n" ); Ok(()) diff --git a/datafusion/core/tests/physical_optimizer/window_optimize.rs b/datafusion/core/tests/physical_optimizer/window_optimize.rs new file mode 100644 index 000000000000..ba0ffb022a03 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/window_optimize.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +mod test { + use arrow::array::{Int32Array, RecordBatch}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::Result; + use datafusion_datasource::memory::MemorySourceConfig; + use datafusion_datasource::source::DataSourceExec; + use datafusion_execution::TaskContext; + use datafusion_expr::WindowFrame; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::expressions::{col, Column}; + use datafusion_physical_expr::window::PlainAggregateWindowExpr; + use datafusion_physical_plan::windows::BoundedWindowAggExec; + use datafusion_physical_plan::{common, ExecutionPlan, InputOrderMode}; + use std::sync::Arc; + + /// Test case for + #[tokio::test] + async fn test_window_constant_aggregate() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let c = Arc::new(Column::new("b", 1)); + let cnt = AggregateExprBuilder::new(count_udaf(), vec![c]) + .schema(schema.clone()) + .alias("t") + .build()?; + let parition = [col("a", &schema)?]; + let frame = WindowFrame::new(None); + let plain = + PlainAggregateWindowExpr::new(Arc::new(cnt), &parition, &[], Arc::new(frame)); + + let bounded_agg_exec = BoundedWindowAggExec::try_new( + vec![Arc::new(plain)], + source, + InputOrderMode::Linear, + true, + )?; + let task_ctx = Arc::new(TaskContext::default()); + common::collect(bounded_agg_exec.execute(0, task_ctx)?).await?; + + Ok(()) + } + + pub fn mock_data() -> Result> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![ + Some(1), + Some(1), + Some(3), + Some(2), + Some(1), + ])), + Arc::new(Int32Array::from(vec![ + Some(1), + Some(6), + Some(2), + Some(8), + Some(9), + ])), + ], + )?; + + MemorySourceConfig::try_new_exec(&[vec![batch]], Arc::clone(&schema), None) + } +} diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 52372e01d41a..b705448203d7 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -16,7 +16,10 @@ // under the License. use super::*; -use datafusion::scalar::ScalarValue; +use datafusion::common::test_util::batches_to_string; +use datafusion_catalog::MemTable; +use datafusion_common::ScalarValue; +use insta::assert_snapshot; #[tokio::test] async fn csv_query_array_agg_distinct() -> Result<()> { @@ -321,3 +324,120 @@ async fn test_accumulator_row_accumulator() -> Result<()> { Ok(()) } + +/// Test that COUNT(DISTINCT) correctly handles dictionary arrays with all null values. +/// Verifies behavior across both single and multiple partitions. +#[tokio::test] +async fn count_distinct_dictionary_all_null_values() -> Result<()> { + let n: usize = 5; + let num = Arc::new(Int32Array::from_iter(0..n as i32)) as ArrayRef; + + // Create dictionary where all indices point to a null value (index 0) + let dict_values = StringArray::from(vec![None, Some("abc")]); + let dict_indices = Int32Array::from(vec![0; n]); + let dict = DictionaryArray::new(dict_indices, Arc::new(dict_values)); + + let schema = Arc::new(Schema::new(vec![ + Field::new("num1", DataType::Int32, false), + Field::new("num2", DataType::Int32, false), + Field::new( + "dict", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![num.clone(), num.clone(), Arc::new(dict)], + )?; + + // Test with single partition + let ctx = + SessionContext::new_with_config(SessionConfig::new().with_target_partitions(1)); + let provider = MemTable::try_new(schema.clone(), vec![vec![batch.clone()]])?; + ctx.register_table("t", Arc::new(provider))?; + + let df = ctx + .sql("SELECT count(distinct dict) as cnt, count(num2) FROM t GROUP BY num1") + .await?; + let results = df.collect().await?; + + assert_snapshot!( + batches_to_string(&results), + @r###" + +-----+---------------+ + | cnt | count(t.num2) | + +-----+---------------+ + | 0 | 1 | + | 0 | 1 | + | 0 | 1 | + | 0 | 1 | + | 0 | 1 | + +-----+---------------+ + "### + ); + + // Test with multiple partitions + let ctx_multi = + SessionContext::new_with_config(SessionConfig::new().with_target_partitions(2)); + let provider_multi = MemTable::try_new(schema, vec![vec![batch]])?; + ctx_multi.register_table("t", Arc::new(provider_multi))?; + + let df_multi = ctx_multi + .sql("SELECT count(distinct dict) as cnt, count(num2) FROM t GROUP BY num1") + .await?; + let results_multi = df_multi.collect().await?; + + // Results should be identical across partition configurations + assert_eq!( + batches_to_string(&results), + batches_to_string(&results_multi) + ); + + Ok(()) +} + +/// Test COUNT(DISTINCT) with mixed null and non-null dictionary values +#[tokio::test] +async fn count_distinct_dictionary_mixed_values() -> Result<()> { + let n: usize = 6; + let num = Arc::new(Int32Array::from_iter(0..n as i32)) as ArrayRef; + + // Dictionary values array with nulls and non-nulls + let dict_values = StringArray::from(vec![None, Some("abc"), Some("def"), None]); + // Create indices that point to both null and non-null values + let dict_indices = Int32Array::from(vec![0, 1, 2, 0, 1, 3]); + let dict = DictionaryArray::new(dict_indices, Arc::new(dict_values)); + + let schema = Arc::new(Schema::new(vec![ + Field::new("num1", DataType::Int32, false), + Field::new( + "dict", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + ])); + + let batch = RecordBatch::try_new(schema.clone(), vec![num, Arc::new(dict)])?; + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + let ctx = SessionContext::new(); + ctx.register_table("t", Arc::new(provider))?; + + // COUNT(DISTINCT) should only count non-null values "abc" and "def" + let df = ctx.sql("SELECT count(distinct dict) FROM t").await?; + let results = df.collect().await?; + + assert_snapshot!( + batches_to_string(&results), + @r###" + +------------------------+ + | count(DISTINCT t.dict) | + +------------------------+ + | 2 | + +------------------------+ + "### + ); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/create_drop.rs b/datafusion/core/tests/sql/create_drop.rs index 83712053b954..b35e614a464e 100644 --- a/datafusion/core/tests/sql/create_drop.rs +++ b/datafusion/core/tests/sql/create_drop.rs @@ -61,7 +61,7 @@ async fn create_external_table_with_ddl() -> Result<()> { assert_eq!(3, table_schema.fields().len()); assert_eq!(&DataType::Int32, table_schema.field(0).data_type()); - assert_eq!(&DataType::Utf8, table_schema.field(1).data_type()); + assert_eq!(&DataType::Utf8View, table_schema.field(1).data_type()); assert_eq!(&DataType::Boolean, table_schema.field(2).data_type()); Ok(()) diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index e8ef34c2afe7..852b350b27df 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -16,6 +16,7 @@ // under the License. use super::*; +use insta::assert_snapshot; use rstest::rstest; use datafusion::config::ConfigOptions; @@ -52,6 +53,7 @@ async fn explain_analyze_baseline_metrics() { let formatted = arrow::util::pretty::pretty_format_batches(&results) .unwrap() .to_string(); + println!("Query Output:\n\n{formatted}"); assert_metrics!( @@ -174,69 +176,66 @@ async fn csv_explain_plans() { println!("SQL: {sql}"); // // Verify schema - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: aggregate_test_100.c1 [c1:Utf8]", - " Filter: aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]", - " TableScan: aggregate_test_100 [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]", - ]; let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + Explain [plan_type:Utf8, plan:Utf8] + Projection: aggregate_test_100.c1 [c1:Utf8View] + Filter: aggregate_test_100.c2 > Int64(10) [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View] + TableScan: aggregate_test_100 [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View] + " ); // // Verify the text format of the plan - let expected = vec![ - "Explain", - " Projection: aggregate_test_100.c1", - " Filter: aggregate_test_100.c2 > Int64(10)", - " TableScan: aggregate_test_100", - ]; let formatted = plan.display_indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r###" + Explain + Projection: aggregate_test_100.c1 + Filter: aggregate_test_100.c2 > Int64(10) + TableScan: aggregate_test_100 + "### ); // // verify the grahviz format of the plan - let expected = vec![ - "// Begin DataFusion GraphViz Plan,", - "// display it online here: https://dreampuf.github.io/GraphvizOnline", - "", - "digraph {", - " subgraph cluster_1", - " {", - " graph[label=\"LogicalPlan\"]", - " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: aggregate_test_100.c1\"]", - " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: aggregate_test_100.c2 > Int64(10)\"]", - " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100\"]", - " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - " subgraph cluster_6", - " {", - " graph[label=\"Detailed LogicalPlan\"]", - " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", - " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]", - " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]", - " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - "}", - "// End DataFusion GraphViz Plan", - ]; let formatted = plan.display_graphviz().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r#" + // Begin DataFusion GraphViz Plan, + // display it online here: https://dreampuf.github.io/GraphvizOnline + + digraph { + subgraph cluster_1 + { + graph[label="LogicalPlan"] + 2[shape=box label="Explain"] + 3[shape=box label="Projection: aggregate_test_100.c1"] + 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] + 4[shape=box label="Filter: aggregate_test_100.c2 > Int64(10)"] + 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] + 5[shape=box label="TableScan: aggregate_test_100"] + 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] + } + subgraph cluster_6 + { + graph[label="Detailed LogicalPlan"] + 7[shape=box label="Explain\nSchema: [plan_type:Utf8, plan:Utf8]"] + 8[shape=box label="Projection: aggregate_test_100.c1\nSchema: [c1:Utf8View]"] + 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back] + 9[shape=box label="Filter: aggregate_test_100.c2 > Int64(10)\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]"] + 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] + 10[shape=box label="TableScan: aggregate_test_100\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]"] + 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] + } + } + // End DataFusion GraphViz Plan + "# ); // Optimized logical plan @@ -248,69 +247,67 @@ async fn csv_explain_plans() { assert_eq!(logical_schema, optimized_logical_schema.as_ref()); // // Verify schema - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: aggregate_test_100.c1 [c1:Utf8]", - " Filter: aggregate_test_100.c2 > Int8(10) [c1:Utf8, c2:Int8]", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] [c1:Utf8, c2:Int8]", - ]; let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + Explain [plan_type:Utf8, plan:Utf8] + Projection: aggregate_test_100.c1 [c1:Utf8View] + Filter: aggregate_test_100.c2 > Int8(10) [c1:Utf8View, c2:Int8] + TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] [c1:Utf8View, c2:Int8] + " ); // // Verify the text format of the plan - let expected = vec![ - "Explain", - " Projection: aggregate_test_100.c1", - " Filter: aggregate_test_100.c2 > Int8(10)", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]", - ]; let formatted = plan.display_indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r###" + Explain + Projection: aggregate_test_100.c1 + Filter: aggregate_test_100.c2 > Int8(10) + TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] + + "### ); // // verify the grahviz format of the plan - let expected = vec![ - "// Begin DataFusion GraphViz Plan,", - "// display it online here: https://dreampuf.github.io/GraphvizOnline", - "", - "digraph {", - " subgraph cluster_1", - " {", - " graph[label=\"LogicalPlan\"]", - " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: aggregate_test_100.c1\"]", - " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: aggregate_test_100.c2 > Int8(10)\"]", - " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\"]", - " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - " subgraph cluster_6", - " {", - " graph[label=\"Detailed LogicalPlan\"]", - " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", - " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int8(10)\\nSchema: [c1:Utf8, c2:Int8]\"]", - " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\\nSchema: [c1:Utf8, c2:Int8]\"]", - " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - "}", - "// End DataFusion GraphViz Plan", - ]; let formatted = plan.display_graphviz().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r#" + // Begin DataFusion GraphViz Plan, + // display it online here: https://dreampuf.github.io/GraphvizOnline + + digraph { + subgraph cluster_1 + { + graph[label="LogicalPlan"] + 2[shape=box label="Explain"] + 3[shape=box label="Projection: aggregate_test_100.c1"] + 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] + 4[shape=box label="Filter: aggregate_test_100.c2 > Int8(10)"] + 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] + 5[shape=box label="TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]"] + 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] + } + subgraph cluster_6 + { + graph[label="Detailed LogicalPlan"] + 7[shape=box label="Explain\nSchema: [plan_type:Utf8, plan:Utf8]"] + 8[shape=box label="Projection: aggregate_test_100.c1\nSchema: [c1:Utf8View]"] + 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back] + 9[shape=box label="Filter: aggregate_test_100.c2 > Int8(10)\nSchema: [c1:Utf8View, c2:Int8]"] + 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] + 10[shape=box label="TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\nSchema: [c1:Utf8View, c2:Int8]"] + 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] + } + } + // End DataFusion GraphViz Plan + "# ); // Physical plan @@ -396,69 +393,66 @@ async fn csv_explain_verbose_plans() { // // Verify schema - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: aggregate_test_100.c1 [c1:Utf8]", - " Filter: aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]", - " TableScan: aggregate_test_100 [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]", - ]; let formatted = dataframe.logical_plan().display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + Explain [plan_type:Utf8, plan:Utf8] + Projection: aggregate_test_100.c1 [c1:Utf8View] + Filter: aggregate_test_100.c2 > Int64(10) [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View] + TableScan: aggregate_test_100 [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View] + " ); // // Verify the text format of the plan - let expected = vec![ - "Explain", - " Projection: aggregate_test_100.c1", - " Filter: aggregate_test_100.c2 > Int64(10)", - " TableScan: aggregate_test_100", - ]; let formatted = dataframe.logical_plan().display_indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r###" + Explain + Projection: aggregate_test_100.c1 + Filter: aggregate_test_100.c2 > Int64(10) + TableScan: aggregate_test_100 + "### ); // // verify the grahviz format of the plan - let expected = vec![ - "// Begin DataFusion GraphViz Plan,", - "// display it online here: https://dreampuf.github.io/GraphvizOnline", - "", - "digraph {", - " subgraph cluster_1", - " {", - " graph[label=\"LogicalPlan\"]", - " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: aggregate_test_100.c1\"]", - " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: aggregate_test_100.c2 > Int64(10)\"]", - " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100\"]", - " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - " subgraph cluster_6", - " {", - " graph[label=\"Detailed LogicalPlan\"]", - " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", - " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]", - " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]", - " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - "}", - "// End DataFusion GraphViz Plan", - ]; let formatted = dataframe.logical_plan().display_graphviz().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r#" + // Begin DataFusion GraphViz Plan, + // display it online here: https://dreampuf.github.io/GraphvizOnline + + digraph { + subgraph cluster_1 + { + graph[label="LogicalPlan"] + 2[shape=box label="Explain"] + 3[shape=box label="Projection: aggregate_test_100.c1"] + 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] + 4[shape=box label="Filter: aggregate_test_100.c2 > Int64(10)"] + 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] + 5[shape=box label="TableScan: aggregate_test_100"] + 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] + } + subgraph cluster_6 + { + graph[label="Detailed LogicalPlan"] + 7[shape=box label="Explain\nSchema: [plan_type:Utf8, plan:Utf8]"] + 8[shape=box label="Projection: aggregate_test_100.c1\nSchema: [c1:Utf8View]"] + 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back] + 9[shape=box label="Filter: aggregate_test_100.c2 > Int64(10)\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]"] + 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] + 10[shape=box label="TableScan: aggregate_test_100\nSchema: [c1:Utf8View, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8View]"] + 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] + } + } + // End DataFusion GraphViz Plan + "# ); // Optimized logical plan @@ -470,69 +464,66 @@ async fn csv_explain_verbose_plans() { assert_eq!(&logical_schema, optimized_logical_schema.as_ref()); // // Verify schema - let expected = vec![ - "Explain [plan_type:Utf8, plan:Utf8]", - " Projection: aggregate_test_100.c1 [c1:Utf8]", - " Filter: aggregate_test_100.c2 > Int8(10) [c1:Utf8, c2:Int8]", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] [c1:Utf8, c2:Int8]", - ]; let formatted = plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + Explain [plan_type:Utf8, plan:Utf8] + Projection: aggregate_test_100.c1 [c1:Utf8View] + Filter: aggregate_test_100.c2 > Int8(10) [c1:Utf8View, c2:Int8] + TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] [c1:Utf8View, c2:Int8] + " ); // // Verify the text format of the plan - let expected = vec![ - "Explain", - " Projection: aggregate_test_100.c1", - " Filter: aggregate_test_100.c2 > Int8(10)", - " TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]", - ]; let formatted = plan.display_indent().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r###" + Explain + Projection: aggregate_test_100.c1 + Filter: aggregate_test_100.c2 > Int8(10) + TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] + "### ); // // verify the grahviz format of the plan - let expected = vec![ - "// Begin DataFusion GraphViz Plan,", - "// display it online here: https://dreampuf.github.io/GraphvizOnline", - "", - "digraph {", - " subgraph cluster_1", - " {", - " graph[label=\"LogicalPlan\"]", - " 2[shape=box label=\"Explain\"]", - " 3[shape=box label=\"Projection: aggregate_test_100.c1\"]", - " 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]", - " 4[shape=box label=\"Filter: aggregate_test_100.c2 > Int8(10)\"]", - " 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]", - " 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\"]", - " 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - " subgraph cluster_6", - " {", - " graph[label=\"Detailed LogicalPlan\"]", - " 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]", - " 8[shape=box label=\"Projection: aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]", - " 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]", - " 9[shape=box label=\"Filter: aggregate_test_100.c2 > Int8(10)\\nSchema: [c1:Utf8, c2:Int8]\"]", - " 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]", - " 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\\nSchema: [c1:Utf8, c2:Int8]\"]", - " 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]", - " }", - "}", - "// End DataFusion GraphViz Plan", - ]; let formatted = plan.display_graphviz().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r#" + // Begin DataFusion GraphViz Plan, + // display it online here: https://dreampuf.github.io/GraphvizOnline + + digraph { + subgraph cluster_1 + { + graph[label="LogicalPlan"] + 2[shape=box label="Explain"] + 3[shape=box label="Projection: aggregate_test_100.c1"] + 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] + 4[shape=box label="Filter: aggregate_test_100.c2 > Int8(10)"] + 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] + 5[shape=box label="TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]"] + 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] + } + subgraph cluster_6 + { + graph[label="Detailed LogicalPlan"] + 7[shape=box label="Explain\nSchema: [plan_type:Utf8, plan:Utf8]"] + 8[shape=box label="Projection: aggregate_test_100.c1\nSchema: [c1:Utf8View]"] + 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back] + 9[shape=box label="Filter: aggregate_test_100.c2 > Int8(10)\nSchema: [c1:Utf8View, c2:Int8]"] + 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] + 10[shape=box label="TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)]\nSchema: [c1:Utf8View, c2:Int8]"] + 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] + } + } + // End DataFusion GraphViz Plan + "# ); // Physical plan @@ -561,7 +552,9 @@ async fn csv_explain_verbose_plans() { async fn explain_analyze_runs_optimizers(#[values("*", "1")] count_expr: &str) { // repro for https://github.com/apache/datafusion/issues/917 // where EXPLAIN ANALYZE was not correctly running optimizer - let ctx = SessionContext::new(); + let ctx = SessionContext::new_with_config( + SessionConfig::new().with_collect_statistics(true), + ); register_alltypes_parquet(&ctx).await; // This happens as an optimization pass where count(*)/count(1) can be @@ -600,19 +593,6 @@ async fn test_physical_plan_display_indent() { LIMIT 10"; let dataframe = ctx.sql(sql).await.unwrap(); let physical_plan = dataframe.create_physical_plan().await.unwrap(); - let expected = vec![ - "SortPreservingMergeExec: [the_min@2 DESC], fetch=10", - " SortExec: TopK(fetch=10), expr=[the_min@2 DESC], preserve_partitioning=[true]", - " ProjectionExec: expr=[c1@0 as c1, max(aggregate_test_100.c12)@1 as max(aggregate_test_100.c12), min(aggregate_test_100.c12)@2 as the_min]", - " AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000", - " AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)]", - " CoalesceBatchesExec: target_batch_size=4096", - " FilterExec: c12@1 < 10", - " RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1, c12], file_type=csv, has_header=true", - ]; let normalizer = ExplainNormalizer::new(); let actual = format!("{}", displayable(physical_plan.as_ref()).indent(true)) @@ -620,10 +600,24 @@ async fn test_physical_plan_display_indent() { .lines() // normalize paths .map(|s| normalizer.normalize(s)) - .collect::>(); - assert_eq!( - expected, actual, - "expected:\n{expected:#?}\nactual:\n\n{actual:#?}\n" + .collect::>() + .join("\n"); + + assert_snapshot!( + actual, + @r###" + SortPreservingMergeExec: [the_min@2 DESC], fetch=10 + SortExec: TopK(fetch=10), expr=[the_min@2 DESC], preserve_partitioning=[true] + ProjectionExec: expr=[c1@0 as c1, max(aggregate_test_100.c12)@1 as max(aggregate_test_100.c12), min(aggregate_test_100.c12)@2 as the_min] + AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)] + CoalesceBatchesExec: target_batch_size=4096 + RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000 + AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)] + CoalesceBatchesExec: target_batch_size=4096 + FilterExec: c12@1 < 10 + RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1 + DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1, c12], file_type=csv, has_header=true + "### ); } @@ -645,19 +639,6 @@ async fn test_physical_plan_display_indent_multi_children() { let dataframe = ctx.sql(sql).await.unwrap(); let physical_plan = dataframe.create_physical_plan().await.unwrap(); - let expected = vec![ - "CoalesceBatchesExec: target_batch_size=4096", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c2@0)], projection=[c1@0]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000", - " RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1", - " DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], file_type=csv, has_header=true", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c2@0], 9000), input_partitions=9000", - " RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1", - " ProjectionExec: expr=[c1@0 as c2]", - " DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], file_type=csv, has_header=true", - ]; let normalizer = ExplainNormalizer::new(); let actual = format!("{}", displayable(physical_plan.as_ref()).indent(true)) @@ -665,11 +646,24 @@ async fn test_physical_plan_display_indent_multi_children() { .lines() // normalize paths .map(|s| normalizer.normalize(s)) - .collect::>(); - - assert_eq!( - expected, actual, - "expected:\n{expected:#?}\nactual:\n\n{actual:#?}\n" + .collect::>() + .join("\n"); + + assert_snapshot!( + actual, + @r###" + CoalesceBatchesExec: target_batch_size=4096 + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c1@0, c2@0)], projection=[c1@0] + CoalesceBatchesExec: target_batch_size=4096 + RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000 + RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1 + DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], file_type=csv, has_header=true + CoalesceBatchesExec: target_batch_size=4096 + RepartitionExec: partitioning=Hash([c2@0], 9000), input_partitions=9000 + RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1 + ProjectionExec: expr=[c1@0 as c2] + DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1], file_type=csv, has_header=true + "### ); } @@ -777,14 +771,19 @@ async fn explain_logical_plan_only() { let sql = "EXPLAIN select count(*) from (values ('a', 1, 100), ('a', 2, 150)) as t (c1,c2,c3)"; let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); - - let expected = vec![ - vec!["logical_plan", "Projection: count(Int64(1)) AS count(*)\ - \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ - \n SubqueryAlias: t\ - \n Projection: \ - \n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))"]]; - assert_eq!(expected, actual); + let actual = actual.into_iter().map(|r| r.join("\n")).collect::(); + + assert_snapshot!( + actual, + @r#" + logical_plan + Projection: count(Int64(1)) AS count(*) + Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] + SubqueryAlias: t + Projection: + Values: (Utf8("a"), Int64(1), Int64(100)), (Utf8("a"), Int64(2), Int64(150)) + "# + ); } #[tokio::test] @@ -795,14 +794,16 @@ async fn explain_physical_plan_only() { let sql = "EXPLAIN select count(*) from (values ('a', 1, 100), ('a', 2, 150)) as t (c1,c2,c3)"; let actual = execute(&ctx, sql).await; let actual = normalize_vec_for_explain(actual); - - let expected = vec![vec![ - "physical_plan", - "ProjectionExec: expr=[2 as count(*)]\ - \n PlaceholderRowExec\ - \n", - ]]; - assert_eq!(expected, actual); + let actual = actual.into_iter().map(|r| r.join("\n")).collect::(); + + assert_snapshot!( + actual, + @r###" + physical_plan + ProjectionExec: expr=[2 as count(*)] + PlaceholderRowExec + "### + ); } #[tokio::test] diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 77eec20eac00..729542d27e3f 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use datafusion::assert_batches_eq; use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::test_util::register_unbounded_file_with_ordering; @@ -235,3 +236,50 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { } Ok(()) } + +#[tokio::test] +async fn join_using_uppercase_column() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new( + "UPPER", + DataType::UInt32, + false, + )])); + let tmp_dir = TempDir::new()?; + let file_path = tmp_dir.path().join("uppercase-column.csv"); + let mut file = File::create(file_path.clone())?; + file.write_all("0".as_bytes())?; + drop(file); + + let ctx = SessionContext::new(); + ctx.register_csv( + "test", + file_path.to_str().unwrap(), + CsvReadOptions::new().schema(&schema).has_header(false), + ) + .await?; + + let dataframe = ctx + .sql( + r#" + SELECT test."UPPER" FROM "test" + INNER JOIN ( + SELECT test."UPPER" FROM "test" + ) AS selection USING ("UPPER") + ; + "#, + ) + .await?; + + assert_batches_eq!( + [ + "+-------+", + "| UPPER |", + "+-------+", + "| 0 |", + "+-------+", + ], + &dataframe.collect().await? + ); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 579049692e7d..e212ee269b15 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -34,7 +34,6 @@ use datafusion::{execution::context::SessionContext, physical_plan::displayable} use datafusion_common::test_util::batches_to_sort_string; use datafusion_common::utils::get_available_parallelism; use datafusion_common::{assert_contains, assert_not_contains}; -use insta::assert_snapshot; use object_store::path::Path; use std::fs::File; use std::io::Write; @@ -63,6 +62,7 @@ pub mod create_drop; pub mod explain_analyze; pub mod joins; mod path_partition; +mod runtime_config; pub mod select; mod sql_api; diff --git a/datafusion/core/tests/sql/path_partition.rs b/datafusion/core/tests/sql/path_partition.rs index fa6c7432413f..5e9748d23d8c 100644 --- a/datafusion/core/tests/sql/path_partition.rs +++ b/datafusion/core/tests/sql/path_partition.rs @@ -25,8 +25,6 @@ use std::sync::Arc; use arrow::datatypes::DataType; use datafusion::datasource::listing::ListingTableUrl; -use datafusion::datasource::physical_plan::ParquetSource; -use datafusion::datasource::source::DataSourceExec; use datafusion::{ datasource::{ file_format::{csv::CsvFormat, parquet::ParquetFormat}, @@ -42,8 +40,6 @@ use datafusion_common::stats::Precision; use datafusion_common::test_util::batches_to_sort_string; use datafusion_common::ScalarValue; use datafusion_execution::config::SessionConfig; -use datafusion_expr::{col, lit, Expr, Operator}; -use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; use async_trait::async_trait; use bytes::Bytes; @@ -57,55 +53,6 @@ use object_store::{ use object_store::{Attributes, MultipartUpload, PutMultipartOpts, PutPayload}; use url::Url; -#[tokio::test] -async fn parquet_partition_pruning_filter() -> Result<()> { - let ctx = SessionContext::new(); - - let table = create_partitioned_alltypes_parquet_table( - &ctx, - &[ - "year=2021/month=09/day=09/file.parquet", - "year=2021/month=10/day=09/file.parquet", - "year=2021/month=10/day=28/file.parquet", - ], - &[ - ("year", DataType::Int32), - ("month", DataType::Int32), - ("day", DataType::Int32), - ], - "mirror:///", - "alltypes_plain.parquet", - ) - .await; - - // The first three filters can be resolved using only the partition columns. - let filters = [ - Expr::eq(col("year"), lit(2021)), - Expr::eq(col("month"), lit(10)), - Expr::eq(col("day"), lit(28)), - Expr::gt(col("id"), lit(1)), - ]; - let exec = table.scan(&ctx.state(), None, &filters, None).await?; - let data_source_exec = exec.as_any().downcast_ref::().unwrap(); - if let Some((_, parquet_config)) = - data_source_exec.downcast_to_file_source::() - { - let pred = parquet_config.predicate().unwrap(); - // Only the last filter should be pushdown to TableScan - let expected = Arc::new(BinaryExpr::new( - Arc::new(Column::new_with_schema("id", &exec.schema()).unwrap()), - Operator::Gt, - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), - )); - - assert!(pred.as_any().is::()); - let pred = pred.as_any().downcast_ref::().unwrap(); - - assert_eq!(pred, expected.as_ref()); - } - Ok(()) -} - #[tokio::test] async fn parquet_distinct_partition_col() -> Result<()> { let ctx = SessionContext::new(); @@ -484,7 +431,9 @@ async fn parquet_multiple_nonstring_partitions() -> Result<()> { #[tokio::test] async fn parquet_statistics() -> Result<()> { - let ctx = SessionContext::new(); + let mut config = SessionConfig::new(); + config.options_mut().execution.collect_statistics = true; + let ctx = SessionContext::new_with_config(config); register_partitioned_alltypes_parquet( &ctx, @@ -511,7 +460,7 @@ async fn parquet_statistics() -> Result<()> { let schema = physical_plan.schema(); assert_eq!(schema.fields().len(), 4); - let stat_cols = physical_plan.statistics()?.column_statistics; + let stat_cols = physical_plan.partition_statistics(None)?.column_statistics; assert_eq!(stat_cols.len(), 4); // stats for the first col are read from the parquet file assert_eq!(stat_cols[0].null_count, Precision::Exact(3)); @@ -526,7 +475,7 @@ async fn parquet_statistics() -> Result<()> { let schema = physical_plan.schema(); assert_eq!(schema.fields().len(), 2); - let stat_cols = physical_plan.statistics()?.column_statistics; + let stat_cols = physical_plan.partition_statistics(None)?.column_statistics; assert_eq!(stat_cols.len(), 2); // stats for the first col are read from the parquet file assert_eq!(stat_cols[0].null_count, Precision::Exact(1)); @@ -636,7 +585,8 @@ async fn create_partitioned_alltypes_parquet_table( .iter() .map(|x| (x.0.to_owned(), x.1.clone())) .collect::>(), - ); + ) + .with_session_config_options(&ctx.copied_config()); let table_path = ListingTableUrl::parse(table_path).unwrap(); let store_path = diff --git a/datafusion/core/tests/sql/runtime_config.rs b/datafusion/core/tests/sql/runtime_config.rs new file mode 100644 index 000000000000..18e07bb61ed9 --- /dev/null +++ b/datafusion/core/tests/sql/runtime_config.rs @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for runtime configuration SQL interface + +use std::sync::Arc; + +use datafusion::execution::context::SessionContext; +use datafusion::execution::context::TaskContext; +use datafusion_physical_plan::common::collect; + +#[tokio::test] +async fn test_memory_limit_with_spill() { + let ctx = SessionContext::new(); + + ctx.sql("SET datafusion.runtime.memory_limit = '1M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + + ctx.sql("SET datafusion.execution.sort_spill_reservation_bytes = 0") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let query = "select * from generate_series(1,10000000) as t1(v1) order by v1;"; + let df = ctx.sql(query).await.unwrap(); + + let plan = df.create_physical_plan().await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx.state())); + let stream = plan.execute(0, task_ctx).unwrap(); + + let _results = collect(stream).await; + let metrics = plan.metrics().unwrap(); + let spill_count = metrics.spill_count().unwrap(); + assert!(spill_count > 0, "Expected spills but none occurred"); +} + +#[tokio::test] +async fn test_no_spill_with_adequate_memory() { + let ctx = SessionContext::new(); + + ctx.sql("SET datafusion.runtime.memory_limit = '10M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + ctx.sql("SET datafusion.execution.sort_spill_reservation_bytes = 0") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let query = "select * from generate_series(1,100000) as t1(v1) order by v1;"; + let df = ctx.sql(query).await.unwrap(); + + let plan = df.create_physical_plan().await.unwrap(); + let task_ctx = Arc::new(TaskContext::from(&ctx.state())); + let stream = plan.execute(0, task_ctx).unwrap(); + + let _results = collect(stream).await; + let metrics = plan.metrics().unwrap(); + let spill_count = metrics.spill_count().unwrap(); + assert_eq!(spill_count, 0, "Expected no spills but some occurred"); +} + +#[tokio::test] +async fn test_multiple_configs() { + let ctx = SessionContext::new(); + + ctx.sql("SET datafusion.runtime.memory_limit = '100M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + ctx.sql("SET datafusion.execution.batch_size = '2048'") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let query = "select * from generate_series(1,100000) as t1(v1) order by v1;"; + let result = ctx.sql(query).await.unwrap().collect().await; + + assert!(result.is_ok(), "Should not fail due to memory limit"); + + let state = ctx.state(); + let batch_size = state.config().options().execution.batch_size; + assert_eq!(batch_size, 2048); +} + +#[tokio::test] +async fn test_memory_limit_enforcement() { + let ctx = SessionContext::new(); + + ctx.sql("SET datafusion.runtime.memory_limit = '1M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let query = "select * from generate_series(1,100000) as t1(v1) order by v1;"; + let result = ctx.sql(query).await.unwrap().collect().await; + + assert!(result.is_err(), "Should fail due to memory limit"); + + ctx.sql("SET datafusion.runtime.memory_limit = '100M'") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let result = ctx.sql(query).await.unwrap().collect().await; + + assert!(result.is_ok(), "Should not fail due to memory limit"); +} + +#[tokio::test] +async fn test_invalid_memory_limit() { + let ctx = SessionContext::new(); + + let result = ctx + .sql("SET datafusion.runtime.memory_limit = '100X'") + .await; + + assert!(result.is_err()); + let error_message = result.unwrap_err().to_string(); + assert!(error_message.contains("Unsupported unit 'X'")); +} + +#[tokio::test] +async fn test_unknown_runtime_config() { + let ctx = SessionContext::new(); + + let result = ctx + .sql("SET datafusion.runtime.unknown_config = 'value'") + .await; + + assert!(result.is_err()); + let error_message = result.unwrap_err().to_string(); + assert!(error_message.contains("Unknown runtime configuration")); +} diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index f874dd7c0842..0e1210ebb842 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -17,6 +17,7 @@ use super::*; use datafusion_common::ScalarValue; +use insta::assert_snapshot; #[tokio::test] async fn test_list_query_parameters() -> Result<()> { diff --git a/datafusion/core/tests/tpc-ds/49.sql b/datafusion/core/tests/tpc-ds/49.sql index 090e9746c0d8..219877719f22 100644 --- a/datafusion/core/tests/tpc-ds/49.sql +++ b/datafusion/core/tests/tpc-ds/49.sql @@ -110,7 +110,7 @@ select channel, item, return_ratio, return_rank, currency_rank from where sr.sr_return_amt > 10000 and sts.ss_net_profit > 1 - and sts.ss_net_paid > 0 + and sts.ss_net_paid > 0 and sts.ss_quantity > 0 and ss_sold_date_sk = d_date_sk and d_year = 2000 diff --git a/datafusion/core/tests/tracing/mod.rs b/datafusion/core/tests/tracing/mod.rs index 787dd9f4f3cb..df8a28c021d1 100644 --- a/datafusion/core/tests/tracing/mod.rs +++ b/datafusion/core/tests/tracing/mod.rs @@ -55,9 +55,9 @@ async fn test_tracer_injection() { let untraced_result = SpawnedTask::spawn(run_query()).join().await; if let Err(e) = untraced_result { // Check if the error message contains the expected error. - assert!(e.is_panic(), "Expected a panic, but got: {:?}", e); + assert!(e.is_panic(), "Expected a panic, but got: {e:?}"); assert_contains!(e.to_string(), "Task ID not found in spawn graph"); - info!("Caught expected panic: {}", e); + info!("Caught expected panic: {e}"); } else { panic!("Expected the task to panic, but it completed successfully"); }; @@ -94,7 +94,7 @@ async fn run_query() { ctx.register_object_store(&url, traceable_store.clone()); // Register a listing table from the test data directory. - let table_path = format!("test://{}/", test_data); + let table_path = format!("test://{test_data}/"); ctx.register_listing_table("alltypes", &table_path, listing_options, None, None) .await .expect("Failed to register table"); diff --git a/datafusion/core/tests/user_defined/expr_planner.rs b/datafusion/core/tests/user_defined/expr_planner.rs index 1fc6d14c5b22..07d289cab06c 100644 --- a/datafusion/core/tests/user_defined/expr_planner.rs +++ b/datafusion/core/tests/user_defined/expr_planner.rs @@ -56,7 +56,7 @@ impl ExprPlanner for MyCustomPlanner { } BinaryOperator::Question => { Ok(PlannerResult::Planned(Expr::Alias(Alias::new( - Expr::Literal(ScalarValue::Boolean(Some(true))), + Expr::Literal(ScalarValue::Boolean(Some(true)), None), None::<&str>, format!("{} ? {}", expr.left, expr.right), )))) diff --git a/datafusion/core/tests/user_defined/insert_operation.rs b/datafusion/core/tests/user_defined/insert_operation.rs index 12f700ce572b..c8a4279a4211 100644 --- a/datafusion/core/tests/user_defined/insert_operation.rs +++ b/datafusion/core/tests/user_defined/insert_operation.rs @@ -26,6 +26,7 @@ use datafusion::{ use datafusion_catalog::{Session, TableProvider}; use datafusion_expr::{dml::InsertOp, Expr, TableType}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::execution_plan::SchedulingType; use datafusion_physical_plan::{ execution_plan::{Boundedness, EmissionType}, DisplayAs, ExecutionPlan, PlanProperties, @@ -132,7 +133,8 @@ impl TestInsertExec { Partitioning::UnknownPartitioning(1), EmissionType::Incremental, Boundedness::Bounded, - ), + ) + .with_scheduling_type(SchedulingType::Cooperative), } } } diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 5cbb05f290a7..aa5a72c0fb45 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -18,6 +18,8 @@ //! This module contains end to end demonstrations of creating //! user defined aggregate functions +use std::any::Any; +use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; use std::mem::{size_of, size_of_val}; use std::sync::{ @@ -26,10 +28,11 @@ use std::sync::{ }; use arrow::array::{ - types::UInt64Type, AsArray, Int32Array, PrimitiveArray, StringArray, StructArray, + record_batch, types::UInt64Type, Array, AsArray, Int32Array, PrimitiveArray, + StringArray, StructArray, UInt64Array, }; use arrow::datatypes::{Fields, Schema}; - +use arrow_schema::FieldRef; use datafusion::common::test_util::batches_to_string; use datafusion::dataframe::DataFrame; use datafusion::datasource::MemTable; @@ -48,11 +51,12 @@ use datafusion::{ prelude::SessionContext, scalar::ScalarValue, }; -use datafusion_common::assert_contains; +use datafusion_common::{assert_contains, exec_datafusion_err}; use datafusion_common::{cast::as_primitive_array, exec_err}; +use datafusion_expr::expr::WindowFunction; use datafusion_expr::{ - col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, - LogicalPlanBuilder, SimpleAggregateUDF, + col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, Expr, + GroupsAccumulator, LogicalPlanBuilder, SimpleAggregateUDF, WindowFunctionDefinition, }; use datafusion_functions_aggregate::average::AvgAccumulator; @@ -569,7 +573,7 @@ impl TimeSum { // Returns the same type as its input let return_type = timestamp_type.clone(); - let state_fields = vec![Field::new("sum", timestamp_type, true)]; + let state_fields = vec![Field::new("sum", timestamp_type, true).into()]; let volatility = Volatility::Immutable; @@ -669,7 +673,7 @@ impl FirstSelector { let state_fields = state_type .into_iter() .enumerate() - .map(|(i, t)| Field::new(format!("{i}"), t, true)) + .map(|(i, t)| Field::new(format!("{i}"), t, true).into()) .collect::>(); // Possible input signatures @@ -781,7 +785,7 @@ struct TestGroupsAccumulator { } impl AggregateUDFImpl for TestGroupsAccumulator { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -890,3 +894,264 @@ impl GroupsAccumulator for TestGroupsAccumulator { size_of::() } } + +#[derive(Debug)] +struct MetadataBasedAggregateUdf { + name: String, + signature: Signature, + metadata: HashMap, +} + +impl MetadataBasedAggregateUdf { + fn new(metadata: HashMap) -> Self { + // The name we return must be unique. Otherwise we will not call distinct + // instances of this UDF. This is a small hack for the unit tests to get unique + // names, but you could do something more elegant with the metadata. + let name = format!("metadata_based_udf_{}", metadata.len()); + Self { + name, + signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable), + metadata, + } + } +} + +impl AggregateUDFImpl for MetadataBasedAggregateUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unimplemented!("this should never be called since return_field is implemented"); + } + + fn return_field(&self, _arg_fields: &[FieldRef]) -> Result { + Ok(Field::new(self.name(), DataType::UInt64, true) + .with_metadata(self.metadata.clone()) + .into()) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let input_expr = acc_args + .exprs + .first() + .ok_or(exec_datafusion_err!("Expected one argument"))?; + let input_field = input_expr.return_field(acc_args.schema)?; + + let double_output = input_field + .metadata() + .get("modify_values") + .map(|v| v == "double_output") + .unwrap_or(false); + + Ok(Box::new(MetadataBasedAccumulator { + double_output, + curr_sum: 0, + })) + } +} + +#[derive(Debug)] +struct MetadataBasedAccumulator { + double_output: bool, + curr_sum: u64, +} + +impl Accumulator for MetadataBasedAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = values[0] + .as_any() + .downcast_ref::() + .ok_or(exec_datafusion_err!("Expected UInt64Array"))?; + + self.curr_sum = arr.iter().fold(self.curr_sum, |a, b| a + b.unwrap_or(0)); + + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let v = match self.double_output { + true => self.curr_sum * 2, + false => self.curr_sum, + }; + + Ok(ScalarValue::from(v)) + } + + fn size(&self) -> usize { + 9 + } + + fn state(&mut self) -> Result> { + Ok(vec![ScalarValue::from(self.curr_sum)]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } +} + +#[tokio::test] +async fn test_metadata_based_aggregate() -> Result<()> { + let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_metadata", DataType::UInt64, true), + Field::new("with_metadata", DataType::UInt64, true).with_metadata( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let df = ctx.table("t").await?; + + let no_output_meta_udf = + AggregateUDF::from(MetadataBasedAggregateUdf::new(HashMap::new())); + let with_output_meta_udf = AggregateUDF::from(MetadataBasedAggregateUdf::new( + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(), + )); + + let df = df.aggregate( + vec![], + vec![ + no_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_no_out"), + no_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_no_out"), + with_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_with_out"), + with_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_with_out"), + ], + )?; + + let actual = df.collect().await?; + + // To test for output metadata handling, we set the expected values on the result + // To test for input metadata handling, we check the numbers returned + let mut output_meta = HashMap::new(); + let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("meta_no_in_no_out", DataType::UInt64, true), + Field::new("meta_with_in_no_out", DataType::UInt64, true), + Field::new("meta_no_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + ]); + + let expected = record_batch!( + ("meta_no_in_no_out", UInt64, [50]), + ("meta_with_in_no_out", UInt64, [100]), + ("meta_no_in_with_out", UInt64, [50]), + ("meta_with_in_with_out", UInt64, [100]) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} + +#[tokio::test] +async fn test_metadata_based_aggregate_as_window() -> Result<()> { + let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_metadata", DataType::UInt64, true), + Field::new("with_metadata", DataType::UInt64, true).with_metadata( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let df = ctx.table("t").await?; + + let no_output_meta_udf = Arc::new(AggregateUDF::from( + MetadataBasedAggregateUdf::new(HashMap::new()), + )); + let with_output_meta_udf = + Arc::new(AggregateUDF::from(MetadataBasedAggregateUdf::new( + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(), + ))); + + let df = df.select(vec![ + Expr::from(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(Arc::clone(&no_output_meta_udf)), + vec![col("no_metadata")], + )) + .alias("meta_no_in_no_out"), + Expr::from(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(no_output_meta_udf), + vec![col("with_metadata")], + )) + .alias("meta_with_in_no_out"), + Expr::from(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(Arc::clone(&with_output_meta_udf)), + vec![col("no_metadata")], + )) + .alias("meta_no_in_with_out"), + Expr::from(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(with_output_meta_udf), + vec![col("with_metadata")], + )) + .alias("meta_with_in_with_out"), + ])?; + + let actual = df.collect().await?; + + // To test for output metadata handling, we set the expected values on the result + // To test for input metadata handling, we check the numbers returned + let mut output_meta = HashMap::new(); + let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("meta_no_in_no_out", DataType::UInt64, true), + Field::new("meta_with_in_no_out", DataType::UInt64, true), + Field::new("meta_no_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + ]); + + let expected = record_batch!( + ("meta_no_in_no_out", UInt64, [50, 50, 50, 50, 50]), + ("meta_with_in_no_out", UInt64, [100, 100, 100, 100, 100]), + ("meta_no_in_with_out", UInt64, [50, 50, 50, 50, 50]), + ("meta_with_in_with_out", UInt64, [100, 100, 100, 100, 100]) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index e46940e63154..4d3916c1760e 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -63,15 +63,14 @@ use std::hash::Hash; use std::task::{Context, Poll}; use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; +use arrow::array::{Array, ArrayRef, StringViewArray}; use arrow::{ - array::{Int64Array, StringArray}, - datatypes::SchemaRef, - record_batch::RecordBatch, + array::Int64Array, datatypes::SchemaRef, record_batch::RecordBatch, util::pretty::pretty_format_batches, }; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::{ - common::cast::{as_int64_array, as_string_array}, + common::cast::as_int64_array, common::{arrow_datafusion_err, internal_err, DFSchemaRef}, error::{DataFusionError, Result}, execution::{ @@ -100,6 +99,7 @@ use datafusion_optimizer::AnalyzerRule; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use async_trait::async_trait; +use datafusion_common::cast::as_string_view_array; use futures::{Stream, StreamExt}; /// Execute the specified sql and return the resulting record batches @@ -796,22 +796,26 @@ fn accumulate_batch( k: &usize, ) -> BTreeMap { let num_rows = input_batch.num_rows(); + // Assuming the input columns are - // column[0]: customer_id / UTF8 + // column[0]: customer_id UTF8View // column[1]: revenue: Int64 - let customer_id = - as_string_array(input_batch.column(0)).expect("Column 0 is not customer_id"); + let customer_id_column = input_batch.column(0); let revenue = as_int64_array(input_batch.column(1)).unwrap(); for row in 0..num_rows { - add_row( - &mut top_values, - customer_id.value(row), - revenue.value(row), - k, - ); + let customer_id = match customer_id_column.data_type() { + arrow::datatypes::DataType::Utf8View => { + let array = as_string_view_array(customer_id_column).unwrap(); + array.value(row) + } + _ => panic!("Unsupported customer_id type"), + }; + + add_row(&mut top_values, customer_id, revenue.value(row), k); } + top_values } @@ -843,11 +847,19 @@ impl Stream for TopKReader { self.state.iter().rev().unzip(); let customer: Vec<&str> = customer.iter().map(|&s| &**s).collect(); + + let customer_array: ArrayRef = match schema.field(0).data_type() { + arrow::datatypes::DataType::Utf8View => { + Arc::new(StringViewArray::from(customer)) + } + other => panic!("Unsupported customer_id output type: {other:?}"), + }; + Poll::Ready(Some( RecordBatch::try_new( schema, vec![ - Arc::new(StringArray::from(customer)), + Arc::new(customer_array), Arc::new(Int64Array::from(revenue)), ], ) @@ -900,11 +912,12 @@ impl MyAnalyzerRule { .map(|e| { e.transform(|e| { Ok(match e { - Expr::Literal(ScalarValue::Int64(i)) => { + Expr::Literal(ScalarValue::Int64(i), _) => { // transform to UInt64 - Transformed::yes(Expr::Literal(ScalarValue::UInt64( - i.map(|i| i as u64), - ))) + Transformed::yes(Expr::Literal( + ScalarValue::UInt64(i.map(|i| i as u64)), + None, + )) } _ => Transformed::no(e), }) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 264bd6b66a60..90e49e504c75 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -16,16 +16,19 @@ // under the License. use std::any::Any; +use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; -use arrow::array::as_string_array; +use arrow::array::{as_string_array, create_array, record_batch, Int8Array, UInt64Array}; use arrow::array::{ builder::BooleanBuilder, cast::AsArray, Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, }; use arrow::compute::kernels::numeric::add; use arrow::datatypes::{DataType, Field, Schema}; +use arrow_schema::extension::{Bool8, CanonicalExtensionType, ExtensionType}; +use arrow_schema::{ArrowError, FieldRef}; use datafusion::common::test_util::batches_to_string; use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; @@ -35,13 +38,14 @@ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::utils::take_function_args; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, not_impl_err, - plan_err, DFSchema, DataFusionError, HashMap, Result, ScalarValue, + plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::expr::FieldMetadata; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder, - OperateFunctionArg, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, Signature, Volatility, + lit_with_metadata, Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, + LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions_nested::range::range_udf; use parking_lot::Mutex; @@ -57,7 +61,7 @@ async fn csv_query_custom_udf_with_cast() -> Result<()> { let ctx = create_udf_context(); register_aggregate_csv(&ctx).await?; let sql = "SELECT avg(custom_sqrt(c11)) FROM aggregate_test_100"; - let actual = plan_and_collect(&ctx, sql).await.unwrap(); + let actual = plan_and_collect(&ctx, sql).await?; insta::assert_snapshot!(batches_to_string(&actual), @r###" +------------------------------------------+ @@ -76,7 +80,7 @@ async fn csv_query_avg_sqrt() -> Result<()> { register_aggregate_csv(&ctx).await?; // Note it is a different column (c12) than above (c11) let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; - let actual = plan_and_collect(&ctx, sql).await.unwrap(); + let actual = plan_and_collect(&ctx, sql).await?; insta::assert_snapshot!(batches_to_string(&actual), @r###" +------------------------------------------+ @@ -389,7 +393,7 @@ async fn udaf_as_window_func() -> Result<()> { WindowAggr: windowExpr=[[my_acc(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] TableScan: my_table"#; - let dataframe = context.sql(sql).await.unwrap(); + let dataframe = context.sql(sql).await?; assert_eq!(format!("{}", dataframe.logical_plan()), expected); Ok(()) } @@ -399,7 +403,7 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { let ctx = SessionContext::new(); let arr = Int32Array::from(vec![1]); let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch)?; let myfunc = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { @@ -443,7 +447,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { let ctx = SessionContext::new(); let arr = Int32Array::from(vec![1]); let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch)?; let myfunc = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { @@ -803,7 +807,7 @@ impl ScalarUDFImpl for TakeUDF { &self.signature } fn return_type(&self, _args: &[DataType]) -> Result { - not_impl_err!("Not called because the return_type_from_args is implemented") + not_impl_err!("Not called because the return_field_from_args is implemented") } /// This function returns the type of the first or second argument based on @@ -811,9 +815,9 @@ impl ScalarUDFImpl for TakeUDF { /// /// 1. If the third argument is '0', return the type of the first argument /// 2. If the third argument is '1', return the type of the second argument - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - if args.arg_types.len() != 3 { - return plan_err!("Expected 3 arguments, got {}.", args.arg_types.len()); + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + if args.arg_fields.len() != 3 { + return plan_err!("Expected 3 arguments, got {}.", args.arg_fields.len()); } let take_idx = if let Some(take_idx) = args.scalar_arguments.get(2) { @@ -838,9 +842,12 @@ impl ScalarUDFImpl for TakeUDF { ); }; - Ok(ReturnInfo::new_nullable( - args.arg_types[take_idx].to_owned(), - )) + Ok(Field::new( + self.name(), + args.arg_fields[take_idx].data_type().to_owned(), + true, + ) + .into()) } // The actual implementation @@ -967,10 +974,6 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { Ok(ExprSimplifyResult::Simplified(replacement)) } - - fn aliases(&self) -> &[String] { - &[] - } } impl ScalarFunctionWrapper { @@ -1004,8 +1007,7 @@ impl ScalarFunctionWrapper { if let Some(value) = placeholder.strip_prefix('$') { Ok(value.parse().map(|v: usize| v - 1).map_err(|e| { DataFusionError::Execution(format!( - "Placeholder `{}` parsing error: {}!", - placeholder, e + "Placeholder `{placeholder}` parsing error: {e}!" )) })?) } else { @@ -1160,7 +1162,7 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<( match ctx.sql(sql).await { Ok(_) => {} Err(e) => { - panic!("Error creating function: {}", e); + panic!("Error creating function: {e}"); } } @@ -1179,7 +1181,7 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<( quote_style: None, span: Span::empty(), }), - data_type: DataType::Utf8, + data_type: DataType::Utf8View, default_expr: None, }]), return_type: Some(DataType::Int32), @@ -1367,3 +1369,404 @@ async fn register_alltypes_parquet(ctx: &SessionContext) -> Result<()> { async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } + +#[derive(Debug)] +struct MetadataBasedUdf { + name: String, + signature: Signature, + metadata: HashMap, +} + +impl MetadataBasedUdf { + fn new(metadata: HashMap) -> Self { + // The name we return must be unique. Otherwise we will not call distinct + // instances of this UDF. This is a small hack for the unit tests to get unique + // names, but you could do something more elegant with the metadata. + let name = format!("metadata_based_udf_{}", metadata.len()); + Self { + name, + signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable), + metadata, + } + } +} + +impl ScalarUDFImpl for MetadataBasedUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _args: &[DataType]) -> Result { + unimplemented!( + "this should never be called since return_field_from_args is implemented" + ); + } + + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + Ok(Field::new(self.name(), DataType::UInt64, true) + .with_metadata(self.metadata.clone()) + .into()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + assert_eq!(args.arg_fields.len(), 1); + let should_double = args.arg_fields[0] + .metadata() + .get("modify_values") + .map(|v| v == "double_output") + .unwrap_or(false); + let mulitplier = if should_double { 2 } else { 1 }; + + match &args.args[0] { + ColumnarValue::Array(array) => { + let array_values: Vec<_> = array + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.map(|x| x * mulitplier)) + .collect(); + let array_ref = Arc::new(UInt64Array::from(array_values)) as ArrayRef; + Ok(ColumnarValue::Array(array_ref)) + } + ColumnarValue::Scalar(value) => { + let ScalarValue::UInt64(value) = value else { + return exec_err!("incorrect data type"); + }; + + Ok(ColumnarValue::Scalar(ScalarValue::UInt64( + value.map(|v| v * mulitplier), + ))) + } + } + } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + self.name == other.name() + } +} + +#[tokio::test] +async fn test_metadata_based_udf() -> Result<()> { + let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_metadata", DataType::UInt64, true), + Field::new("with_metadata", DataType::UInt64, true).with_metadata( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), + ])); + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let no_output_meta_udf = ScalarUDF::from(MetadataBasedUdf::new(HashMap::new())); + let with_output_meta_udf = ScalarUDF::from(MetadataBasedUdf::new( + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(), + )); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .project(vec![ + no_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_no_out"), + no_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_no_out"), + with_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_with_out"), + with_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_with_out"), + ])? + .build()?; + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + + // To test for output metadata handling, we set the expected values on the result + // To test for input metadata handling, we check the numbers returned + let mut output_meta = HashMap::new(); + let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("meta_no_in_no_out", DataType::UInt64, true), + Field::new("meta_with_in_no_out", DataType::UInt64, true), + Field::new("meta_no_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + ]); + + let expected = record_batch!( + ("meta_no_in_no_out", UInt64, [0, 5, 10, 15, 20]), + ("meta_with_in_no_out", UInt64, [0, 10, 20, 30, 40]), + ("meta_no_in_with_out", UInt64, [0, 5, 10, 15, 20]), + ("meta_with_in_with_out", UInt64, [0, 10, 20, 30, 40]) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + ctx.deregister_table("t")?; + Ok(()) +} + +#[tokio::test] +async fn test_metadata_based_udf_with_literal() -> Result<()> { + let ctx = SessionContext::new(); + let input_metadata: HashMap = + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(); + let input_metadata = FieldMetadata::from(input_metadata); + let df = ctx.sql("select 0;").await?.select(vec![ + lit(5u64).alias_with_metadata("lit_with_doubling", Some(input_metadata.clone())), + lit(5u64).alias("lit_no_doubling"), + lit_with_metadata(5u64, Some(input_metadata)) + .alias("lit_with_double_no_alias_metadata"), + ])?; + + let output_metadata: HashMap = + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(); + let custom_udf = ScalarUDF::from(MetadataBasedUdf::new(output_metadata.clone())); + + let plan = LogicalPlanBuilder::from(df.into_optimized_plan()?) + .project(vec![ + custom_udf + .call(vec![col("lit_with_doubling")]) + .alias("doubled_output"), + custom_udf + .call(vec![col("lit_no_doubling")]) + .alias("not_doubled_output"), + custom_udf + .call(vec![col("lit_with_double_no_alias_metadata")]) + .alias("double_without_alias_metadata"), + ])? + .build()?; + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + + let schema = Arc::new(Schema::new(vec![ + Field::new("doubled_output", DataType::UInt64, false) + .with_metadata(output_metadata.clone()), + Field::new("not_doubled_output", DataType::UInt64, false) + .with_metadata(output_metadata.clone()), + Field::new("double_without_alias_metadata", DataType::UInt64, false) + .with_metadata(output_metadata.clone()), + ])); + + let expected = RecordBatch::try_new( + schema, + vec![ + create_array!(UInt64, [10]), + create_array!(UInt64, [5]), + create_array!(UInt64, [10]), + ], + )?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} + +/// This UDF is to test extension handling, both on the input and output +/// sides. For the input, we will handle the data differently if there is +/// the canonical extension type Bool8. For the output we will add a +/// user defined extension type. +#[derive(Debug)] +struct ExtensionBasedUdf { + name: String, + signature: Signature, +} + +impl Default for ExtensionBasedUdf { + fn default() -> Self { + Self { + name: "canonical_extension_udf".to_string(), + signature: Signature::exact(vec![DataType::Int8], Volatility::Immutable), + } + } +} +impl ScalarUDFImpl for ExtensionBasedUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + Ok(Field::new("canonical_extension_udf", DataType::Utf8, true) + .with_extension_type(MyUserExtentionType {}) + .into()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + assert_eq!(args.arg_fields.len(), 1); + let input_field = args.arg_fields[0].as_ref(); + + let output_as_bool = matches!( + CanonicalExtensionType::try_from(input_field), + Ok(CanonicalExtensionType::Bool8(_)) + ); + + // If we have the extension type set, we are outputting a boolean value. + // Otherwise we output a string representation of the numeric value. + fn print_value(v: Option, as_bool: bool) -> Option { + v.map(|x| match as_bool { + true => format!("{}", x != 0), + false => format!("{x}"), + }) + } + + match &args.args[0] { + ColumnarValue::Array(array) => { + let array_values: Vec<_> = array + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| print_value(v, output_as_bool)) + .collect(); + let array_ref = Arc::new(StringArray::from(array_values)) as ArrayRef; + Ok(ColumnarValue::Array(array_ref)) + } + ColumnarValue::Scalar(value) => { + let ScalarValue::Int8(value) = value else { + return exec_err!("incorrect data type"); + }; + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(print_value( + *value, + output_as_bool, + )))) + } + } + } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + self.name == other.name() + } +} + +struct MyUserExtentionType {} + +impl ExtensionType for MyUserExtentionType { + const NAME: &'static str = "my_user_extention_type"; + type Metadata = (); + + fn metadata(&self) -> &Self::Metadata { + &() + } + + fn serialize_metadata(&self) -> Option { + None + } + + fn deserialize_metadata( + _metadata: Option<&str>, + ) -> std::result::Result { + Ok(()) + } + + fn supports_data_type( + &self, + data_type: &DataType, + ) -> std::result::Result<(), ArrowError> { + if let DataType::Utf8 = data_type { + Ok(()) + } else { + Err(ArrowError::InvalidArgumentError( + "only utf8 supported".to_string(), + )) + } + } + + fn try_new( + _data_type: &DataType, + _metadata: Self::Metadata, + ) -> std::result::Result { + Ok(Self {}) + } +} + +#[tokio::test] +async fn test_extension_based_udf() -> Result<()> { + let data_array = Arc::new(Int8Array::from(vec![0, 0, 10, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_extension", DataType::Int8, true), + Field::new("with_extension", DataType::Int8, true).with_extension_type(Bool8), + ])); + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let extension_based_udf = ScalarUDF::from(ExtensionBasedUdf::default()); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .project(vec![ + extension_based_udf + .call(vec![col("no_extension")]) + .alias("without_bool8_extension"), + extension_based_udf + .call(vec![col("with_extension")]) + .alias("with_bool8_extension"), + ])? + .build()?; + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + + // To test for output extension handling, we set the expected values on the result + // To test for input extensions handling, we check the strings returned + let expected_schema = Schema::new(vec![ + Field::new("without_bool8_extension", DataType::Utf8, true) + .with_extension_type(MyUserExtentionType {}), + Field::new("with_bool8_extension", DataType::Utf8, true) + .with_extension_type(MyUserExtentionType {}), + ]); + + let expected = record_batch!( + ("without_bool8_extension", Utf8, ["0", "0", "10", "20"]), + ( + "with_bool8_extension", + Utf8, + ["false", "false", "true", "true"] + ) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + ctx.deregister_table("t")?; + Ok(()) +} diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index e4aff0b00705..2c6611f382ce 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -205,7 +205,7 @@ impl TableFunctionImpl for SimpleCsvTableFunc { let mut filepath = String::new(); for expr in exprs { match expr { - Expr::Literal(ScalarValue::Utf8(Some(ref path))) => { + Expr::Literal(ScalarValue::Utf8(Some(ref path)), _) => { filepath.clone_from(path); } expr => new_exprs.push(expr.clone()), diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 7c56507acd45..bcd2c3945e39 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -18,11 +18,16 @@ //! This module contains end to end tests of creating //! user defined window functions -use arrow::array::{ArrayRef, AsArray, Int64Array, RecordBatch, StringArray}; +use arrow::array::{ + record_batch, Array, ArrayRef, AsArray, Int64Array, RecordBatch, StringArray, + UInt64Array, +}; use arrow::datatypes::{DataType, Field, Schema}; +use arrow_schema::FieldRef; use datafusion::common::test_util::batches_to_string; use datafusion::common::{Result, ScalarValue}; use datafusion::prelude::SessionContext; +use datafusion_common::exec_datafusion_err; use datafusion_expr::{ PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF, WindowUDFImpl, }; @@ -34,6 +39,7 @@ use datafusion_physical_expr::{ expressions::{col, lit}, PhysicalExpr, }; +use std::collections::HashMap; use std::{ any::Any, ops::Range, @@ -559,8 +565,8 @@ impl OddCounter { &self.aliases } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::Int64, true)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Int64, true).into()) } } @@ -678,7 +684,7 @@ impl WindowUDFImpl for VariadicWindowUDF { unimplemented!("unnecessary for testing"); } - fn field(&self, _: WindowUDFFieldArgs) -> Result { + fn field(&self, _: WindowUDFFieldArgs) -> Result { unimplemented!("unnecessary for testing"); } } @@ -723,11 +729,11 @@ fn test_default_expressions() -> Result<()> { ]; for input_exprs in &test_cases { - let input_types = input_exprs + let input_fields = input_exprs .iter() - .map(|expr: &Arc| expr.data_type(&schema).unwrap()) + .map(|expr: &Arc| expr.return_field(&schema).unwrap()) .collect::>(); - let expr_args = ExpressionArgs::new(input_exprs, &input_types); + let expr_args = ExpressionArgs::new(input_exprs, &input_fields); let ret_exprs = udwf.expressions(expr_args); @@ -735,9 +741,7 @@ fn test_default_expressions() -> Result<()> { assert_eq!( input_exprs.len(), ret_exprs.len(), - "\nInput expressions: {:?}\nReturned expressions: {:?}", - input_exprs, - ret_exprs + "\nInput expressions: {input_exprs:?}\nReturned expressions: {ret_exprs:?}" ); // Compares each returned expression with original input expressions @@ -753,3 +757,149 @@ fn test_default_expressions() -> Result<()> { } Ok(()) } + +#[derive(Debug)] +struct MetadataBasedWindowUdf { + name: String, + signature: Signature, + metadata: HashMap, +} + +impl MetadataBasedWindowUdf { + fn new(metadata: HashMap) -> Self { + // The name we return must be unique. Otherwise we will not call distinct + // instances of this UDF. This is a small hack for the unit tests to get unique + // names, but you could do something more elegant with the metadata. + let name = format!("metadata_based_udf_{}", metadata.len()); + Self { + name, + signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable), + metadata, + } + } +} + +impl WindowUDFImpl for MetadataBasedWindowUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + let input_field = partition_evaluator_args + .input_fields() + .first() + .ok_or(exec_datafusion_err!("Expected one argument"))?; + + let double_output = input_field + .metadata() + .get("modify_values") + .map(|v| v == "double_output") + .unwrap_or(false); + + Ok(Box::new(MetadataBasedPartitionEvaluator { double_output })) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::UInt64, true) + .with_metadata(self.metadata.clone()) + .into()) + } +} + +#[derive(Debug)] +struct MetadataBasedPartitionEvaluator { + double_output: bool, +} + +impl PartitionEvaluator for MetadataBasedPartitionEvaluator { + fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result { + let values = values[0].as_any().downcast_ref::().unwrap(); + let sum = values.iter().fold(0_u64, |acc, v| acc + v.unwrap_or(0)); + + let result = if self.double_output { sum * 2 } else { sum }; + + Ok(Arc::new(UInt64Array::from_value(result, num_rows))) + } +} + +#[tokio::test] +async fn test_metadata_based_window_fn() -> Result<()> { + let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_metadata", DataType::UInt64, true), + Field::new("with_metadata", DataType::UInt64, true).with_metadata( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let df = ctx.table("t").await?; + + let no_output_meta_udf = WindowUDF::from(MetadataBasedWindowUdf::new(HashMap::new())); + let with_output_meta_udf = WindowUDF::from(MetadataBasedWindowUdf::new( + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(), + )); + + let df = df.select(vec![ + no_output_meta_udf + .call(vec![datafusion_expr::col("no_metadata")]) + .alias("meta_no_in_no_out"), + no_output_meta_udf + .call(vec![datafusion_expr::col("with_metadata")]) + .alias("meta_with_in_no_out"), + with_output_meta_udf + .call(vec![datafusion_expr::col("no_metadata")]) + .alias("meta_no_in_with_out"), + with_output_meta_udf + .call(vec![datafusion_expr::col("with_metadata")]) + .alias("meta_with_in_with_out"), + ])?; + + let actual = df.collect().await?; + + // To test for output metadata handling, we set the expected values on the result + // To test for input metadata handling, we check the numbers returned + let mut output_meta = HashMap::new(); + let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("meta_no_in_no_out", DataType::UInt64, true), + Field::new("meta_with_in_no_out", DataType::UInt64, true), + Field::new("meta_no_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + ]); + + let expected = record_batch!( + ("meta_no_in_no_out", UInt64, [50, 50, 50, 50, 50]), + ("meta_with_in_no_out", UInt64, [100, 100, 100, 100, 100]), + ("meta_no_in_with_out", UInt64, [50, 50, 50, 50, 50]), + ("meta_with_in_with_out", UInt64, [100, 100, 100, 100, 100]) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} diff --git a/datafusion/datasource-avro/README.md b/datafusion/datasource-avro/README.md index f8d7aebdcad1..3436d4a85ad0 100644 --- a/datafusion/datasource-avro/README.md +++ b/datafusion/datasource-avro/README.md @@ -23,4 +23,9 @@ This crate is a submodule of DataFusion that defines a Avro based file source. +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs index 9a1b54b872ad..36553b36bc6c 100644 --- a/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/datasource-avro/src/avro_to_arrow/arrow_array_reader.rs @@ -21,7 +21,7 @@ use apache_avro::schema::RecordSchema; use apache_avro::{ schema::{Schema as AvroSchema, SchemaKind}, types::Value, - AvroResult, Error as AvroError, Reader as AvroReader, + Error as AvroError, Reader as AvroReader, }; use arrow::array::{ make_array, Array, ArrayBuilder, ArrayData, ArrayDataBuilder, ArrayRef, @@ -33,7 +33,7 @@ use arrow::buffer::{Buffer, MutableBuffer}; use arrow::datatypes::{ ArrowDictionaryKeyType, ArrowNumericType, ArrowPrimitiveType, DataType, Date32Type, Date64Type, Field, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, - Int8Type, Schema, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, @@ -56,23 +56,17 @@ type RecordSlice<'a> = &'a [&'a Vec<(String, Value)>]; pub struct AvroArrowArrayReader<'a, R: Read> { reader: AvroReader<'a, R>, schema: SchemaRef, - projection: Option>, schema_lookup: BTreeMap, } impl AvroArrowArrayReader<'_, R> { - pub fn try_new( - reader: R, - schema: SchemaRef, - projection: Option>, - ) -> Result { + pub fn try_new(reader: R, schema: SchemaRef) -> Result { let reader = AvroReader::new(reader)?; let writer_schema = reader.writer_schema().clone(); let schema_lookup = Self::schema_lookup(writer_schema)?; Ok(Self { reader, schema, - projection, schema_lookup, }) } @@ -123,7 +117,7 @@ impl AvroArrowArrayReader<'_, R> { AvroSchema::Record(RecordSchema { fields, lookup, .. }) => { lookup.iter().for_each(|(field_name, pos)| { schema_lookup - .insert(format!("{}.{}", parent_field_name, field_name), *pos); + .insert(format!("{parent_field_name}.{field_name}"), *pos); }); for field in fields { @@ -137,7 +131,7 @@ impl AvroArrowArrayReader<'_, R> { } } AvroSchema::Array(schema) => { - let sub_parent_field_name = format!("{}.element", parent_field_name); + let sub_parent_field_name = format!("{parent_field_name}.element"); Self::child_schema_lookup( &sub_parent_field_name, &schema.items, @@ -175,20 +169,9 @@ impl AvroArrowArrayReader<'_, R> { }; let rows = rows.iter().collect::>>(); - let projection = self.projection.clone().unwrap_or_default(); - let arrays = - self.build_struct_array(&rows, "", self.schema.fields(), &projection); - let projected_fields = if projection.is_empty() { - self.schema.fields().clone() - } else { - projection - .iter() - .filter_map(|name| self.schema.column_with_name(name)) - .map(|(_, field)| field.clone()) - .collect() - }; - let projected_schema = Arc::new(Schema::new(projected_fields)); - Some(arrays.and_then(|arr| RecordBatch::try_new(projected_schema, arr))) + let arrays = self.build_struct_array(&rows, "", self.schema.fields()); + + Some(arrays.and_then(|arr| RecordBatch::try_new(Arc::clone(&self.schema), arr))) } fn build_boolean_array(&self, rows: RecordSlice, col_name: &str) -> ArrayRef { @@ -615,7 +598,7 @@ impl AvroArrowArrayReader<'_, R> { let sub_parent_field_name = format!("{}.{}", parent_field_name, list_field.name()); let arrays = - self.build_struct_array(&rows, &sub_parent_field_name, fields, &[])?; + self.build_struct_array(&rows, &sub_parent_field_name, fields)?; let data_type = DataType::Struct(fields.clone()); ArrayDataBuilder::new(data_type) .len(rows.len()) @@ -645,20 +628,14 @@ impl AvroArrowArrayReader<'_, R> { /// The function does not construct the StructArray as some callers would want the child arrays. /// /// *Note*: The function is recursive, and will read nested structs. - /// - /// If `projection` is not empty, then all values are returned. The first level of projection - /// occurs at the `RecordBatch` level. No further projection currently occurs, but would be - /// useful if plucking values from a struct, e.g. getting `a.b.c.e` from `a.b.c.{d, e}`. fn build_struct_array( &self, rows: RecordSlice, parent_field_name: &str, struct_fields: &Fields, - projection: &[String], ) -> ArrowResult> { let arrays: ArrowResult> = struct_fields .iter() - .filter(|field| projection.is_empty() || projection.contains(field.name())) .map(|field| { let field_path = if parent_field_name.is_empty() { field.name().to_string() @@ -840,12 +817,8 @@ impl AvroArrowArrayReader<'_, R> { } }) .collect::>>(); - let arrays = self.build_struct_array( - &struct_rows, - &field_path, - fields, - &[], - )?; + let arrays = + self.build_struct_array(&struct_rows, &field_path, fields)?; // construct a struct array's data in order to set null buffer let data_type = DataType::Struct(fields.clone()); let data = ArrayDataBuilder::new(data_type) @@ -965,40 +938,31 @@ fn resolve_string(v: &Value) -> ArrowResult> { .map_err(|e| SchemaError(format!("expected resolvable string : {e:?}"))) } -fn resolve_u8(v: &Value) -> AvroResult { - let int = match v { - Value::Int(n) => Ok(Value::Int(*n)), - Value::Long(n) => Ok(Value::Int(*n as i32)), - other => Err(AvroError::GetU8(other.into())), - }?; - if let Value::Int(n) = int { - if n >= 0 && n <= From::from(u8::MAX) { - return Ok(n as u8); - } - } +fn resolve_u8(v: &Value) -> Option { + let v = match v { + Value::Union(_, inner) => inner.as_ref(), + _ => v, + }; - Err(AvroError::GetU8(int.into())) + match v { + Value::Int(n) => u8::try_from(*n).ok(), + Value::Long(n) => u8::try_from(*n).ok(), + _ => None, + } } fn resolve_bytes(v: &Value) -> Option> { - let v = if let Value::Union(_, b) = v { b } else { v }; + let v = match v { + Value::Union(_, inner) => inner.as_ref(), + _ => v, + }; + match v { - Value::Bytes(_) => Ok(v.clone()), - Value::String(s) => Ok(Value::Bytes(s.clone().into_bytes())), - Value::Array(items) => Ok(Value::Bytes( - items - .iter() - .map(resolve_u8) - .collect::, _>>() - .ok()?, - )), - other => Err(AvroError::GetBytes(other.into())), - } - .ok() - .and_then(|v| match v { - Value::Bytes(s) => Some(s), + Value::Bytes(bytes) => Some(bytes.clone()), + Value::String(s) => Some(s.as_bytes().to_vec()), + Value::Array(items) => items.iter().map(resolve_u8).collect::>>(), _ => None, - }) + } } fn resolve_fixed(v: &Value, size: usize) -> Option> { diff --git a/datafusion/datasource-avro/src/avro_to_arrow/reader.rs b/datafusion/datasource-avro/src/avro_to_arrow/reader.rs index bc7b50a9cdc3..7f5900605a06 100644 --- a/datafusion/datasource-avro/src/avro_to_arrow/reader.rs +++ b/datafusion/datasource-avro/src/avro_to_arrow/reader.rs @@ -16,7 +16,7 @@ // under the License. use super::arrow_array_reader::AvroArrowArrayReader; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{Fields, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use datafusion_common::Result; @@ -133,19 +133,35 @@ impl Reader<'_, R> { /// /// If reading a `File`, you can customise the Reader, such as to enable schema /// inference, use `ReaderBuilder`. + /// + /// If projection is provided, it uses a schema with only the fields in the projection, respecting their order. + /// Only the first level of projection is handled. No further projection currently occurs, but would be + /// useful if plucking values from a struct, e.g. getting `a.b.c.e` from `a.b.c.{d, e}`. pub fn try_new( reader: R, schema: SchemaRef, batch_size: usize, projection: Option>, ) -> Result { + let projected_schema = projection.as_ref().filter(|p| !p.is_empty()).map_or_else( + || Arc::clone(&schema), + |proj| { + Arc::new(arrow::datatypes::Schema::new( + proj.iter() + .filter_map(|name| { + schema.column_with_name(name).map(|(_, f)| f.clone()) + }) + .collect::(), + )) + }, + ); + Ok(Self { array_reader: AvroArrowArrayReader::try_new( reader, - Arc::clone(&schema), - projection, + Arc::clone(&projected_schema), )?, - schema, + schema: projected_schema, batch_size, }) } @@ -179,10 +195,13 @@ mod tests { use arrow::datatypes::{DataType, Field}; use std::fs::File; - fn build_reader(name: &str) -> Reader { + fn build_reader(name: &str, projection: Option>) -> Reader { let testdata = datafusion_common::test_util::arrow_test_data(); let filename = format!("{testdata}/avro/{name}"); - let builder = ReaderBuilder::new().read_schema().with_batch_size(64); + let mut builder = ReaderBuilder::new().read_schema().with_batch_size(64); + if let Some(projection) = projection { + builder = builder.with_projection(projection); + } builder.build(File::open(filename).unwrap()).unwrap() } @@ -195,7 +214,7 @@ mod tests { #[test] fn test_avro_basic() { - let mut reader = build_reader("alltypes_dictionary.avro"); + let mut reader = build_reader("alltypes_dictionary.avro", None); let batch = reader.next().unwrap().unwrap(); assert_eq!(11, batch.num_columns()); @@ -281,4 +300,58 @@ mod tests { assert_eq!(1230768000000000, col.value(0)); assert_eq!(1230768060000000, col.value(1)); } + + #[test] + fn test_avro_with_projection() { + // Test projection to filter and reorder columns + let projection = Some(vec![ + "string_col".to_string(), + "double_col".to_string(), + "bool_col".to_string(), + ]); + let mut reader = build_reader("alltypes_dictionary.avro", projection); + let batch = reader.next().unwrap().unwrap(); + + // Only 3 columns should be present (not all 11) + assert_eq!(3, batch.num_columns()); + assert_eq!(2, batch.num_rows()); + + let schema = reader.schema(); + let batch_schema = batch.schema(); + assert_eq!(schema, batch_schema); + + // Verify columns are in the order specified in projection + // First column should be string_col (was at index 9 in original) + assert_eq!("string_col", schema.field(0).name()); + assert_eq!(&DataType::Binary, schema.field(0).data_type()); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!("0".as_bytes(), col.value(0)); + assert_eq!("1".as_bytes(), col.value(1)); + + // Second column should be double_col (was at index 7 in original) + assert_eq!("double_col", schema.field(1).name()); + assert_eq!(&DataType::Float64, schema.field(1).data_type()); + let col = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(0.0, col.value(0)); + assert_eq!(10.1, col.value(1)); + + // Third column should be bool_col (was at index 1 in original) + assert_eq!("bool_col", schema.field(2).name()); + assert_eq!(&DataType::Boolean, schema.field(2).data_type()); + let col = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(col.value(0)); + assert!(!col.value(1)); + } } diff --git a/datafusion/datasource-avro/src/avro_to_arrow/schema.rs b/datafusion/datasource-avro/src/avro_to_arrow/schema.rs index 276056c24c01..f53d38e51d1f 100644 --- a/datafusion/datasource-avro/src/avro_to_arrow/schema.rs +++ b/datafusion/datasource-avro/src/avro_to_arrow/schema.rs @@ -22,7 +22,7 @@ use apache_avro::types::Value; use apache_avro::Schema as AvroSchema; use arrow::datatypes::{DataType, IntervalUnit, Schema, TimeUnit, UnionMode}; use arrow::datatypes::{Field, UnionFields}; -use datafusion_common::error::{DataFusionError, Result}; +use datafusion_common::error::Result; use std::collections::HashMap; use std::sync::Arc; @@ -107,9 +107,7 @@ fn schema_to_field_with_props( .data_type() .clone() } else { - return Err(DataFusionError::AvroError( - apache_avro::Error::GetUnionDuplicate, - )); + return Err(apache_avro::Error::GetUnionDuplicate.into()); } } else { let fields = sub_schemas diff --git a/datafusion/datasource-avro/src/file_format.rs b/datafusion/datasource-avro/src/file_format.rs index 4b50fee1d326..47f8d9daca0a 100644 --- a/datafusion/datasource-avro/src/file_format.rs +++ b/datafusion/datasource-avro/src/file_format.rs @@ -37,7 +37,6 @@ use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_format::{FileFormat, FileFormatFactory}; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; use datafusion_datasource::source::DataSourceExec; -use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; use datafusion_session::Session; @@ -150,7 +149,6 @@ impl FileFormat for AvroFormat { &self, _state: &dyn Session, conf: FileScanConfig, - _filters: Option<&Arc>, ) -> Result> { let config = FileScanConfigBuilder::from(conf) .with_source(self.file_source()) diff --git a/datafusion/datasource-avro/src/source.rs b/datafusion/datasource-avro/src/source.rs index ce3722e7b11e..948049f5a747 100644 --- a/datafusion/datasource-avro/src/source.rs +++ b/datafusion/datasource-avro/src/source.rs @@ -18,142 +18,22 @@ //! Execution plan for reading line-delimited Avro files use std::any::Any; -use std::fmt::Formatter; use std::sync::Arc; use crate::avro_to_arrow::Reader as AvroReader; -use datafusion_common::error::Result; - use arrow::datatypes::SchemaRef; -use datafusion_common::{Constraints, Statistics}; +use datafusion_common::error::Result; +use datafusion_common::Statistics; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; use datafusion_datasource::file_stream::FileOpener; -use datafusion_datasource::source::DataSourceExec; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_datasource::schema_adapter::SchemaAdapterFactory; use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; -use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, -}; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use object_store::ObjectStore; -/// Execution plan for scanning Avro data source -#[derive(Debug, Clone)] -#[deprecated(since = "46.0.0", note = "use DataSourceExec instead")] -pub struct AvroExec { - inner: DataSourceExec, - base_config: FileScanConfig, -} - -#[allow(unused, deprecated)] -impl AvroExec { - /// Create a new Avro reader execution plan provided base configurations - pub fn new(base_config: FileScanConfig) -> Self { - let ( - projected_schema, - projected_constraints, - projected_statistics, - projected_output_ordering, - ) = base_config.project(); - let cache = Self::compute_properties( - Arc::clone(&projected_schema), - &projected_output_ordering, - projected_constraints, - &base_config, - ); - let base_config = base_config.with_source(Arc::new(AvroSource::default())); - Self { - inner: DataSourceExec::new(Arc::new(base_config.clone())), - base_config, - } - } - - /// Ref to the base configs - pub fn base_config(&self) -> &FileScanConfig { - &self.base_config - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - orderings: &[LexOrdering], - constraints: Constraints, - file_scan_config: &FileScanConfig, - ) -> PlanProperties { - // Equivalence Properties - let eq_properties = EquivalenceProperties::new_with_orderings(schema, orderings) - .with_constraints(constraints); - let n_partitions = file_scan_config.file_groups.len(); - - PlanProperties::new( - eq_properties, - Partitioning::UnknownPartitioning(n_partitions), // Output Partitioning - EmissionType::Incremental, - Boundedness::Bounded, - ) - } -} - -#[allow(unused, deprecated)] -impl DisplayAs for AvroExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - self.inner.fmt_as(t, f) - } -} - -#[allow(unused, deprecated)] -impl ExecutionPlan for AvroExec { - fn name(&self) -> &'static str { - "AvroExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - self.inner.properties() - } - fn children(&self) -> Vec<&Arc> { - Vec::new() - } - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - self.inner.execute(partition, context) - } - - fn statistics(&self) -> Result { - self.inner.statistics() - } - - fn metrics(&self) -> Option { - self.inner.metrics() - } - - fn fetch(&self) -> Option { - self.inner.fetch() - } - - fn with_fetch(&self, limit: Option) -> Option> { - self.inner.with_fetch(limit) - } -} - /// AvroSource holds the extra configuration that is necessary for opening avro files #[derive(Clone, Default)] pub struct AvroSource { @@ -162,6 +42,7 @@ pub struct AvroSource { projection: Option>, metrics: ExecutionPlanMetricsSet, projected_statistics: Option, + schema_adapter_factory: Option>, } impl AvroSource { @@ -244,13 +125,29 @@ impl FileSource for AvroSource { ) -> Result> { Ok(None) } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } } mod private { use super::*; use bytes::Buf; - use datafusion_datasource::{file_meta::FileMeta, file_stream::FileOpenFuture}; + use datafusion_datasource::{ + file_meta::FileMeta, file_stream::FileOpenFuture, PartitionedFile, + }; use futures::StreamExt; use object_store::{GetResultPayload, ObjectStore}; @@ -260,7 +157,11 @@ mod private { } impl FileOpener for AvroOpener { - fn open(&self, file_meta: FileMeta) -> Result { + fn open( + &self, + file_meta: FileMeta, + _file: PartitionedFile, + ) -> Result { let config = Arc::clone(&self.config); let object_store = Arc::clone(&self.object_store); Ok(Box::pin(async move { diff --git a/datafusion/datasource-csv/README.md b/datafusion/datasource-csv/README.md index c5944f9e438f..0ebddb538663 100644 --- a/datafusion/datasource-csv/README.md +++ b/datafusion/datasource-csv/README.md @@ -23,4 +23,9 @@ This crate is a submodule of DataFusion that defines a CSV based file source. +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/datasource-csv/src/file_format.rs b/datafusion/datasource-csv/src/file_format.rs index 76f3c50a70a7..c9cd09bf676b 100644 --- a/datafusion/datasource-csv/src/file_format.rs +++ b/datafusion/datasource-csv/src/file_format.rs @@ -50,7 +50,6 @@ use datafusion_datasource::write::orchestration::spawn_writer_tasks_and_join; use datafusion_datasource::write::BatchSerializer; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; -use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::LexRequirement; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_session::Session; @@ -408,17 +407,16 @@ impl FileFormat for CsvFormat { &self, state: &dyn Session, conf: FileScanConfig, - _filters: Option<&Arc>, ) -> Result> { // Consult configuration options for default values let has_header = self .options .has_header - .unwrap_or(state.config_options().catalog.has_header); + .unwrap_or_else(|| state.config_options().catalog.has_header); let newlines_in_values = self .options .newlines_in_values - .unwrap_or(state.config_options().catalog.newlines_in_values); + .unwrap_or_else(|| state.config_options().catalog.newlines_in_values); let conf_builder = FileScanConfigBuilder::from(conf) .with_file_compression_type(self.options.compression.into()) @@ -454,11 +452,11 @@ impl FileFormat for CsvFormat { let has_header = self .options() .has_header - .unwrap_or(state.config_options().catalog.has_header); + .unwrap_or_else(|| state.config_options().catalog.has_header); let newlines_in_values = self .options() .newlines_in_values - .unwrap_or(state.config_options().catalog.newlines_in_values); + .unwrap_or_else(|| state.config_options().catalog.newlines_in_values); let options = self .options() @@ -504,7 +502,7 @@ impl CsvFormat { && self .options .has_header - .unwrap_or(state.config_options().catalog.has_header), + .unwrap_or_else(|| state.config_options().catalog.has_header), ) .with_delimiter(self.options.delimiter) .with_quote(self.options.quote); diff --git a/datafusion/datasource-csv/src/source.rs b/datafusion/datasource-csv/src/source.rs index f5d45cd3fc88..6c994af940d1 100644 --- a/datafusion/datasource-csv/src/source.rs +++ b/datafusion/datasource-csv/src/source.rs @@ -17,6 +17,7 @@ //! Execution plan for reading CSV files +use datafusion_datasource::schema_adapter::SchemaAdapterFactory; use std::any::Any; use std::fmt; use std::io::{Read, Seek, SeekFrom}; @@ -28,379 +29,28 @@ use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_meta::FileMeta; use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener}; use datafusion_datasource::{ - calculate_range, FileRange, ListingTableUrl, RangeCalculation, + as_file_source, calculate_range, FileRange, ListingTableUrl, PartitionedFile, + RangeCalculation, }; use arrow::csv; use arrow::datatypes::SchemaRef; -use datafusion_common::config::ConfigOptions; -use datafusion_common::{Constraints, DataFusionError, Result, Statistics}; +use datafusion_common::{DataFusionError, Result, Statistics}; use datafusion_common_runtime::JoinSet; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_datasource::source::DataSourceExec; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; -use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_execution::TaskContext; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, }; use crate::file_format::CsvDecoder; -use datafusion_datasource::file_groups::FileGroup; use futures::{StreamExt, TryStreamExt}; use object_store::buffered::BufWriter; use object_store::{GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; -/// Old Csv source, deprecated with DataSourceExec implementation and CsvSource -/// -/// See examples on `CsvSource` -#[derive(Debug, Clone)] -#[deprecated(since = "46.0.0", note = "use DataSourceExec instead")] -pub struct CsvExec { - base_config: FileScanConfig, - inner: DataSourceExec, -} - -/// Builder for [`CsvExec`]. -/// -/// See example on [`CsvExec`]. -#[derive(Debug, Clone)] -#[deprecated(since = "46.0.0", note = "use FileScanConfig instead")] -pub struct CsvExecBuilder { - file_scan_config: FileScanConfig, - file_compression_type: FileCompressionType, - // TODO: it seems like these format options could be reused across all the various CSV config - has_header: bool, - delimiter: u8, - quote: u8, - terminator: Option, - escape: Option, - comment: Option, - newlines_in_values: bool, -} - -#[allow(unused, deprecated)] -impl CsvExecBuilder { - /// Create a new builder to read the provided file scan configuration. - pub fn new(file_scan_config: FileScanConfig) -> Self { - Self { - file_scan_config, - // TODO: these defaults are duplicated from `CsvOptions` - should they be computed? - has_header: false, - delimiter: b',', - quote: b'"', - terminator: None, - escape: None, - comment: None, - newlines_in_values: false, - file_compression_type: FileCompressionType::UNCOMPRESSED, - } - } - - /// Set whether the first row defines the column names. - /// - /// The default value is `false`. - pub fn with_has_header(mut self, has_header: bool) -> Self { - self.has_header = has_header; - self - } - - /// Set the column delimeter. - /// - /// The default is `,`. - pub fn with_delimeter(mut self, delimiter: u8) -> Self { - self.delimiter = delimiter; - self - } - - /// Set the quote character. - /// - /// The default is `"`. - pub fn with_quote(mut self, quote: u8) -> Self { - self.quote = quote; - self - } - - /// Set the line terminator. If not set, the default is CRLF. - /// - /// The default is None. - pub fn with_terminator(mut self, terminator: Option) -> Self { - self.terminator = terminator; - self - } - - /// Set the escape character. - /// - /// The default is `None` (i.e. quotes cannot be escaped). - pub fn with_escape(mut self, escape: Option) -> Self { - self.escape = escape; - self - } - - /// Set the comment character. - /// - /// The default is `None` (i.e. comments are not supported). - pub fn with_comment(mut self, comment: Option) -> Self { - self.comment = comment; - self - } - - /// Set whether newlines in (quoted) values are supported. - /// - /// Parsing newlines in quoted values may be affected by execution behaviour such as - /// parallel file scanning. Setting this to `true` ensures that newlines in values are - /// parsed successfully, which may reduce performance. - /// - /// The default value is `false`. - pub fn with_newlines_in_values(mut self, newlines_in_values: bool) -> Self { - self.newlines_in_values = newlines_in_values; - self - } - - /// Set the file compression type. - /// - /// The default is [`FileCompressionType::UNCOMPRESSED`]. - pub fn with_file_compression_type( - mut self, - file_compression_type: FileCompressionType, - ) -> Self { - self.file_compression_type = file_compression_type; - self - } - - /// Build a [`CsvExec`]. - #[must_use] - pub fn build(self) -> CsvExec { - let Self { - file_scan_config: base_config, - file_compression_type, - has_header, - delimiter, - quote, - terminator, - escape, - comment, - newlines_in_values, - } = self; - - let ( - projected_schema, - projected_constraints, - projected_statistics, - projected_output_ordering, - ) = base_config.project(); - let cache = CsvExec::compute_properties( - projected_schema, - &projected_output_ordering, - projected_constraints, - &base_config, - ); - let csv = CsvSource::new(has_header, delimiter, quote) - .with_comment(comment) - .with_escape(escape) - .with_terminator(terminator); - let base_config = base_config - .with_newlines_in_values(newlines_in_values) - .with_file_compression_type(file_compression_type) - .with_source(Arc::new(csv)); - - CsvExec { - inner: DataSourceExec::new(Arc::new(base_config.clone())), - base_config, - } - } -} - -#[allow(unused, deprecated)] -impl CsvExec { - /// Create a new CSV reader execution plan provided base and specific configurations - #[allow(clippy::too_many_arguments)] - pub fn new( - base_config: FileScanConfig, - has_header: bool, - delimiter: u8, - quote: u8, - terminator: Option, - escape: Option, - comment: Option, - newlines_in_values: bool, - file_compression_type: FileCompressionType, - ) -> Self { - CsvExecBuilder::new(base_config) - .with_has_header(has_header) - .with_delimeter(delimiter) - .with_quote(quote) - .with_terminator(terminator) - .with_escape(escape) - .with_comment(comment) - .with_newlines_in_values(newlines_in_values) - .with_file_compression_type(file_compression_type) - .build() - } - - /// Return a [`CsvExecBuilder`]. - /// - /// See example on [`CsvExec`] and [`CsvExecBuilder`] for specifying CSV table options. - pub fn builder(file_scan_config: FileScanConfig) -> CsvExecBuilder { - CsvExecBuilder::new(file_scan_config) - } - - /// Ref to the base configs - pub fn base_config(&self) -> &FileScanConfig { - &self.base_config - } - - fn file_scan_config(&self) -> FileScanConfig { - self.inner - .data_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - fn csv_source(&self) -> CsvSource { - let source = self.file_scan_config(); - source - .file_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - /// true if the first line of each file is a header - pub fn has_header(&self) -> bool { - self.csv_source().has_header() - } - - /// Specifies whether newlines in (quoted) values are supported. - /// - /// Parsing newlines in quoted values may be affected by execution behaviour such as - /// parallel file scanning. Setting this to `true` ensures that newlines in values are - /// parsed successfully, which may reduce performance. - /// - /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. - pub fn newlines_in_values(&self) -> bool { - let source = self.file_scan_config(); - source.newlines_in_values() - } - - fn output_partitioning_helper(file_scan_config: &FileScanConfig) -> Partitioning { - Partitioning::UnknownPartitioning(file_scan_config.file_groups.len()) - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - orderings: &[LexOrdering], - constraints: Constraints, - file_scan_config: &FileScanConfig, - ) -> PlanProperties { - // Equivalence Properties - let eq_properties = EquivalenceProperties::new_with_orderings(schema, orderings) - .with_constraints(constraints); - - PlanProperties::new( - eq_properties, - Self::output_partitioning_helper(file_scan_config), // Output Partitioning - EmissionType::Incremental, - Boundedness::Bounded, - ) - } - - fn with_file_groups(mut self, file_groups: Vec) -> Self { - self.base_config.file_groups = file_groups.clone(); - let mut file_source = self.file_scan_config(); - file_source = file_source.with_file_groups(file_groups); - self.inner = self.inner.with_data_source(Arc::new(file_source)); - self - } -} - -#[allow(unused, deprecated)] -impl DisplayAs for CsvExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { - self.inner.fmt_as(t, f) - } -} - -#[allow(unused, deprecated)] -impl ExecutionPlan for CsvExec { - fn name(&self) -> &'static str { - "CsvExec" - } - - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - self.inner.properties() - } - - fn children(&self) -> Vec<&Arc> { - // this is a leaf node and has no children - vec![] - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - /// Redistribute files across partitions according to their size - /// See comments on `FileGroupPartitioner` for more detail. - /// - /// Return `None` if can't get repartitioned (empty, compressed file, or `newlines_in_values` set). - fn repartitioned( - &self, - target_partitions: usize, - config: &ConfigOptions, - ) -> Result>> { - self.inner.repartitioned(target_partitions, config) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - self.inner.execute(partition, context) - } - - fn statistics(&self) -> Result { - self.inner.statistics() - } - - fn metrics(&self) -> Option { - self.inner.metrics() - } - - fn fetch(&self) -> Option { - self.inner.fetch() - } - - fn with_fetch(&self, limit: Option) -> Option> { - self.inner.with_fetch(limit) - } - - fn try_swapping_with_projection( - &self, - projection: &ProjectionExec, - ) -> Result>> { - self.inner.try_swapping_with_projection(projection) - } -} - /// A Config for [`CsvOpener`] /// /// # Example: create a `DataSourceExec` for CSV @@ -443,6 +93,7 @@ pub struct CsvSource { comment: Option, metrics: ExecutionPlanMetricsSet, projected_statistics: Option, + schema_adapter_factory: Option>, } impl CsvSource { @@ -564,6 +215,12 @@ impl CsvOpener { } } +impl From for Arc { + fn from(source: CsvSource) -> Self { + as_file_source(source) + } +} + impl FileSource for CsvSource { fn create_file_opener( &self, @@ -626,6 +283,20 @@ impl FileSource for CsvSource { DisplayFormatType::TreeRender => Ok(()), } } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } } impl FileOpener for CsvOpener { @@ -652,7 +323,11 @@ impl FileOpener for CsvOpener { /// A,1,2,3,4,5,6,7,8,9\n /// A},1,2,3,4,5,6,7,8,9\n /// The lines read would be: [1, 2] - fn open(&self, file_meta: FileMeta) -> Result { + fn open( + &self, + file_meta: FileMeta, + _file: PartitionedFile, + ) -> Result { // `self.config.has_header` controls whether to skip reading the 1st line header // If the .csv file is read in parallel and this `CsvOpener` is only reading some middle // partition, then don't skip first line @@ -743,6 +418,11 @@ pub async fn plan_to_csv( let parsed = ListingTableUrl::parse(path)?; let object_store_url = parsed.object_store(); let store = task_ctx.runtime_env().object_store(&object_store_url)?; + let writer_buffer_size = task_ctx + .session_config() + .options() + .execution + .objectstore_writer_buffer_size; let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { let storeref = Arc::clone(&store); @@ -752,7 +432,8 @@ pub async fn plan_to_csv( let mut stream = plan.execute(i, Arc::clone(&task_ctx))?; join_set.spawn(async move { - let mut buf_writer = BufWriter::new(storeref, file.clone()); + let mut buf_writer = + BufWriter::with_capacity(storeref, file.clone(), writer_buffer_size); let mut buffer = Vec::with_capacity(1024); //only write headers on first iteration let mut write_headers = true; diff --git a/datafusion/datasource-json/README.md b/datafusion/datasource-json/README.md index 64181814736d..ac0b73b78e69 100644 --- a/datafusion/datasource-json/README.md +++ b/datafusion/datasource-json/README.md @@ -23,4 +23,9 @@ This crate is a submodule of DataFusion that defines a JSON based file source. +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/datasource-json/src/file_format.rs b/datafusion/datasource-json/src/file_format.rs index 8d0515804fc7..f6b758b5bc51 100644 --- a/datafusion/datasource-json/src/file_format.rs +++ b/datafusion/datasource-json/src/file_format.rs @@ -52,7 +52,6 @@ use datafusion_datasource::write::orchestration::spawn_writer_tasks_and_join; use datafusion_datasource::write::BatchSerializer; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; -use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::LexRequirement; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_session::Session; @@ -249,7 +248,6 @@ impl FileFormat for JsonFormat { &self, _state: &dyn Session, conf: FileScanConfig, - _filters: Option<&Arc>, ) -> Result> { let source = Arc::new(JsonSource::new()); let conf = FileScanConfigBuilder::from(conf) diff --git a/datafusion/datasource-json/src/source.rs b/datafusion/datasource-json/src/source.rs index ee96d050966d..d318928e5c6b 100644 --- a/datafusion/datasource-json/src/source.rs +++ b/datafusion/datasource-json/src/source.rs @@ -30,198 +30,25 @@ use datafusion_datasource::decoder::{deserialize_stream, DecoderDeserializer}; use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_meta::FileMeta; use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener}; -use datafusion_datasource::{calculate_range, ListingTableUrl, RangeCalculation}; +use datafusion_datasource::schema_adapter::SchemaAdapterFactory; +use datafusion_datasource::{ + as_file_source, calculate_range, ListingTableUrl, PartitionedFile, RangeCalculation, +}; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use arrow::json::ReaderBuilder; use arrow::{datatypes::SchemaRef, json}; -use datafusion_common::{Constraints, Statistics}; +use datafusion_common::Statistics; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_datasource::source::DataSourceExec; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; -use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use datafusion_physical_plan::{DisplayAs, DisplayFormatType, PlanProperties}; - -use datafusion_datasource::file_groups::FileGroup; +use datafusion_execution::TaskContext; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; + use futures::{StreamExt, TryStreamExt}; use object_store::buffered::BufWriter; use object_store::{GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; -/// Execution plan for scanning NdJson data source -#[derive(Debug, Clone)] -#[deprecated(since = "46.0.0", note = "use DataSourceExec instead")] -pub struct NdJsonExec { - inner: DataSourceExec, - base_config: FileScanConfig, - file_compression_type: FileCompressionType, -} - -#[allow(unused, deprecated)] -impl NdJsonExec { - /// Create a new JSON reader execution plan provided base configurations - pub fn new( - base_config: FileScanConfig, - file_compression_type: FileCompressionType, - ) -> Self { - let ( - projected_schema, - projected_constraints, - projected_statistics, - projected_output_ordering, - ) = base_config.project(); - let cache = Self::compute_properties( - projected_schema, - &projected_output_ordering, - projected_constraints, - &base_config, - ); - - let json = JsonSource::default(); - let base_config = base_config - .with_file_compression_type(file_compression_type) - .with_source(Arc::new(json)); - - Self { - inner: DataSourceExec::new(Arc::new(base_config.clone())), - file_compression_type: base_config.file_compression_type, - base_config, - } - } - - /// Ref to the base configs - pub fn base_config(&self) -> &FileScanConfig { - &self.base_config - } - - /// Ref to file compression type - pub fn file_compression_type(&self) -> &FileCompressionType { - &self.file_compression_type - } - - fn file_scan_config(&self) -> FileScanConfig { - self.inner - .data_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - fn json_source(&self) -> JsonSource { - let source = self.file_scan_config(); - source - .file_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - fn output_partitioning_helper(file_scan_config: &FileScanConfig) -> Partitioning { - Partitioning::UnknownPartitioning(file_scan_config.file_groups.len()) - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - orderings: &[LexOrdering], - constraints: Constraints, - file_scan_config: &FileScanConfig, - ) -> PlanProperties { - // Equivalence Properties - let eq_properties = EquivalenceProperties::new_with_orderings(schema, orderings) - .with_constraints(constraints); - - PlanProperties::new( - eq_properties, - Self::output_partitioning_helper(file_scan_config), // Output Partitioning - EmissionType::Incremental, - Boundedness::Bounded, - ) - } - - fn with_file_groups(mut self, file_groups: Vec) -> Self { - self.base_config.file_groups = file_groups.clone(); - let mut file_source = self.file_scan_config(); - file_source = file_source.with_file_groups(file_groups); - self.inner = self.inner.with_data_source(Arc::new(file_source)); - self - } -} - -#[allow(unused, deprecated)] -impl DisplayAs for NdJsonExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - self.inner.fmt_as(t, f) - } -} - -#[allow(unused, deprecated)] -impl ExecutionPlan for NdJsonExec { - fn name(&self) -> &'static str { - "NdJsonExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - fn properties(&self) -> &PlanProperties { - self.inner.properties() - } - - fn children(&self) -> Vec<&Arc> { - Vec::new() - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - fn repartitioned( - &self, - target_partitions: usize, - config: &datafusion_common::config::ConfigOptions, - ) -> Result>> { - self.inner.repartitioned(target_partitions, config) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - self.inner.execute(partition, context) - } - - fn statistics(&self) -> Result { - self.inner.statistics() - } - - fn metrics(&self) -> Option { - self.inner.metrics() - } - - fn fetch(&self) -> Option { - self.inner.fetch() - } - - fn with_fetch(&self, limit: Option) -> Option> { - self.inner.with_fetch(limit) - } -} - /// A [`FileOpener`] that opens a JSON file and yields a [`FileOpenFuture`] pub struct JsonOpener { batch_size: usize, @@ -253,6 +80,7 @@ pub struct JsonSource { batch_size: Option, metrics: ExecutionPlanMetricsSet, projected_statistics: Option, + schema_adapter_factory: Option>, } impl JsonSource { @@ -262,6 +90,12 @@ impl JsonSource { } } +impl From for Arc { + fn from(source: JsonSource) -> Self { + as_file_source(source) + } +} + impl FileSource for JsonSource { fn create_file_opener( &self, @@ -316,6 +150,20 @@ impl FileSource for JsonSource { fn file_type(&self) -> &str { "json" } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } } impl FileOpener for JsonOpener { @@ -328,7 +176,11 @@ impl FileOpener for JsonOpener { /// are applied to determine which lines to read: /// 1. The first line of the partition is the line in which the index of the first character >= `start`. /// 2. The last line of the partition is the line in which the byte at position `end - 1` resides. - fn open(&self, file_meta: FileMeta) -> Result { + fn open( + &self, + file_meta: FileMeta, + _file: PartitionedFile, + ) -> Result { let store = Arc::clone(&self.object_store); let schema = Arc::clone(&self.projected_schema); let batch_size = self.batch_size; @@ -399,6 +251,11 @@ pub async fn plan_to_json( let parsed = ListingTableUrl::parse(path)?; let object_store_url = parsed.object_store(); let store = task_ctx.runtime_env().object_store(&object_store_url)?; + let writer_buffer_size = task_ctx + .session_config() + .options() + .execution + .objectstore_writer_buffer_size; let mut join_set = JoinSet::new(); for i in 0..plan.output_partitioning().partition_count() { let storeref = Arc::clone(&store); @@ -408,7 +265,8 @@ pub async fn plan_to_json( let mut stream = plan.execute(i, Arc::clone(&task_ctx))?; join_set.spawn(async move { - let mut buf_writer = BufWriter::new(storeref, file.clone()); + let mut buf_writer = + BufWriter::with_capacity(storeref, file.clone(), writer_buffer_size); let mut buffer = Vec::with_capacity(1024); while let Some(batch) = stream.next().await.transpose()? { diff --git a/datafusion/datasource-parquet/Cargo.toml b/datafusion/datasource-parquet/Cargo.toml index b6a548c998dc..08d258852a20 100644 --- a/datafusion/datasource-parquet/Cargo.toml +++ b/datafusion/datasource-parquet/Cargo.toml @@ -45,6 +45,7 @@ datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-optimizer = { workspace = true } datafusion-physical-plan = { workspace = true } +datafusion-pruning = { workspace = true } datafusion-session = { workspace = true } futures = { workspace = true } itertools = { workspace = true } diff --git a/datafusion/datasource-parquet/README.md b/datafusion/datasource-parquet/README.md index abcdd5ab1340..9ac472a9f4f0 100644 --- a/datafusion/datasource-parquet/README.md +++ b/datafusion/datasource-parquet/README.md @@ -23,4 +23,9 @@ This crate is a submodule of DataFusion that defines a Parquet based file source. +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/datasource-parquet/src/file_format.rs b/datafusion/datasource-parquet/src/file_format.rs index 2ef4f236f278..59663fe5100a 100644 --- a/datafusion/datasource-parquet/src/file_format.rs +++ b/datafusion/datasource-parquet/src/file_format.rs @@ -18,20 +18,22 @@ //! [`ParquetFormat`]: Parquet [`FileFormat`] abstractions use std::any::Any; +use std::cell::RefCell; use std::fmt; use std::fmt::Debug; use std::ops::Range; +use std::rc::Rc; use std::sync::Arc; use arrow::array::RecordBatch; use arrow::datatypes::{Fields, Schema, SchemaRef, TimeUnit}; use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; -use datafusion_datasource::write::{create_writer, get_writer_schema, SharedBuffer}; - -use datafusion_datasource::file_format::{ - FileFormat, FileFormatFactory, FilePushdownSupport, +use datafusion_datasource::write::{ + get_writer_schema, ObjectWriterBuilder, SharedBuffer, }; + +use datafusion_datasource::file_format::{FileFormat, FileFormatFactory}; use datafusion_datasource::write::demux::DemuxedStreamReceiver; use arrow::compute::sum; @@ -41,7 +43,7 @@ use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, ColumnStatistics, - DataFusionError, GetExt, Result, DEFAULT_PARQUET_EXTENSION, + DataFusionError, GetExt, HashSet, Result, DEFAULT_PARQUET_EXTENSION, }; use datafusion_common::{HashMap, Statistics}; use datafusion_common_runtime::{JoinSet, SpawnedTask}; @@ -52,16 +54,13 @@ use datafusion_datasource::sink::{DataSink, DataSinkExec}; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; -use datafusion_expr::Expr; use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}; -use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::LexRequirement; use datafusion_physical_plan::Accumulator; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_session::Session; -use crate::can_expr_be_pushed_down_with_schemas; -use crate::source::ParquetSource; +use crate::source::{parse_coerce_int96_string, ParquetSource}; use async_trait::async_trait; use bytes::Bytes; use datafusion_datasource::source::DataSourceExec; @@ -79,6 +78,7 @@ use parquet::arrow::arrow_writer::{ use parquet::arrow::async_reader::MetadataFetch; use parquet::arrow::{parquet_to_arrow_schema, ArrowSchemaConverter, AsyncArrowWriter}; use parquet::basic::Type; +use parquet::encryption::decrypt::FileDecryptionProperties; use parquet::errors::ParquetError; use parquet::file::metadata::{ParquetMetaData, ParquetMetaDataReader, RowGroupMetaData}; use parquet::file::properties::{WriterProperties, WriterPropertiesBuilder}; @@ -304,9 +304,18 @@ async fn fetch_schema_with_location( store: &dyn ObjectStore, file: &ObjectMeta, metadata_size_hint: Option, + file_decryption_properties: Option<&FileDecryptionProperties>, + coerce_int96: Option, ) -> Result<(Path, Schema)> { let loc_path = file.location.clone(); - let schema = fetch_schema(store, file, metadata_size_hint).await?; + let schema = fetch_schema( + store, + file, + metadata_size_hint, + file_decryption_properties, + coerce_int96, + ) + .await?; Ok((loc_path, schema)) } @@ -337,12 +346,27 @@ impl FileFormat for ParquetFormat { store: &Arc, objects: &[ObjectMeta], ) -> Result { + let coerce_int96 = match self.coerce_int96() { + Some(time_unit) => Some(parse_coerce_int96_string(time_unit.as_str())?), + None => None, + }; + let config_file_decryption_properties = &self.options.crypto.file_decryption; + let file_decryption_properties: Option = + match config_file_decryption_properties { + Some(cfd) => { + let fd: FileDecryptionProperties = cfd.clone().into(); + Some(fd) + } + None => None, + }; let mut schemas: Vec<_> = futures::stream::iter(objects) .map(|object| { fetch_schema_with_location( store.as_ref(), object, self.metadata_size_hint(), + file_decryption_properties.as_ref(), + coerce_int96, ) }) .boxed() // Workaround https://github.com/rust-lang/rust/issues/64552 @@ -391,11 +415,21 @@ impl FileFormat for ParquetFormat { table_schema: SchemaRef, object: &ObjectMeta, ) -> Result { + let config_file_decryption_properties = &self.options.crypto.file_decryption; + let file_decryption_properties: Option = + match config_file_decryption_properties { + Some(cfd) => { + let fd: FileDecryptionProperties = cfd.clone().into(); + Some(fd) + } + None => None, + }; let stats = fetch_statistics( store.as_ref(), table_schema, object, self.metadata_size_hint(), + file_decryption_properties.as_ref(), ) .await?; Ok(stats) @@ -405,34 +439,23 @@ impl FileFormat for ParquetFormat { &self, _state: &dyn Session, conf: FileScanConfig, - filters: Option<&Arc>, ) -> Result> { - let mut predicate = None; let mut metadata_size_hint = None; - // If enable pruning then combine the filters to build the predicate. - // If disable pruning then set the predicate to None, thus readers - // will not prune data based on the statistics. - if self.enable_pruning() { - if let Some(pred) = filters.cloned() { - predicate = Some(pred); - } - } if let Some(metadata) = self.metadata_size_hint() { metadata_size_hint = Some(metadata); } let mut source = ParquetSource::new(self.options.clone()); - if let Some(predicate) = predicate { - source = source.with_predicate(Arc::clone(&conf.file_schema), predicate); - } if let Some(metadata_size_hint) = metadata_size_hint { source = source.with_metadata_size_hint(metadata_size_hint) } + // Apply schema adapter factory before building the new config + let file_source = source.apply_schema_adapter(&conf)?; let conf = FileScanConfigBuilder::from(conf) - .with_source(Arc::new(source)) + .with_source(file_source) .build(); Ok(DataSourceExec::from_data_source(conf)) } @@ -453,27 +476,6 @@ impl FileFormat for ParquetFormat { Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _) } - fn supports_filters_pushdown( - &self, - file_schema: &Schema, - table_schema: &Schema, - filters: &[&Expr], - ) -> Result { - if !self.options().global.pushdown_filters { - return Ok(FilePushdownSupport::NoSupport); - } - - let all_supported = filters.iter().all(|filter| { - can_expr_be_pushed_down_with_schemas(filter, file_schema, table_schema) - }); - - Ok(if all_supported { - FilePushdownSupport::Supported - } else { - FilePushdownSupport::NotSupportedForFilter - }) - } - fn file_source(&self) -> Arc { Arc::new(ParquetSource::default()) } @@ -588,38 +590,186 @@ pub fn coerce_int96_to_resolution( file_schema: &Schema, time_unit: &TimeUnit, ) -> Option { - let mut transform = false; - let parquet_fields: HashMap<_, _> = parquet_schema + // Traverse the parquet_schema columns looking for int96 physical types. If encountered, insert + // the field's full path into a set. + let int96_fields: HashSet<_> = parquet_schema .columns() .iter() - .map(|f| { - let dt = f.physical_type(); - if dt.eq(&Type::INT96) { - transform = true; - } - (f.name(), dt) - }) + .filter(|f| f.physical_type() == Type::INT96) + .map(|f| f.path().string()) .collect(); - if !transform { + if int96_fields.is_empty() { + // The schema doesn't contain any int96 fields, so skip the remaining logic. return None; } - let transformed_fields: Vec> = file_schema - .fields - .iter() - .map(|field| match parquet_fields.get(field.name().as_str()) { - Some(Type::INT96) => { - field_with_new_type(field, DataType::Timestamp(*time_unit, None)) + // Do a DFS into the schema using a stack, looking for timestamp(nanos) fields that originated + // as int96 to coerce to the provided time_unit. + + type NestedFields = Rc>>; + type StackContext<'a> = ( + Vec<&'a str>, // The Parquet column path (e.g., "c0.list.element.c1") for the current field. + &'a FieldRef, // The current field to be processed. + NestedFields, // The parent's fields that this field will be (possibly) type-coerced and + // inserted into. All fields have a parent, so this is not an Option type. + Option, // Nested types need to create their own vector of fields for their + // children. For primitive types this will remain None. For nested + // types it is None the first time they are processed. Then, we + // instantiate a vector for its children, push the field back onto the + // stack to be processed again, and DFS into its children. The next + // time we process the field, we know we have DFS'd into the children + // because this field is Some. + ); + + // This is our top-level fields from which we will construct our schema. We pass this into our + // initial stack context as the parent fields, and the DFS populates it. + let fields = Rc::new(RefCell::new(Vec::with_capacity(file_schema.fields.len()))); + + // TODO: It might be possible to only DFS into nested fields that we know contain an int96 if we + // use some sort of LPM data structure to check if we're currently DFS'ing nested types that are + // in a column path that contains an int96. That can be a future optimization for large schemas. + let transformed_schema = { + // Populate the stack with our top-level fields. + let mut stack: Vec = file_schema + .fields() + .iter() + .rev() + .map(|f| (vec![f.name().as_str()], f, Rc::clone(&fields), None)) + .collect(); + + // Pop fields to DFS into until we have exhausted the stack. + while let Some((parquet_path, current_field, parent_fields, child_fields)) = + stack.pop() + { + match (current_field.data_type(), child_fields) { + (DataType::Struct(unprocessed_children), None) => { + // This is the first time popping off this struct. We don't yet know the + // correct types of its children (i.e., if they need coercing) so we create + // a vector for child_fields, push the struct node back onto the stack to be + // processed again (see below) after processing all its children. + let child_fields = Rc::new(RefCell::new(Vec::with_capacity( + unprocessed_children.len(), + ))); + // Note that here we push the struct back onto the stack with its + // parent_fields in the same position, now with Some(child_fields). + stack.push(( + parquet_path.clone(), + current_field, + parent_fields, + Some(Rc::clone(&child_fields)), + )); + // Push all the children in reverse to maintain original schema order due to + // stack processing. + for child in unprocessed_children.into_iter().rev() { + let mut child_path = parquet_path.clone(); + // Build up a normalized path that we'll use as a key into the original + // int96_fields set above to test if this originated as int96. + child_path.push("."); + child_path.push(child.name()); + // Note that here we push the field onto the stack using the struct's + // new child_fields vector as the field's parent_fields. + stack.push((child_path, child, Rc::clone(&child_fields), None)); + } + } + (DataType::Struct(unprocessed_children), Some(processed_children)) => { + // This is the second time popping off this struct. The child_fields vector + // now contains each field that has been DFS'd into, and we can construct + // the resulting struct with correct child types. + let processed_children = processed_children.borrow(); + assert_eq!(processed_children.len(), unprocessed_children.len()); + let processed_struct = Field::new_struct( + current_field.name(), + processed_children.as_slice(), + current_field.is_nullable(), + ); + parent_fields.borrow_mut().push(Arc::new(processed_struct)); + } + (DataType::List(unprocessed_child), None) => { + // This is the first time popping off this list. See struct docs above. + let child_fields = Rc::new(RefCell::new(Vec::with_capacity(1))); + stack.push(( + parquet_path.clone(), + current_field, + parent_fields, + Some(Rc::clone(&child_fields)), + )); + let mut child_path = parquet_path.clone(); + // Spark uses a definition for arrays/lists that results in a group + // named "list" that is not maintained when parsing to Arrow. We just push + // this name into the path. + child_path.push(".list."); + child_path.push(unprocessed_child.name()); + stack.push(( + child_path.clone(), + unprocessed_child, + Rc::clone(&child_fields), + None, + )); + } + (DataType::List(_), Some(processed_children)) => { + // This is the second time popping off this list. See struct docs above. + let processed_children = processed_children.borrow(); + assert_eq!(processed_children.len(), 1); + let processed_list = Field::new_list( + current_field.name(), + Arc::clone(&processed_children[0]), + current_field.is_nullable(), + ); + parent_fields.borrow_mut().push(Arc::new(processed_list)); + } + (DataType::Map(unprocessed_child, _), None) => { + // This is the first time popping off this map. See struct docs above. + let child_fields = Rc::new(RefCell::new(Vec::with_capacity(1))); + stack.push(( + parquet_path.clone(), + current_field, + parent_fields, + Some(Rc::clone(&child_fields)), + )); + let mut child_path = parquet_path.clone(); + child_path.push("."); + child_path.push(unprocessed_child.name()); + stack.push(( + child_path.clone(), + unprocessed_child, + Rc::clone(&child_fields), + None, + )); + } + (DataType::Map(_, sorted), Some(processed_children)) => { + // This is the second time popping off this map. See struct docs above. + let processed_children = processed_children.borrow(); + assert_eq!(processed_children.len(), 1); + let processed_map = Field::new( + current_field.name(), + DataType::Map(Arc::clone(&processed_children[0]), *sorted), + current_field.is_nullable(), + ); + parent_fields.borrow_mut().push(Arc::new(processed_map)); + } + (DataType::Timestamp(TimeUnit::Nanosecond, None), None) + if int96_fields.contains(parquet_path.concat().as_str()) => + // We found a timestamp(nanos) and it originated as int96. Coerce it to the correct + // time_unit. + { + parent_fields.borrow_mut().push(field_with_new_type( + current_field, + DataType::Timestamp(*time_unit, None), + )); + } + // Other types can be cloned as they are. + _ => parent_fields.borrow_mut().push(Arc::clone(current_field)), } - _ => Arc::clone(field), - }) - .collect(); + } + assert_eq!(fields.borrow().len(), file_schema.fields.len()); + Schema::new_with_metadata( + fields.borrow_mut().clone(), + file_schema.metadata.clone(), + ) + }; - Some(Schema::new_with_metadata( - transformed_fields, - file_schema.metadata.clone(), - )) + Some(transformed_schema) } /// Coerces the file schema if the table schema uses a view type. @@ -809,12 +959,14 @@ pub async fn fetch_parquet_metadata( store: &dyn ObjectStore, meta: &ObjectMeta, size_hint: Option, + decryption_properties: Option<&FileDecryptionProperties>, ) -> Result { let file_size = meta.size; let fetch = ObjectStoreFetch::new(store, meta); ParquetMetaDataReader::new() .with_prefetch_hint(size_hint) + .with_decryption_properties(decryption_properties) .load_and_finish(fetch, file_size) .await .map_err(DataFusionError::from) @@ -825,13 +977,26 @@ async fn fetch_schema( store: &dyn ObjectStore, file: &ObjectMeta, metadata_size_hint: Option, + file_decryption_properties: Option<&FileDecryptionProperties>, + coerce_int96: Option, ) -> Result { - let metadata = fetch_parquet_metadata(store, file, metadata_size_hint).await?; + let metadata = fetch_parquet_metadata( + store, + file, + metadata_size_hint, + file_decryption_properties, + ) + .await?; let file_metadata = metadata.file_metadata(); let schema = parquet_to_arrow_schema( file_metadata.schema_descr(), file_metadata.key_value_metadata(), )?; + let schema = coerce_int96 + .and_then(|time_unit| { + coerce_int96_to_resolution(file_metadata.schema_descr(), &schema, &time_unit) + }) + .unwrap_or(schema); Ok(schema) } @@ -843,8 +1008,11 @@ pub async fn fetch_statistics( table_schema: SchemaRef, file: &ObjectMeta, metadata_size_hint: Option, + decryption_properties: Option<&FileDecryptionProperties>, ) -> Result { - let metadata = fetch_parquet_metadata(store, file, metadata_size_hint).await?; + let metadata = + fetch_parquet_metadata(store, file, metadata_size_hint, decryption_properties) + .await?; statistics_from_parquet_meta_calc(&metadata, table_schema) } @@ -938,7 +1106,7 @@ pub fn statistics_from_parquet_meta_calc( .ok(); } Err(e) => { - debug!("Failed to create statistics converter: {}", e); + debug!("Failed to create statistics converter: {e}"); null_counts_array[idx] = Precision::Exact(num_rows); } } @@ -1090,9 +1258,18 @@ impl ParquetSink { &self, location: &Path, object_store: Arc, + context: &Arc, parquet_props: WriterProperties, ) -> Result> { - let buf_writer = BufWriter::new(object_store, location.clone()); + let buf_writer = BufWriter::with_capacity( + object_store, + location.clone(), + context + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + ); let options = ArrowWriterOptions::new() .with_properties(parquet_props) .with_skip_arrow_metadata(self.parquet_options.global.skip_arrow_metadata); @@ -1125,9 +1302,15 @@ impl FileSink for ParquetSink { object_store: Arc, ) -> Result { let parquet_opts = &self.parquet_options; - let allow_single_file_parallelism = + let mut allow_single_file_parallelism = parquet_opts.global.allow_single_file_parallelism; + if parquet_opts.crypto.file_encryption.is_some() { + // For now, arrow-rs does not support parallel writes with encryption + // See https://github.com/apache/arrow-rs/issues/7359 + allow_single_file_parallelism = false; + } + let mut file_write_tasks: JoinSet< std::result::Result<(Path, FileMetaData), DataFusionError>, > = JoinSet::new(); @@ -1148,12 +1331,12 @@ impl FileSink for ParquetSink { .create_async_arrow_writer( &path, Arc::clone(&object_store), + context, parquet_props.clone(), ) .await?; - let mut reservation = - MemoryConsumer::new(format!("ParquetSink[{}]", path)) - .register(context.memory_pool()); + let mut reservation = MemoryConsumer::new(format!("ParquetSink[{path}]")) + .register(context.memory_pool()); file_write_tasks.spawn(async move { while let Some(batch) = rx.recv().await { writer.write(&batch).await?; @@ -1166,14 +1349,21 @@ impl FileSink for ParquetSink { Ok((path, file_metadata)) }); } else { - let writer = create_writer( + let writer = ObjectWriterBuilder::new( // Parquet files as a whole are never compressed, since they // manage compressed blocks themselves. FileCompressionType::UNCOMPRESSED, &path, Arc::clone(&object_store), ) - .await?; + .with_buffer_size(Some( + context + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + )) + .build()?; let schema = get_writer_schema(&self.config); let props = parquet_props.clone(); let parallel_options_clone = parallel_options.clone(); @@ -1585,3 +1775,220 @@ fn create_max_min_accs( .collect(); (max_values, min_values) } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + + use arrow::datatypes::DataType; + use parquet::schema::parser::parse_message_type; + + #[test] + fn coerce_int96_to_resolution_with_mixed_timestamps() { + // Unclear if Spark (or other writer) could generate a file with mixed timestamps like this, + // but we want to test the scenario just in case since it's at least a valid schema as far + // as the Parquet spec is concerned. + let spark_schema = " + message spark_schema { + optional int96 c0; + optional int64 c1 (TIMESTAMP(NANOS,true)); + optional int64 c2 (TIMESTAMP(NANOS,false)); + optional int64 c3 (TIMESTAMP(MILLIS,true)); + optional int64 c4 (TIMESTAMP(MILLIS,false)); + optional int64 c5 (TIMESTAMP(MICROS,true)); + optional int64 c6 (TIMESTAMP(MICROS,false)); + } + "; + + let schema = parse_message_type(spark_schema).expect("should parse schema"); + let descr = SchemaDescriptor::new(Arc::new(schema)); + + let arrow_schema = parquet_to_arrow_schema(&descr, None).unwrap(); + + let result = + coerce_int96_to_resolution(&descr, &arrow_schema, &TimeUnit::Microsecond) + .unwrap(); + + // Only the first field (c0) should be converted to a microsecond timestamp because it's the + // only timestamp that originated from an INT96. + let expected_schema = Schema::new(vec![ + Field::new("c0", DataType::Timestamp(TimeUnit::Microsecond, None), true), + Field::new( + "c1", + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), + true, + ), + Field::new("c2", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + Field::new( + "c3", + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), + true, + ), + Field::new("c4", DataType::Timestamp(TimeUnit::Millisecond, None), true), + Field::new( + "c5", + DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), + true, + ), + Field::new("c6", DataType::Timestamp(TimeUnit::Microsecond, None), true), + ]); + + assert_eq!(result, expected_schema); + } + + #[test] + fn coerce_int96_to_resolution_with_nested_types() { + // This schema is derived from Comet's CometFuzzTestSuite ParquetGenerator only using int96 + // primitive types with generateStruct, generateArray, and generateMap set to true, with one + // additional field added to c4's struct to make sure all fields in a struct get modified. + // https://github.com/apache/datafusion-comet/blob/main/spark/src/main/scala/org/apache/comet/testing/ParquetGenerator.scala + let spark_schema = " + message spark_schema { + optional int96 c0; + optional group c1 { + optional int96 c0; + } + optional group c2 { + optional group c0 (LIST) { + repeated group list { + optional int96 element; + } + } + } + optional group c3 (LIST) { + repeated group list { + optional int96 element; + } + } + optional group c4 (LIST) { + repeated group list { + optional group element { + optional int96 c0; + optional int96 c1; + } + } + } + optional group c5 (MAP) { + repeated group key_value { + required int96 key; + optional int96 value; + } + } + optional group c6 (LIST) { + repeated group list { + optional group element (MAP) { + repeated group key_value { + required int96 key; + optional int96 value; + } + } + } + } + } + "; + + let schema = parse_message_type(spark_schema).expect("should parse schema"); + let descr = SchemaDescriptor::new(Arc::new(schema)); + + let arrow_schema = parquet_to_arrow_schema(&descr, None).unwrap(); + + let result = + coerce_int96_to_resolution(&descr, &arrow_schema, &TimeUnit::Microsecond) + .unwrap(); + + let expected_schema = Schema::new(vec![ + Field::new("c0", DataType::Timestamp(TimeUnit::Microsecond, None), true), + Field::new_struct( + "c1", + vec![Field::new( + "c0", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + )], + true, + ), + Field::new_struct( + "c2", + vec![Field::new_list( + "c0", + Field::new( + "element", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + true, + )], + true, + ), + Field::new_list( + "c3", + Field::new( + "element", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + true, + ), + Field::new_list( + "c4", + Field::new_struct( + "element", + vec![ + Field::new( + "c0", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "c1", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + ], + true, + ), + true, + ), + Field::new_map( + "c5", + "key_value", + Field::new( + "key", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "value", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + false, + true, + ), + Field::new_list( + "c6", + Field::new_map( + "element", + "key_value", + Field::new( + "key", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "value", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + false, + true, + ), + true, + ), + ]); + + assert_eq!(result, expected_schema); + } +} diff --git a/datafusion/datasource-parquet/src/metrics.rs b/datafusion/datasource-parquet/src/metrics.rs index 3213d0201295..574fe2a040ea 100644 --- a/datafusion/datasource-parquet/src/metrics.rs +++ b/datafusion/datasource-parquet/src/metrics.rs @@ -27,6 +27,21 @@ use datafusion_physical_plan::metrics::{ /// [`ParquetFileReaderFactory`]: super::ParquetFileReaderFactory #[derive(Debug, Clone)] pub struct ParquetFileMetrics { + /// Number of file **ranges** pruned by partition or file level statistics. + /// Pruning of files often happens at planning time but may happen at execution time + /// if dynamic filters (e.g. from a join) result in additional pruning. + /// + /// This does **not** necessarily equal the number of files pruned: + /// files may be scanned in sub-ranges to increase parallelism, + /// in which case this will represent the number of sub-ranges pruned, not the number of files. + /// The number of files pruned will always be less than or equal to this number. + /// + /// A single file may have some ranges that are not pruned and some that are pruned. + /// For example, with a query like `ORDER BY col LIMIT 10`, the TopK dynamic filter + /// pushdown optimization may fill up the TopK heap when reading the first part of a file, + /// then skip the second part if file statistics indicate it cannot contain rows + /// that would be in the TopK. + pub files_ranges_pruned_statistics: Count, /// Number of times the predicate could not be evaluated pub predicate_evaluation_errors: Count, /// Number of row groups whose bloom filters were checked and matched (not pruned) @@ -122,7 +137,11 @@ impl ParquetFileMetrics { .with_new_label("filename", filename.to_string()) .subset_time("metadata_load_time", partition); + let files_ranges_pruned_statistics = MetricBuilder::new(metrics) + .counter("files_ranges_pruned_statistics", partition); + Self { + files_ranges_pruned_statistics, predicate_evaluation_errors, row_groups_matched_bloom_filter, row_groups_pruned_bloom_filter, diff --git a/datafusion/datasource-parquet/src/mod.rs b/datafusion/datasource-parquet/src/mod.rs index 516b13792189..0b4e86240383 100644 --- a/datafusion/datasource-parquet/src/mod.rs +++ b/datafusion/datasource-parquet/src/mod.rs @@ -19,8 +19,6 @@ // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] -//! [`ParquetExec`] FileSource for reading Parquet files - pub mod access_plan; pub mod file_format; mod metrics; @@ -32,28 +30,7 @@ mod row_group_filter; pub mod source; mod writer; -use std::any::Any; -use std::fmt::Formatter; -use std::sync::Arc; - pub use access_plan::{ParquetAccessPlan, RowGroupAccess}; -use arrow::datatypes::SchemaRef; -use datafusion_common::config::{ConfigOptions, TableParquetOptions}; -use datafusion_common::Result; -use datafusion_common::{Constraints, Statistics}; -use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_datasource::schema_adapter::SchemaAdapterFactory; -use datafusion_datasource::source::DataSourceExec; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{ - EquivalenceProperties, LexOrdering, Partitioning, PhysicalExpr, -}; -use datafusion_physical_optimizer::pruning::PruningPredicate; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; -use datafusion_physical_plan::metrics::MetricsSet; -use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, -}; pub use file_format::*; pub use metrics::ParquetFileMetrics; pub use page_filter::PagePruningAccessPlanFilter; @@ -61,491 +38,4 @@ pub use reader::{DefaultParquetFileReaderFactory, ParquetFileReaderFactory}; pub use row_filter::build_row_filter; pub use row_filter::can_expr_be_pushed_down_with_schemas; pub use row_group_filter::RowGroupAccessPlanFilter; -use source::ParquetSource; pub use writer::plan_to_parquet; - -use datafusion_datasource::file_groups::FileGroup; -use log::debug; - -#[derive(Debug, Clone)] -#[deprecated(since = "46.0.0", note = "use DataSourceExec instead")] -/// Deprecated Execution plan replaced with DataSourceExec -pub struct ParquetExec { - inner: DataSourceExec, - base_config: FileScanConfig, - table_parquet_options: TableParquetOptions, - /// Optional predicate for row filtering during parquet scan - predicate: Option>, - /// Optional predicate for pruning row groups (derived from `predicate`) - pruning_predicate: Option>, - /// Optional user defined parquet file reader factory - parquet_file_reader_factory: Option>, - /// Optional user defined schema adapter - schema_adapter_factory: Option>, -} - -#[allow(unused, deprecated)] -impl From for ParquetExecBuilder { - fn from(exec: ParquetExec) -> Self { - exec.into_builder() - } -} - -/// [`ParquetExecBuilder`], deprecated builder for [`ParquetExec`]. -/// -/// ParquetExec is replaced with `DataSourceExec` and it includes `ParquetSource` -/// -/// See example on [`ParquetSource`]. -#[deprecated( - since = "46.0.0", - note = "use DataSourceExec with ParquetSource instead" -)] -#[allow(unused, deprecated)] -pub struct ParquetExecBuilder { - file_scan_config: FileScanConfig, - predicate: Option>, - metadata_size_hint: Option, - table_parquet_options: TableParquetOptions, - parquet_file_reader_factory: Option>, - schema_adapter_factory: Option>, -} - -#[allow(unused, deprecated)] -impl ParquetExecBuilder { - /// Create a new builder to read the provided file scan configuration - pub fn new(file_scan_config: FileScanConfig) -> Self { - Self::new_with_options(file_scan_config, TableParquetOptions::default()) - } - - /// Create a new builder to read the data specified in the file scan - /// configuration with the provided `TableParquetOptions`. - pub fn new_with_options( - file_scan_config: FileScanConfig, - table_parquet_options: TableParquetOptions, - ) -> Self { - Self { - file_scan_config, - predicate: None, - metadata_size_hint: None, - table_parquet_options, - parquet_file_reader_factory: None, - schema_adapter_factory: None, - } - } - - /// Update the list of files groups to read - pub fn with_file_groups(mut self, file_groups: Vec) -> Self { - self.file_scan_config.file_groups = file_groups; - self - } - - /// Set the filter predicate when reading. - /// - /// See the "Predicate Pushdown" section of the [`ParquetExec`] documentation - /// for more details. - pub fn with_predicate(mut self, predicate: Arc) -> Self { - self.predicate = Some(predicate); - self - } - - /// Set the metadata size hint - /// - /// This value determines how many bytes at the end of the file the default - /// [`ParquetFileReaderFactory`] will request in the initial IO. If this is - /// too small, the ParquetExec will need to make additional IO requests to - /// read the footer. - pub fn with_metadata_size_hint(mut self, metadata_size_hint: usize) -> Self { - self.metadata_size_hint = Some(metadata_size_hint); - self - } - - /// Set the options for controlling how the ParquetExec reads parquet files. - /// - /// See also [`Self::new_with_options`] - pub fn with_table_parquet_options( - mut self, - table_parquet_options: TableParquetOptions, - ) -> Self { - self.table_parquet_options = table_parquet_options; - self - } - - /// Set optional user defined parquet file reader factory. - /// - /// You can use [`ParquetFileReaderFactory`] to more precisely control how - /// data is read from parquet files (e.g. skip re-reading metadata, coalesce - /// I/O operations, etc). - /// - /// The default reader factory reads directly from an [`ObjectStore`] - /// instance using individual I/O operations for the footer and each page. - /// - /// If a custom `ParquetFileReaderFactory` is provided, then data access - /// operations will be routed to this factory instead of [`ObjectStore`]. - /// - /// [`ObjectStore`]: object_store::ObjectStore - pub fn with_parquet_file_reader_factory( - mut self, - parquet_file_reader_factory: Arc, - ) -> Self { - self.parquet_file_reader_factory = Some(parquet_file_reader_factory); - self - } - - /// Set optional schema adapter factory. - /// - /// [`SchemaAdapterFactory`] allows user to specify how fields from the - /// parquet file get mapped to that of the table schema. The default schema - /// adapter uses arrow's cast library to map the parquet fields to the table - /// schema. - pub fn with_schema_adapter_factory( - mut self, - schema_adapter_factory: Arc, - ) -> Self { - self.schema_adapter_factory = Some(schema_adapter_factory); - self - } - - /// Convenience: build an `Arc`d `ParquetExec` from this builder - pub fn build_arc(self) -> Arc { - Arc::new(self.build()) - } - - /// Build a [`ParquetExec`] - #[must_use] - pub fn build(self) -> ParquetExec { - let Self { - file_scan_config, - predicate, - metadata_size_hint, - table_parquet_options, - parquet_file_reader_factory, - schema_adapter_factory, - } = self; - let mut parquet = ParquetSource::new(table_parquet_options); - if let Some(predicate) = predicate.clone() { - parquet = parquet - .with_predicate(Arc::clone(&file_scan_config.file_schema), predicate); - } - if let Some(metadata_size_hint) = metadata_size_hint { - parquet = parquet.with_metadata_size_hint(metadata_size_hint) - } - if let Some(parquet_reader_factory) = parquet_file_reader_factory { - parquet = parquet.with_parquet_file_reader_factory(parquet_reader_factory) - } - if let Some(schema_factory) = schema_adapter_factory { - parquet = parquet.with_schema_adapter_factory(schema_factory); - } - - let base_config = file_scan_config.with_source(Arc::new(parquet.clone())); - debug!("Creating ParquetExec, files: {:?}, projection {:?}, predicate: {:?}, limit: {:?}", - base_config.file_groups, base_config.projection, predicate, base_config.limit); - - ParquetExec { - inner: DataSourceExec::new(Arc::new(base_config.clone())), - base_config, - predicate, - pruning_predicate: parquet.pruning_predicate, - schema_adapter_factory: parquet.schema_adapter_factory, - parquet_file_reader_factory: parquet.parquet_file_reader_factory, - table_parquet_options: parquet.table_parquet_options, - } - } -} - -#[allow(unused, deprecated)] -impl ParquetExec { - /// Create a new Parquet reader execution plan provided file list and schema. - pub fn new( - base_config: FileScanConfig, - predicate: Option>, - metadata_size_hint: Option, - table_parquet_options: TableParquetOptions, - ) -> Self { - let mut builder = - ParquetExecBuilder::new_with_options(base_config, table_parquet_options); - if let Some(predicate) = predicate { - builder = builder.with_predicate(predicate); - } - if let Some(metadata_size_hint) = metadata_size_hint { - builder = builder.with_metadata_size_hint(metadata_size_hint); - } - builder.build() - } - /// Return a [`ParquetExecBuilder`]. - /// - /// See example on [`ParquetExec`] and [`ParquetExecBuilder`] for specifying - /// parquet table options. - pub fn builder(file_scan_config: FileScanConfig) -> ParquetExecBuilder { - ParquetExecBuilder::new(file_scan_config) - } - - /// Convert this `ParquetExec` into a builder for modification - pub fn into_builder(self) -> ParquetExecBuilder { - // list out fields so it is clear what is being dropped - // (note the fields which are dropped are re-created as part of calling - // `build` on the builder) - let file_scan_config = self.file_scan_config(); - let parquet = self.parquet_source(); - - ParquetExecBuilder { - file_scan_config, - predicate: parquet.predicate, - metadata_size_hint: parquet.metadata_size_hint, - table_parquet_options: parquet.table_parquet_options, - parquet_file_reader_factory: parquet.parquet_file_reader_factory, - schema_adapter_factory: parquet.schema_adapter_factory, - } - } - fn file_scan_config(&self) -> FileScanConfig { - self.inner - .data_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - fn parquet_source(&self) -> ParquetSource { - self.file_scan_config() - .file_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - /// [`FileScanConfig`] that controls this scan (such as which files to read) - pub fn base_config(&self) -> &FileScanConfig { - &self.base_config - } - /// Options passed to the parquet reader for this scan - pub fn table_parquet_options(&self) -> &TableParquetOptions { - &self.table_parquet_options - } - /// Optional predicate. - pub fn predicate(&self) -> Option<&Arc> { - self.predicate.as_ref() - } - /// Optional reference to this parquet scan's pruning predicate - pub fn pruning_predicate(&self) -> Option<&Arc> { - self.pruning_predicate.as_ref() - } - /// return the optional file reader factory - pub fn parquet_file_reader_factory( - &self, - ) -> Option<&Arc> { - self.parquet_file_reader_factory.as_ref() - } - /// Optional user defined parquet file reader factory. - pub fn with_parquet_file_reader_factory( - mut self, - parquet_file_reader_factory: Arc, - ) -> Self { - let mut parquet = self.parquet_source(); - parquet.parquet_file_reader_factory = - Some(Arc::clone(&parquet_file_reader_factory)); - let file_source = self.file_scan_config(); - self.inner = self - .inner - .with_data_source(Arc::new(file_source.with_source(Arc::new(parquet)))); - self.parquet_file_reader_factory = Some(parquet_file_reader_factory); - self - } - /// return the optional schema adapter factory - pub fn schema_adapter_factory(&self) -> Option<&Arc> { - self.schema_adapter_factory.as_ref() - } - /// Set optional schema adapter factory. - /// - /// [`SchemaAdapterFactory`] allows user to specify how fields from the - /// parquet file get mapped to that of the table schema. The default schema - /// adapter uses arrow's cast library to map the parquet fields to the table - /// schema. - pub fn with_schema_adapter_factory( - mut self, - schema_adapter_factory: Arc, - ) -> Self { - let mut parquet = self.parquet_source(); - parquet.schema_adapter_factory = Some(Arc::clone(&schema_adapter_factory)); - let file_source = self.file_scan_config(); - self.inner = self - .inner - .with_data_source(Arc::new(file_source.with_source(Arc::new(parquet)))); - self.schema_adapter_factory = Some(schema_adapter_factory); - self - } - /// If true, the predicate will be used during the parquet scan. - /// Defaults to false - /// - /// [`Expr`]: datafusion_expr::Expr - pub fn with_pushdown_filters(mut self, pushdown_filters: bool) -> Self { - let mut parquet = self.parquet_source(); - parquet.table_parquet_options.global.pushdown_filters = pushdown_filters; - let file_source = self.file_scan_config(); - self.inner = self - .inner - .with_data_source(Arc::new(file_source.with_source(Arc::new(parquet)))); - self.table_parquet_options.global.pushdown_filters = pushdown_filters; - self - } - - /// Return the value described in [`Self::with_pushdown_filters`] - fn pushdown_filters(&self) -> bool { - self.parquet_source() - .table_parquet_options - .global - .pushdown_filters - } - /// If true, the `RowFilter` made by `pushdown_filters` may try to - /// minimize the cost of filter evaluation by reordering the - /// predicate [`Expr`]s. If false, the predicates are applied in - /// the same order as specified in the query. Defaults to false. - /// - /// [`Expr`]: datafusion_expr::Expr - pub fn with_reorder_filters(mut self, reorder_filters: bool) -> Self { - let mut parquet = self.parquet_source(); - parquet.table_parquet_options.global.reorder_filters = reorder_filters; - let file_source = self.file_scan_config(); - self.inner = self - .inner - .with_data_source(Arc::new(file_source.with_source(Arc::new(parquet)))); - self.table_parquet_options.global.reorder_filters = reorder_filters; - self - } - /// Return the value described in [`Self::with_reorder_filters`] - fn reorder_filters(&self) -> bool { - self.parquet_source() - .table_parquet_options - .global - .reorder_filters - } - /// If enabled, the reader will read the page index - /// This is used to optimize filter pushdown - /// via `RowSelector` and `RowFilter` by - /// eliminating unnecessary IO and decoding - fn bloom_filter_on_read(&self) -> bool { - self.parquet_source() - .table_parquet_options - .global - .bloom_filter_on_read - } - /// Return the value described in [`ParquetSource::with_enable_page_index`] - fn enable_page_index(&self) -> bool { - self.parquet_source() - .table_parquet_options - .global - .enable_page_index - } - - fn output_partitioning_helper(file_config: &FileScanConfig) -> Partitioning { - Partitioning::UnknownPartitioning(file_config.file_groups.len()) - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - orderings: &[LexOrdering], - constraints: Constraints, - file_config: &FileScanConfig, - ) -> PlanProperties { - PlanProperties::new( - EquivalenceProperties::new_with_orderings(schema, orderings) - .with_constraints(constraints), - Self::output_partitioning_helper(file_config), // Output Partitioning - EmissionType::Incremental, - Boundedness::Bounded, - ) - } - - /// Updates the file groups to read and recalculates the output partitioning - /// - /// Note this function does not update statistics or other properties - /// that depend on the file groups. - fn with_file_groups_and_update_partitioning( - mut self, - file_groups: Vec, - ) -> Self { - let mut config = self.file_scan_config(); - config.file_groups = file_groups; - self.inner = self.inner.with_data_source(Arc::new(config)); - self - } -} - -#[allow(unused, deprecated)] -impl DisplayAs for ParquetExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - self.inner.fmt_as(t, f) - } -} - -#[allow(unused, deprecated)] -impl ExecutionPlan for ParquetExec { - fn name(&self) -> &'static str { - "ParquetExec" - } - - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - self.inner.properties() - } - - fn children(&self) -> Vec<&Arc> { - // this is a leaf node and has no children - vec![] - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - /// Redistribute files across partitions according to their size - /// See comments on `FileGroupPartitioner` for more detail. - fn repartitioned( - &self, - target_partitions: usize, - config: &ConfigOptions, - ) -> Result>> { - self.inner.repartitioned(target_partitions, config) - } - - fn execute( - &self, - partition_index: usize, - ctx: Arc, - ) -> Result { - self.inner.execute(partition_index, ctx) - } - fn metrics(&self) -> Option { - self.inner.metrics() - } - fn statistics(&self) -> Result { - self.inner.statistics() - } - fn fetch(&self) -> Option { - self.inner.fetch() - } - - fn with_fetch(&self, limit: Option) -> Option> { - self.inner.with_fetch(limit) - } -} - -fn should_enable_page_index( - enable_page_index: bool, - page_pruning_predicate: &Option>, -) -> bool { - enable_page_index - && page_pruning_predicate.is_some() - && page_pruning_predicate - .as_ref() - .map(|p| p.filter_number() > 0) - .unwrap_or(false) -} diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index cfe8213f86e4..561be82cf75d 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -23,25 +23,29 @@ use crate::page_filter::PagePruningAccessPlanFilter; use crate::row_group_filter::RowGroupAccessPlanFilter; use crate::{ apply_file_schema_type_coercions, coerce_int96_to_resolution, row_filter, - should_enable_page_index, ParquetAccessPlan, ParquetFileMetrics, - ParquetFileReaderFactory, + ParquetAccessPlan, ParquetFileMetrics, ParquetFileReaderFactory, }; use datafusion_datasource::file_meta::FileMeta; use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener}; use datafusion_datasource::schema_adapter::SchemaAdapterFactory; -use arrow::datatypes::{SchemaRef, TimeUnit}; +use arrow::datatypes::{FieldRef, SchemaRef, TimeUnit}; use arrow::error::ArrowError; -use datafusion_common::{exec_err, Result}; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_optimizer::pruning::PruningPredicate; +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_datasource::PartitionedFile; +use datafusion_physical_expr::PhysicalExprSchemaRewriter; +use datafusion_physical_expr_common::physical_expr::{ + is_dynamic_physical_expr, PhysicalExpr, +}; use datafusion_physical_plan::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder}; +use datafusion_pruning::{build_pruning_predicate, FilePruner, PruningPredicate}; use futures::{StreamExt, TryStreamExt}; use log::debug; use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions}; use parquet::arrow::async_reader::AsyncFileReader; use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask}; +use parquet::encryption::decrypt::FileDecryptionProperties; use parquet::file::metadata::ParquetMetaDataReader; /// Implements [`FileOpener`] for a parquet file @@ -56,8 +60,11 @@ pub(super) struct ParquetOpener { pub limit: Option, /// Optional predicate to apply during the scan pub predicate: Option>, - /// Schema of the output table - pub table_schema: SchemaRef, + /// Schema of the output table without partition columns. + /// This is the schema we coerce the physical file schema into. + pub logical_file_schema: SchemaRef, + /// Partition columns + pub partition_fields: Vec, /// Optional hint for how large the initial request to read parquet metadata /// should be pub metadata_size_hint: Option, @@ -82,10 +89,12 @@ pub(super) struct ParquetOpener { pub enable_row_group_stats_pruning: bool, /// Coerce INT96 timestamps to specific TimeUnit pub coerce_int96: Option, + /// Optional parquet FileDecryptionProperties + pub file_decryption_properties: Option>, } impl FileOpener for ParquetOpener { - fn open(&self, file_meta: FileMeta) -> Result { + fn open(&self, file_meta: FileMeta, file: PartitionedFile) -> Result { let file_range = file_meta.range.clone(); let extensions = file_meta.extensions.clone(); let file_name = file_meta.location().to_string(); @@ -97,7 +106,7 @@ impl FileOpener for ParquetOpener { let mut async_file_reader: Box = self.parquet_file_reader_factory.create_reader( self.partition_index, - file_meta, + file_meta.clone(), metadata_size_hint, &self.metrics, )?; @@ -105,13 +114,13 @@ impl FileOpener for ParquetOpener { let batch_size = self.batch_size; let projected_schema = - SchemaRef::from(self.table_schema.project(&self.projection)?); - let schema_adapter_factory = Arc::clone(&self.schema_adapter_factory); + SchemaRef::from(self.logical_file_schema.project(&self.projection)?); let schema_adapter = self .schema_adapter_factory - .create(projected_schema, Arc::clone(&self.table_schema)); + .create(projected_schema, Arc::clone(&self.logical_file_schema)); let predicate = self.predicate.clone(); - let table_schema = Arc::clone(&self.table_schema); + let logical_file_schema = Arc::clone(&self.logical_file_schema); + let partition_fields = self.partition_fields.clone(); let reorder_predicates = self.reorder_filters; let pushdown_filters = self.pushdown_filters; let coerce_int96 = self.coerce_int96; @@ -122,15 +131,59 @@ impl FileOpener for ParquetOpener { let predicate_creation_errors = MetricBuilder::new(&self.metrics) .global_counter("num_predicate_creation_errors"); - let enable_page_index = self.enable_page_index; + let mut enable_page_index = self.enable_page_index; + let file_decryption_properties = self.file_decryption_properties.clone(); + + // For now, page index does not work with encrypted files. See: + // https://github.com/apache/arrow-rs/issues/7629 + if file_decryption_properties.is_some() { + enable_page_index = false; + } Ok(Box::pin(async move { + // Prune this file using the file level statistics and partition values. + // Since dynamic filters may have been updated since planning it is possible that we are able + // to prune files now that we couldn't prune at planning time. + // It is assumed that there is no point in doing pruning here if the predicate is not dynamic, + // as it would have been done at planning time. + // We'll also check this after every record batch we read, + // and if at some point we are able to prove we can prune the file using just the file level statistics + // we can end the stream early. + let mut file_pruner = predicate + .as_ref() + .map(|p| { + Ok::<_, DataFusionError>( + (is_dynamic_physical_expr(p) | file.has_statistics()).then_some( + FilePruner::new( + Arc::clone(p), + &logical_file_schema, + partition_fields.clone(), + file.clone(), + predicate_creation_errors.clone(), + )?, + ), + ) + }) + .transpose()? + .flatten(); + + if let Some(file_pruner) = &mut file_pruner { + if file_pruner.should_prune()? { + // Return an empty stream immediately to skip the work of setting up the actual stream + file_metrics.files_ranges_pruned_statistics.add(1); + return Ok(futures::stream::empty().boxed()); + } + } + // Don't load the page index yet. Since it is not stored inline in // the footer, loading the page index if it is not needed will do // unecessary I/O. We decide later if it is needed to evaluate the // pruning predicates. Thus default to not requesting if from the // underlying reader. let mut options = ArrowReaderOptions::new().with_page_index(false); + if let Some(fd_val) = file_decryption_properties { + options = options.with_file_decryption_properties((*fd_val).clone()); + } let mut metadata_timer = file_metrics.metadata_load_time.timer(); // Begin by loading the metadata from the underlying reader (note @@ -142,17 +195,20 @@ impl FileOpener for ParquetOpener { .await?; // Note about schemas: we are actually dealing with **3 different schemas** here: - // - The table schema as defined by the TableProvider. This is what the user sees, what they get when they `SELECT * FROM table`, etc. - // - The "virtual" file schema: this is the table schema minus any hive partition columns and projections. This is what the file schema is coerced to. + // - The table schema as defined by the TableProvider. + // This is what the user sees, what they get when they `SELECT * FROM table`, etc. + // - The logical file schema: this is the table schema minus any hive partition columns and projections. + // This is what the physicalfile schema is coerced to. // - The physical file schema: this is the schema as defined by the parquet file. This is what the parquet file actually contains. let mut physical_file_schema = Arc::clone(reader_metadata.schema()); // The schema loaded from the file may not be the same as the // desired schema (for example if we want to instruct the parquet // reader to read strings using Utf8View instead). Update if necessary - if let Some(merged) = - apply_file_schema_type_coercions(&table_schema, &physical_file_schema) - { + if let Some(merged) = apply_file_schema_type_coercions( + &logical_file_schema, + &physical_file_schema, + ) { physical_file_schema = Arc::new(merged); options = options.with_schema(Arc::clone(&physical_file_schema)); reader_metadata = ArrowReaderMetadata::try_new( @@ -176,9 +232,26 @@ impl FileOpener for ParquetOpener { } } + // Adapt the predicate to the physical file schema. + // This evaluates missing columns and inserts any necessary casts. + let predicate = predicate + .map(|p| { + PhysicalExprSchemaRewriter::new( + &physical_file_schema, + &logical_file_schema, + ) + .with_partition_columns( + partition_fields.to_vec(), + file.partition_values, + ) + .rewrite(p) + .map_err(ArrowError::from) + }) + .transpose()?; + // Build predicates for this specific file let (pruning_predicate, page_pruning_predicate) = build_pruning_predicates( - &predicate, + predicate.as_ref(), &physical_file_schema, &predicate_creation_errors, ); @@ -216,11 +289,9 @@ impl FileOpener for ParquetOpener { let row_filter = row_filter::build_row_filter( &predicate, &physical_file_schema, - &table_schema, builder.metadata(), reorder_predicates, &file_metrics, - &schema_adapter_factory, ); match row_filter { @@ -230,8 +301,7 @@ impl FileOpener for ParquetOpener { Ok(None) => {} Err(e) => { debug!( - "Ignoring error building row filter for '{:?}': {}", - predicate, e + "Ignoring error building row filter for '{predicate:?}': {e}" ); } }; @@ -353,30 +423,6 @@ fn create_initial_plan( Ok(ParquetAccessPlan::new_all(row_group_count)) } -/// Build a pruning predicate from an optional predicate expression. -/// If the predicate is None or the predicate cannot be converted to a pruning -/// predicate, return None. -/// If there is an error creating the pruning predicate it is recorded by incrementing -/// the `predicate_creation_errors` counter. -pub(crate) fn build_pruning_predicate( - predicate: Arc, - file_schema: &SchemaRef, - predicate_creation_errors: &Count, -) -> Option> { - match PruningPredicate::try_new(predicate, Arc::clone(file_schema)) { - Ok(pruning_predicate) => { - if !pruning_predicate.always_true() { - return Some(Arc::new(pruning_predicate)); - } - } - Err(e) => { - debug!("Could not create pruning predicate for: {e}"); - predicate_creation_errors.add(1); - } - } - None -} - /// Build a page pruning predicate from an optional predicate expression. /// If the predicate is None or the predicate cannot be converted to a page pruning /// predicate, return None. @@ -390,8 +436,8 @@ pub(crate) fn build_page_pruning_predicate( )) } -fn build_pruning_predicates( - predicate: &Option>, +pub(crate) fn build_pruning_predicates( + predicate: Option<&Arc>, file_schema: &SchemaRef, predicate_creation_errors: &Count, ) -> ( @@ -440,3 +486,589 @@ async fn load_page_index( Ok(reader_metadata) } } + +fn should_enable_page_index( + enable_page_index: bool, + page_pruning_predicate: &Option>, +) -> bool { + enable_page_index + && page_pruning_predicate.is_some() + && page_pruning_predicate + .as_ref() + .map(|p| p.filter_number() > 0) + .unwrap_or(false) +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; + use bytes::{BufMut, BytesMut}; + use chrono::Utc; + use datafusion_common::{ + record_batch, stats::Precision, ColumnStatistics, ScalarValue, Statistics, + }; + use datafusion_datasource::{ + file_meta::FileMeta, file_stream::FileOpener, + schema_adapter::DefaultSchemaAdapterFactory, PartitionedFile, + }; + use datafusion_expr::{col, lit}; + use datafusion_physical_expr::{ + expressions::DynamicFilterPhysicalExpr, planner::logical2physical, PhysicalExpr, + }; + use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; + use futures::{Stream, StreamExt}; + use object_store::{memory::InMemory, path::Path, ObjectMeta, ObjectStore}; + use parquet::arrow::ArrowWriter; + + use crate::{opener::ParquetOpener, DefaultParquetFileReaderFactory}; + + async fn count_batches_and_rows( + mut stream: std::pin::Pin< + Box< + dyn Stream< + Item = Result< + arrow::array::RecordBatch, + arrow::error::ArrowError, + >, + > + Send, + >, + >, + ) -> (usize, usize) { + let mut num_batches = 0; + let mut num_rows = 0; + while let Some(Ok(batch)) = stream.next().await { + num_rows += batch.num_rows(); + num_batches += 1; + } + (num_batches, num_rows) + } + + async fn write_parquet( + store: Arc, + filename: &str, + batch: arrow::record_batch::RecordBatch, + ) -> usize { + let mut out = BytesMut::new().writer(); + { + let mut writer = + ArrowWriter::try_new(&mut out, batch.schema(), None).unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + } + let data = out.into_inner().freeze(); + let data_len = data.len(); + store.put(&Path::from(filename), data.into()).await.unwrap(); + data_len + } + + fn make_dynamic_expr(expr: Arc) -> Arc { + Arc::new(DynamicFilterPhysicalExpr::new( + expr.children().into_iter().map(Arc::clone).collect(), + expr, + )) + } + + #[tokio::test] + async fn test_prune_on_statistics() { + let store = Arc::new(InMemory::new()) as Arc; + + let batch = record_batch!( + ("a", Int32, vec![Some(1), Some(2), Some(2)]), + ("b", Float32, vec![Some(1.0), Some(2.0), None]) + ) + .unwrap(); + + let data_size = + write_parquet(Arc::clone(&store), "test.parquet", batch.clone()).await; + + let schema = batch.schema(); + let file = PartitionedFile::new( + "file.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ) + .with_statistics(Arc::new( + Statistics::new_unknown(&schema) + .add_column_statistics(ColumnStatistics::new_unknown()) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Float32(Some(1.0)))) + .with_max_value(Precision::Exact(ScalarValue::Float32(Some(2.0)))) + .with_null_count(Precision::Exact(1)), + ), + )); + + let make_opener = |predicate| { + ParquetOpener { + partition_index: 0, + projection: Arc::new([0, 1]), + batch_size: 1024, + limit: None, + predicate: Some(predicate), + logical_file_schema: schema.clone(), + metadata_size_hint: None, + metrics: ExecutionPlanMetricsSet::new(), + parquet_file_reader_factory: Arc::new( + DefaultParquetFileReaderFactory::new(Arc::clone(&store)), + ), + partition_fields: vec![], + pushdown_filters: false, // note that this is false! + reorder_filters: false, + enable_page_index: false, + enable_bloom_filter: false, + schema_adapter_factory: Arc::new(DefaultSchemaAdapterFactory), + enable_row_group_stats_pruning: true, + coerce_int96: None, + file_decryption_properties: None, + } + }; + + let make_meta = || FileMeta { + object_meta: ObjectMeta { + location: Path::from("test.parquet"), + last_modified: Utc::now(), + size: u64::try_from(data_size).unwrap(), + e_tag: None, + version: None, + }, + range: None, + extensions: None, + metadata_size_hint: None, + }; + + // A filter on "a" should not exclude any rows even if it matches the data + let expr = col("a").eq(lit(1)); + let predicate = logical2physical(&expr, &schema); + let opener = make_opener(predicate); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // A filter on `b = 5.0` should exclude all rows + let expr = col("b").eq(lit(ScalarValue::Float32(Some(5.0)))); + let predicate = logical2physical(&expr, &schema); + let opener = make_opener(predicate); + let stream = opener.open(make_meta(), file).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + } + + #[tokio::test] + async fn test_prune_on_partition_statistics_with_dynamic_expression() { + let store = Arc::new(InMemory::new()) as Arc; + + let batch = record_batch!(("a", Int32, vec![Some(1), Some(2), Some(3)])).unwrap(); + let data_size = + write_parquet(Arc::clone(&store), "part=1/file.parquet", batch.clone()).await; + + let file_schema = batch.schema(); + let mut file = PartitionedFile::new( + "part=1/file.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + file.partition_values = vec![ScalarValue::Int32(Some(1))]; + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("part", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + + let make_opener = |predicate| { + ParquetOpener { + partition_index: 0, + projection: Arc::new([0]), + batch_size: 1024, + limit: None, + predicate: Some(predicate), + logical_file_schema: file_schema.clone(), + metadata_size_hint: None, + metrics: ExecutionPlanMetricsSet::new(), + parquet_file_reader_factory: Arc::new( + DefaultParquetFileReaderFactory::new(Arc::clone(&store)), + ), + partition_fields: vec![Arc::new(Field::new( + "part", + DataType::Int32, + false, + ))], + pushdown_filters: false, // note that this is false! + reorder_filters: false, + enable_page_index: false, + enable_bloom_filter: false, + schema_adapter_factory: Arc::new(DefaultSchemaAdapterFactory), + enable_row_group_stats_pruning: true, + coerce_int96: None, + file_decryption_properties: None, + } + }; + + let make_meta = || FileMeta { + object_meta: ObjectMeta { + location: Path::from("part=1/file.parquet"), + last_modified: Utc::now(), + size: u64::try_from(data_size).unwrap(), + e_tag: None, + version: None, + }, + range: None, + extensions: None, + metadata_size_hint: None, + }; + + // Filter should match the partition value + let expr = col("part").eq(lit(1)); + // Mark the expression as dynamic even if it's not to force partition pruning to happen + // Otherwise we assume it already happened at the planning stage and won't re-do the work here + let predicate = make_dynamic_expr(logical2physical(&expr, &table_schema)); + let opener = make_opener(predicate); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // Filter should not match the partition value + let expr = col("part").eq(lit(2)); + // Mark the expression as dynamic even if it's not to force partition pruning to happen + // Otherwise we assume it already happened at the planning stage and won't re-do the work here + let predicate = make_dynamic_expr(logical2physical(&expr, &table_schema)); + let opener = make_opener(predicate); + let stream = opener.open(make_meta(), file).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + } + + #[tokio::test] + async fn test_prune_on_partition_values_and_file_statistics() { + let store = Arc::new(InMemory::new()) as Arc; + + let batch = record_batch!( + ("a", Int32, vec![Some(1), Some(2), Some(3)]), + ("b", Float64, vec![Some(1.0), Some(2.0), None]) + ) + .unwrap(); + let data_size = + write_parquet(Arc::clone(&store), "part=1/file.parquet", batch.clone()).await; + let file_schema = batch.schema(); + let mut file = PartitionedFile::new( + "part=1/file.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + file.partition_values = vec![ScalarValue::Int32(Some(1))]; + file.statistics = Some(Arc::new( + Statistics::new_unknown(&file_schema) + .add_column_statistics(ColumnStatistics::new_unknown()) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Float64(Some(1.0)))) + .with_max_value(Precision::Exact(ScalarValue::Float64(Some(2.0)))) + .with_null_count(Precision::Exact(1)), + ), + )); + let table_schema = Arc::new(Schema::new(vec![ + Field::new("part", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Float32, true), + ])); + let make_opener = |predicate| { + ParquetOpener { + partition_index: 0, + projection: Arc::new([0]), + batch_size: 1024, + limit: None, + predicate: Some(predicate), + logical_file_schema: file_schema.clone(), + metadata_size_hint: None, + metrics: ExecutionPlanMetricsSet::new(), + parquet_file_reader_factory: Arc::new( + DefaultParquetFileReaderFactory::new(Arc::clone(&store)), + ), + partition_fields: vec![Arc::new(Field::new( + "part", + DataType::Int32, + false, + ))], + pushdown_filters: false, // note that this is false! + reorder_filters: false, + enable_page_index: false, + enable_bloom_filter: false, + schema_adapter_factory: Arc::new(DefaultSchemaAdapterFactory), + enable_row_group_stats_pruning: true, + coerce_int96: None, + file_decryption_properties: None, + } + }; + let make_meta = || FileMeta { + object_meta: ObjectMeta { + location: Path::from("part=1/file.parquet"), + last_modified: Utc::now(), + size: u64::try_from(data_size).unwrap(), + e_tag: None, + version: None, + }, + range: None, + extensions: None, + metadata_size_hint: None, + }; + + // Filter should match the partition value and file statistics + let expr = col("part").eq(lit(1)).and(col("b").eq(lit(1.0))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // Should prune based on partition value but not file statistics + let expr = col("part").eq(lit(2)).and(col("b").eq(lit(1.0))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + + // Should prune based on file statistics but not partition value + let expr = col("part").eq(lit(1)).and(col("b").eq(lit(7.0))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + + // Should prune based on both partition value and file statistics + let expr = col("part").eq(lit(2)).and(col("b").eq(lit(7.0))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener.open(make_meta(), file).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + } + + #[tokio::test] + async fn test_prune_on_partition_value_and_data_value() { + let store = Arc::new(InMemory::new()) as Arc; + + // Note: number 3 is missing! + let batch = record_batch!(("a", Int32, vec![Some(1), Some(2), Some(4)])).unwrap(); + let data_size = + write_parquet(Arc::clone(&store), "part=1/file.parquet", batch.clone()).await; + + let file_schema = batch.schema(); + let mut file = PartitionedFile::new( + "part=1/file.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + file.partition_values = vec![ScalarValue::Int32(Some(1))]; + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("part", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + + let make_opener = |predicate| { + ParquetOpener { + partition_index: 0, + projection: Arc::new([0]), + batch_size: 1024, + limit: None, + predicate: Some(predicate), + logical_file_schema: file_schema.clone(), + metadata_size_hint: None, + metrics: ExecutionPlanMetricsSet::new(), + parquet_file_reader_factory: Arc::new( + DefaultParquetFileReaderFactory::new(Arc::clone(&store)), + ), + partition_fields: vec![Arc::new(Field::new( + "part", + DataType::Int32, + false, + ))], + pushdown_filters: true, // note that this is true! + reorder_filters: true, + enable_page_index: false, + enable_bloom_filter: false, + schema_adapter_factory: Arc::new(DefaultSchemaAdapterFactory), + enable_row_group_stats_pruning: false, // note that this is false! + coerce_int96: None, + file_decryption_properties: None, + } + }; + + let make_meta = || FileMeta { + object_meta: ObjectMeta { + location: Path::from("part=1/file.parquet"), + last_modified: Utc::now(), + size: u64::try_from(data_size).unwrap(), + e_tag: None, + version: None, + }, + range: None, + extensions: None, + metadata_size_hint: None, + }; + + // Filter should match the partition value and data value + let expr = col("part").eq(lit(1)).or(col("a").eq(lit(1))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // Filter should match the partition value but not the data value + let expr = col("part").eq(lit(1)).or(col("a").eq(lit(3))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // Filter should not match the partition value but match the data value + let expr = col("part").eq(lit(2)).or(col("a").eq(lit(1))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 1); + + // Filter should not match the partition value or the data value + let expr = col("part").eq(lit(2)).or(col("a").eq(lit(3))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener.open(make_meta(), file).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + } + + /// Test that if the filter is not a dynamic filter and we have no stats we don't do extra pruning work at the file level. + #[tokio::test] + async fn test_opener_pruning_skipped_on_static_filters() { + let store = Arc::new(InMemory::new()) as Arc; + + let batch = record_batch!(("a", Int32, vec![Some(1), Some(2), Some(3)])).unwrap(); + let data_size = + write_parquet(Arc::clone(&store), "part=1/file.parquet", batch.clone()).await; + + let file_schema = batch.schema(); + let mut file = PartitionedFile::new( + "part=1/file.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + file.partition_values = vec![ScalarValue::Int32(Some(1))]; + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("part", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + + let make_opener = |predicate| { + ParquetOpener { + partition_index: 0, + projection: Arc::new([0]), + batch_size: 1024, + limit: None, + predicate: Some(predicate), + logical_file_schema: file_schema.clone(), + metadata_size_hint: None, + metrics: ExecutionPlanMetricsSet::new(), + parquet_file_reader_factory: Arc::new( + DefaultParquetFileReaderFactory::new(Arc::clone(&store)), + ), + partition_fields: vec![Arc::new(Field::new( + "part", + DataType::Int32, + false, + ))], + pushdown_filters: false, // note that this is false! + reorder_filters: false, + enable_page_index: false, + enable_bloom_filter: false, + schema_adapter_factory: Arc::new(DefaultSchemaAdapterFactory), + enable_row_group_stats_pruning: true, + coerce_int96: None, + file_decryption_properties: None, + } + }; + + let make_meta = || FileMeta { + object_meta: ObjectMeta { + location: Path::from("part=1/file.parquet"), + last_modified: Utc::now(), + size: u64::try_from(data_size).unwrap(), + e_tag: None, + version: None, + }, + range: None, + extensions: None, + metadata_size_hint: None, + }; + + // Filter should NOT match the stats but the file is never attempted to be pruned because the filters are not dynamic + let expr = col("part").eq(lit(2)); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // If we make the filter dynamic, it should prune + let predicate = make_dynamic_expr(logical2physical(&expr, &table_schema)); + let opener = make_opener(predicate); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + } +} diff --git a/datafusion/datasource-parquet/src/page_filter.rs b/datafusion/datasource-parquet/src/page_filter.rs index 148527998ab5..5f3e05747d40 100644 --- a/datafusion/datasource-parquet/src/page_filter.rs +++ b/datafusion/datasource-parquet/src/page_filter.rs @@ -28,9 +28,10 @@ use arrow::{ array::ArrayRef, datatypes::{Schema, SchemaRef}, }; +use datafusion_common::pruning::PruningStatistics; use datafusion_common::ScalarValue; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; -use datafusion_physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use datafusion_pruning::PruningPredicate; use log::{debug, trace}; use parquet::arrow::arrow_reader::statistics::StatisticsConverter; @@ -333,7 +334,7 @@ fn prune_pages_in_one_row_group( assert_eq!(page_row_counts.len(), values.len()); let mut sum_row = *page_row_counts.first().unwrap(); let mut selected = *values.first().unwrap(); - trace!("Pruned to {:?} using {:?}", values, pruning_stats); + trace!("Pruned to {values:?} using {pruning_stats:?}"); for (i, &f) in values.iter().enumerate().skip(1) { if f == selected { sum_row += *page_row_counts.get(i).unwrap(); diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index 2d2993c29a6f..5626f83186e3 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -67,6 +67,7 @@ use arrow::array::BooleanArray; use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; +use itertools::Itertools; use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter}; use parquet::arrow::ProjectionMask; use parquet::file::metadata::ParquetMetaData; @@ -74,9 +75,8 @@ use parquet::file::metadata::ParquetMetaData; use datafusion_common::cast::as_boolean_array; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::Result; -use datafusion_datasource::schema_adapter::{SchemaAdapterFactory, SchemaMapper}; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::utils::reassign_predicate_columns; +use datafusion_physical_expr::utils::{collect_columns, reassign_predicate_columns}; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; use datafusion_physical_plan::metrics; @@ -106,8 +106,6 @@ pub(crate) struct DatafusionArrowPredicate { rows_matched: metrics::Count, /// how long was spent evaluating this predicate time: metrics::Time, - /// used to perform type coercion while filtering rows - schema_mapper: Arc, } impl DatafusionArrowPredicate { @@ -132,7 +130,6 @@ impl DatafusionArrowPredicate { rows_pruned, rows_matched, time, - schema_mapper: candidate.schema_mapper, }) } } @@ -143,8 +140,6 @@ impl ArrowPredicate for DatafusionArrowPredicate { } fn evaluate(&mut self, batch: RecordBatch) -> ArrowResult { - let batch = self.schema_mapper.map_batch(batch)?; - // scoped timer updates on drop let mut timer = self.time.timer(); @@ -187,9 +182,6 @@ pub(crate) struct FilterCandidate { /// required to pass thorugh a `SchemaMapper` to the table schema /// upon which we then evaluate the filter expression. projection: Vec, - /// A `SchemaMapper` used to map batches read from the file schema to - /// the filter's projection of the table schema. - schema_mapper: Arc, /// The projected table schema that this filter references filter_schema: SchemaRef, } @@ -230,26 +222,11 @@ struct FilterCandidateBuilder { /// columns in the file schema that are not in the table schema or columns that /// are in the table schema that are not in the file schema. file_schema: SchemaRef, - /// The schema of the table (merged schema) -- columns may be in different - /// order than in the file and have columns that are not in the file schema - table_schema: SchemaRef, - /// A `SchemaAdapterFactory` used to map the file schema to the table schema. - schema_adapter_factory: Arc, } impl FilterCandidateBuilder { - pub fn new( - expr: Arc, - file_schema: Arc, - table_schema: Arc, - schema_adapter_factory: Arc, - ) -> Self { - Self { - expr, - file_schema, - table_schema, - schema_adapter_factory, - } + pub fn new(expr: Arc, file_schema: Arc) -> Self { + Self { expr, file_schema } } /// Attempt to build a `FilterCandidate` from the expression @@ -261,20 +238,21 @@ impl FilterCandidateBuilder { /// * `Err(e)` if an error occurs while building the candidate pub fn build(self, metadata: &ParquetMetaData) -> Result> { let Some(required_indices_into_table_schema) = - pushdown_columns(&self.expr, &self.table_schema)? + pushdown_columns(&self.expr, &self.file_schema)? else { return Ok(None); }; let projected_table_schema = Arc::new( - self.table_schema + self.file_schema .project(&required_indices_into_table_schema)?, ); - let (schema_mapper, projection_into_file_schema) = self - .schema_adapter_factory - .create(Arc::clone(&projected_table_schema), self.table_schema) - .map_schema(&self.file_schema)?; + let projection_into_file_schema = collect_columns(&self.expr) + .iter() + .map(|c| c.index()) + .sorted_unstable() + .collect_vec(); let required_bytes = size_of_columns(&projection_into_file_schema, metadata)?; let can_use_index = columns_sorted(&projection_into_file_schema, metadata)?; @@ -284,7 +262,6 @@ impl FilterCandidateBuilder { required_bytes, can_use_index, projection: projection_into_file_schema, - schema_mapper: Arc::clone(&schema_mapper), filter_schema: Arc::clone(&projected_table_schema), })) } @@ -299,6 +276,7 @@ struct PushdownChecker<'schema> { non_primitive_columns: bool, /// Does the expression reference any columns that are in the table /// schema but not in the file schema? + /// This includes partition columns and projected columns. projected_columns: bool, // Indices into the table schema of the columns required to evaluate the expression required_columns: BTreeSet, @@ -366,44 +344,19 @@ fn pushdown_columns( .then_some(checker.required_columns.into_iter().collect())) } -/// creates a PushdownChecker for a single use to check a given column with the given schemes. Used -/// to check preemptively if a column name would prevent pushdowning. -/// effectively does the inverse of [`pushdown_columns`] does, but with a single given column -/// (instead of traversing the entire tree to determine this) -fn would_column_prevent_pushdown(column_name: &str, table_schema: &Schema) -> bool { - let mut checker = PushdownChecker::new(table_schema); - - // the return of this is only used for [`PushdownChecker::f_down()`], so we can safely ignore - // it here. I'm just verifying we know the return type of this so nobody accidentally changes - // the return type of this fn and it gets implicitly ignored here. - let _: Option = checker.check_single_column(column_name); - - // and then return a value based on the state of the checker - checker.prevents_pushdown() -} - /// Recurses through expr as a tree, finds all `column`s, and checks if any of them would prevent /// this expression from being predicate pushed down. If any of them would, this returns false. /// Otherwise, true. +/// Note that the schema passed in here is *not* the physical file schema (as it is not available at that point in time); +/// it is the schema of the table that this expression is being evaluated against minus any projected columns and partition columns. pub fn can_expr_be_pushed_down_with_schemas( - expr: &datafusion_expr::Expr, - _file_schema: &Schema, - table_schema: &Schema, + expr: &Arc, + file_schema: &Schema, ) -> bool { - let mut can_be_pushed = true; - expr.apply(|expr| match expr { - datafusion_expr::Expr::Column(column) => { - can_be_pushed &= !would_column_prevent_pushdown(column.name(), table_schema); - Ok(if can_be_pushed { - TreeNodeRecursion::Jump - } else { - TreeNodeRecursion::Stop - }) - } - _ => Ok(TreeNodeRecursion::Continue), - }) - .unwrap(); // we never return an Err, so we can safely unwrap this - can_be_pushed + match pushdown_columns(expr, file_schema) { + Ok(Some(_)) => true, + Ok(None) | Err(_) => false, + } } /// Calculate the total compressed size of all `Column`'s required for @@ -450,11 +403,9 @@ fn columns_sorted(_columns: &[usize], _metadata: &ParquetMetaData) -> Result, physical_file_schema: &SchemaRef, - table_schema: &SchemaRef, metadata: &ParquetMetaData, reorder_predicates: bool, file_metrics: &ParquetFileMetrics, - schema_adapter_factory: &Arc, ) -> Result> { let rows_pruned = &file_metrics.pushdown_rows_pruned; let rows_matched = &file_metrics.pushdown_rows_matched; @@ -471,8 +422,6 @@ pub fn build_row_filter( FilterCandidateBuilder::new( Arc::clone(expr), Arc::clone(physical_file_schema), - Arc::clone(table_schema), - Arc::clone(schema_adapter_factory), ) .build(metadata) }) @@ -516,13 +465,9 @@ mod test { use super::*; use datafusion_common::ScalarValue; - use arrow::datatypes::{Field, Fields, TimeUnit::Nanosecond}; - use datafusion_datasource::schema_adapter::DefaultSchemaAdapterFactory; use datafusion_expr::{col, Expr}; use datafusion_physical_expr::planner::logical2physical; - use datafusion_physical_plan::metrics::{Count, Time}; - use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; use parquet::arrow::parquet_to_arrow_schema; use parquet::file::reader::{FileReader, SerializedFileReader}; @@ -544,178 +489,56 @@ mod test { let expr = col("int64_list").is_not_null(); let expr = logical2physical(&expr, &table_schema); - let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); let table_schema = Arc::new(table_schema.clone()); - let candidate = FilterCandidateBuilder::new( - expr, - table_schema.clone(), - table_schema, - schema_adapter_factory, - ) - .build(metadata) - .expect("building candidate"); + let candidate = FilterCandidateBuilder::new(expr, table_schema.clone()) + .build(metadata) + .expect("building candidate"); assert!(candidate.is_none()); } - #[test] - fn test_filter_type_coercion() { - let testdata = datafusion_common::test_util::parquet_test_data(); - let file = std::fs::File::open(format!("{testdata}/alltypes_plain.parquet")) - .expect("opening file"); - - let parquet_reader_builder = - ParquetRecordBatchReaderBuilder::try_new(file).expect("creating reader"); - let metadata = parquet_reader_builder.metadata().clone(); - let file_schema = parquet_reader_builder.schema().clone(); - - // This is the schema we would like to coerce to, - // which is different from the physical schema of the file. - let table_schema = Schema::new(vec![Field::new( - "timestamp_col", - DataType::Timestamp(Nanosecond, Some(Arc::from("UTC"))), - false, - )]); - - // Test all should fail - let expr = col("timestamp_col").lt(Expr::Literal( - ScalarValue::TimestampNanosecond(Some(1), Some(Arc::from("UTC"))), - )); - let expr = logical2physical(&expr, &table_schema); - let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); - let table_schema = Arc::new(table_schema.clone()); - let candidate = FilterCandidateBuilder::new( - expr, - file_schema.clone(), - table_schema.clone(), - schema_adapter_factory, - ) - .build(&metadata) - .expect("building candidate") - .expect("candidate expected"); - - let mut row_filter = DatafusionArrowPredicate::try_new( - candidate, - &metadata, - Count::new(), - Count::new(), - Time::new(), - ) - .expect("creating filter predicate"); - - let mut parquet_reader = parquet_reader_builder - .with_projection(row_filter.projection().clone()) - .build() - .expect("building reader"); - - // Parquet file is small, we only need 1 record batch - let first_rb = parquet_reader - .next() - .expect("expected record batch") - .expect("expected error free record batch"); - - let filtered = row_filter.evaluate(first_rb.clone()); - assert!(matches!(filtered, Ok(a) if a == BooleanArray::from(vec![false; 8]))); - - // Test all should pass - let expr = col("timestamp_col").gt(Expr::Literal( - ScalarValue::TimestampNanosecond(Some(0), Some(Arc::from("UTC"))), - )); - let expr = logical2physical(&expr, &table_schema); - let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); - let candidate = FilterCandidateBuilder::new( - expr, - file_schema, - table_schema, - schema_adapter_factory, - ) - .build(&metadata) - .expect("building candidate") - .expect("candidate expected"); - - let mut row_filter = DatafusionArrowPredicate::try_new( - candidate, - &metadata, - Count::new(), - Count::new(), - Time::new(), - ) - .expect("creating filter predicate"); - - let filtered = row_filter.evaluate(first_rb); - assert!(matches!(filtered, Ok(a) if a == BooleanArray::from(vec![true; 8]))); - } - #[test] fn nested_data_structures_prevent_pushdown() { - let table_schema = get_basic_table_schema(); + let table_schema = Arc::new(get_lists_table_schema()); - let file_schema = Schema::new(vec![Field::new( - "list_col", - DataType::Struct(Fields::empty()), - true, - )]); - - let expr = col("list_col").is_not_null(); + let expr = col("utf8_list").is_not_null(); + let expr = logical2physical(&expr, &table_schema); + check_expression_can_evaluate_against_schema(&expr, &table_schema); - assert!(!can_expr_be_pushed_down_with_schemas( - &expr, - &file_schema, - &table_schema - )); + assert!(!can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); } #[test] fn projected_columns_prevent_pushdown() { let table_schema = get_basic_table_schema(); - let file_schema = - Schema::new(vec![Field::new("existing_col", DataType::Int64, true)]); - - let expr = col("nonexistent_column").is_null(); + let expr = + Arc::new(Column::new("nonexistent_column", 0)) as Arc; - assert!(!can_expr_be_pushed_down_with_schemas( - &expr, - &file_schema, - &table_schema - )); + assert!(!can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); } #[test] fn basic_expr_doesnt_prevent_pushdown() { let table_schema = get_basic_table_schema(); - let file_schema = - Schema::new(vec![Field::new("string_col", DataType::Utf8, true)]); - let expr = col("string_col").is_null(); + let expr = logical2physical(&expr, &table_schema); - assert!(can_expr_be_pushed_down_with_schemas( - &expr, - &file_schema, - &table_schema - )); + assert!(can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); } #[test] fn complex_expr_doesnt_prevent_pushdown() { let table_schema = get_basic_table_schema(); - let file_schema = Schema::new(vec![ - Field::new("string_col", DataType::Utf8, true), - Field::new("bigint_col", DataType::Int64, true), - ]); - let expr = col("string_col") .is_not_null() - .or(col("bigint_col").gt(Expr::Literal(ScalarValue::Int64(Some(5))))); + .or(col("bigint_col").gt(Expr::Literal(ScalarValue::Int64(Some(5)), None))); + let expr = logical2physical(&expr, &table_schema); - assert!(can_expr_be_pushed_down_with_schemas( - &expr, - &file_schema, - &table_schema - )); + assert!(can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); } fn get_basic_table_schema() -> Schema { @@ -730,4 +553,27 @@ mod test { parquet_to_arrow_schema(metadata.file_metadata().schema_descr(), None) .expect("parsing schema") } + + fn get_lists_table_schema() -> Schema { + let testdata = datafusion_common::test_util::parquet_test_data(); + let file = std::fs::File::open(format!("{testdata}/list_columns.parquet")) + .expect("opening file"); + + let reader = SerializedFileReader::new(file).expect("creating reader"); + + let metadata = reader.metadata(); + + parquet_to_arrow_schema(metadata.file_metadata().schema_descr(), None) + .expect("parsing schema") + } + + /// Sanity check that the given expression could be evaluated against the given schema without any errors. + /// This will fail if the expression references columns that are not in the schema or if the types of the columns are incompatible, etc. + fn check_expression_can_evaluate_against_schema( + expr: &Arc, + table_schema: &Arc, + ) -> bool { + let batch = RecordBatch::new_empty(Arc::clone(table_schema)); + expr.evaluate(&batch).is_ok() + } } diff --git a/datafusion/datasource-parquet/src/row_group_filter.rs b/datafusion/datasource-parquet/src/row_group_filter.rs index 13418cdeee22..51d50d780f10 100644 --- a/datafusion/datasource-parquet/src/row_group_filter.rs +++ b/datafusion/datasource-parquet/src/row_group_filter.rs @@ -21,9 +21,10 @@ use std::sync::Arc; use super::{ParquetAccessPlan, ParquetFileMetrics}; use arrow::array::{ArrayRef, BooleanArray}; use arrow::datatypes::Schema; +use datafusion_common::pruning::PruningStatistics; use datafusion_common::{Column, Result, ScalarValue}; use datafusion_datasource::FileRange; -use datafusion_physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use datafusion_pruning::PruningPredicate; use parquet::arrow::arrow_reader::statistics::StatisticsConverter; use parquet::arrow::parquet_column; use parquet::basic::Type; @@ -1241,12 +1242,16 @@ mod tests { .run( lit("1").eq(lit("1")).and( col(r#""String""#) - .eq(Expr::Literal(ScalarValue::Utf8View(Some(String::from( - "Hello_Not_Exists", - ))))) - .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( - Some(String::from("Hello_Not_Exists2")), - )))), + .eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("Hello_Not_Exists"))), + None, + )) + .or(col(r#""String""#).eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from( + "Hello_Not_Exists2", + ))), + None, + ))), ), ) .await @@ -1265,7 +1270,7 @@ mod tests { let expr = col(r#""String""#).in_list( (1..25) - .map(|i| lit(format!("Hello_Not_Exists{}", i))) + .map(|i| lit(format!("Hello_Not_Exists{i}"))) .collect::>(), false, ); @@ -1326,15 +1331,18 @@ mod tests { // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` .run( col(r#""String""#) - .eq(Expr::Literal(ScalarValue::Utf8View(Some(String::from( - "Hello", - ))))) - .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( - Some(String::from("the quick")), - )))) - .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( - Some(String::from("are you")), - )))), + .eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("Hello"))), + None, + )) + .or(col(r#""String""#).eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("the quick"))), + None, + ))) + .or(col(r#""String""#).eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("are you"))), + None, + ))), ) .await } diff --git a/datafusion/datasource-parquet/src/source.rs b/datafusion/datasource-parquet/src/source.rs index 6236525fcb9f..f2e782315100 100644 --- a/datafusion/datasource-parquet/src/source.rs +++ b/datafusion/datasource-parquet/src/source.rs @@ -21,31 +21,34 @@ use std::fmt::Debug; use std::fmt::Formatter; use std::sync::Arc; -use crate::opener::build_page_pruning_predicate; -use crate::opener::build_pruning_predicate; +use crate::opener::build_pruning_predicates; use crate::opener::ParquetOpener; -use crate::page_filter::PagePruningAccessPlanFilter; +use crate::row_filter::can_expr_be_pushed_down_with_schemas; use crate::DefaultParquetFileReaderFactory; use crate::ParquetFileReaderFactory; +use datafusion_common::config::ConfigOptions; +use datafusion_datasource::as_file_source; use datafusion_datasource::file_stream::FileOpener; use datafusion_datasource::schema_adapter::{ DefaultSchemaAdapterFactory, SchemaAdapterFactory, }; -use arrow::datatypes::{Schema, SchemaRef, TimeUnit}; +use arrow::datatypes::{SchemaRef, TimeUnit}; use datafusion_common::config::TableParquetOptions; use datafusion_common::{DataFusionError, Statistics}; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; +use datafusion_physical_expr::conjunction; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_optimizer::pruning::PruningPredicate; -use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; +use datafusion_physical_plan::filter_pushdown::FilterPushdownPropagation; +use datafusion_physical_plan::filter_pushdown::PredicateSupports; +use datafusion_physical_plan::metrics::Count; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion_physical_plan::DisplayFormatType; use itertools::Itertools; use object_store::ObjectStore; - /// Execution plan for reading one or more Parquet files. /// /// ```text @@ -92,7 +95,7 @@ use object_store::ObjectStore; /// # let predicate = lit(true); /// let source = Arc::new( /// ParquetSource::default() -/// .with_predicate(Arc::clone(&file_schema), predicate) +/// .with_predicate(predicate) /// ); /// // Create a DataSourceExec for reading `file1.parquet` with a file size of 100MB /// let config = FileScanConfigBuilder::new(object_store_url, file_schema, source) @@ -259,12 +262,12 @@ pub struct ParquetSource { pub(crate) table_parquet_options: TableParquetOptions, /// Optional metrics pub(crate) metrics: ExecutionPlanMetricsSet, + /// The schema of the file. + /// In particular, this is the schema of the table without partition columns, + /// *not* the physical schema of the file. + pub(crate) file_schema: Option, /// Optional predicate for row filtering during parquet scan pub(crate) predicate: Option>, - /// Optional predicate for pruning row groups (derived from `predicate`) - pub(crate) pruning_predicate: Option>, - /// Optional predicate for pruning pages (derived from `predicate`) - pub(crate) page_pruning_predicate: Option>, /// Optional user defined parquet file reader factory pub(crate) parquet_file_reader_factory: Option>, /// Optional user defined schema adapter @@ -303,26 +306,12 @@ impl ParquetSource { self } - /// Set predicate information, also sets pruning_predicate and page_pruning_predicate attributes - pub fn with_predicate( - &self, - file_schema: Arc, - predicate: Arc, - ) -> Self { + /// Set predicate information + pub fn with_predicate(&self, predicate: Arc) -> Self { let mut conf = self.clone(); - let metrics = ExecutionPlanMetricsSet::new(); - let predicate_creation_errors = - MetricBuilder::new(&metrics).global_counter("num_predicate_creation_errors"); - conf = conf.with_metrics(metrics); conf.predicate = Some(Arc::clone(&predicate)); - - conf.page_pruning_predicate = - Some(build_page_pruning_predicate(&predicate, &file_schema)); - conf.pruning_predicate = - build_pruning_predicate(predicate, &file_schema, &predicate_creation_errors); - conf } @@ -353,25 +342,6 @@ impl ParquetSource { self } - /// return the optional schema adapter factory - pub fn schema_adapter_factory(&self) -> Option<&Arc> { - self.schema_adapter_factory.as_ref() - } - - /// Set optional schema adapter factory. - /// - /// [`SchemaAdapterFactory`] allows user to specify how fields from the - /// parquet file get mapped to that of the table schema. The default schema - /// adapter uses arrow's cast library to map the parquet fields to the table - /// schema. - pub fn with_schema_adapter_factory( - mut self, - schema_adapter_factory: Arc, - ) -> Self { - self.schema_adapter_factory = Some(schema_adapter_factory); - self - } - /// If true, the predicate will be used during the parquet scan. /// Defaults to false /// @@ -436,10 +406,34 @@ impl ParquetSource { fn bloom_filter_on_read(&self) -> bool { self.table_parquet_options.global.bloom_filter_on_read } + + /// Applies schema adapter factory from the FileScanConfig if present. + /// + /// # Arguments + /// * `conf` - FileScanConfig that may contain a schema adapter factory + /// # Returns + /// The converted FileSource with schema adapter factory applied if provided + pub fn apply_schema_adapter( + self, + conf: &FileScanConfig, + ) -> datafusion_common::Result> { + let file_source: Arc = self.into(); + + // If the FileScanConfig.file_source() has a schema adapter factory, apply it + if let Some(factory) = conf.file_source().schema_adapter_factory() { + file_source.with_schema_adapter_factory( + Arc::::clone(&factory), + ) + } else { + Ok(file_source) + } + } } /// Parses datafusion.common.config.ParquetOptions.coerce_int96 String to a arrow_schema.datatype.TimeUnit -fn parse_coerce_int96_string(str_setting: &str) -> datafusion_common::Result { +pub(crate) fn parse_coerce_int96_string( + str_setting: &str, +) -> datafusion_common::Result { let str_setting_lower: &str = &str_setting.to_lowercase(); match str_setting_lower { @@ -454,6 +448,13 @@ fn parse_coerce_int96_string(str_setting: &str) -> datafusion_common::Result for Arc { + fn from(source: ParquetSource) -> Self { + as_file_source(source) + } +} + impl FileSource for ParquetSource { fn create_file_opener( &self, @@ -474,6 +475,13 @@ impl FileSource for ParquetSource { Arc::new(DefaultParquetFileReaderFactory::new(object_store)) as _ }); + let file_decryption_properties = self + .table_parquet_options() + .crypto + .file_decryption + .as_ref() + .map(|props| Arc::new(props.clone().into())); + let coerce_int96 = self .table_parquet_options .global @@ -489,7 +497,8 @@ impl FileSource for ParquetSource { .expect("Batch size must set before creating ParquetOpener"), limit: base_config.limit, predicate: self.predicate.clone(), - table_schema: Arc::clone(&base_config.file_schema), + logical_file_schema: Arc::clone(&base_config.file_schema), + partition_fields: base_config.table_partition_cols.clone(), metadata_size_hint: self.metadata_size_hint, metrics: self.metrics().clone(), parquet_file_reader_factory, @@ -500,6 +509,7 @@ impl FileSource for ParquetSource { enable_row_group_stats_pruning: self.table_parquet_options.global.pruning, schema_adapter_factory, coerce_int96, + file_decryption_properties, }) } @@ -513,8 +523,11 @@ impl FileSource for ParquetSource { Arc::new(conf) } - fn with_schema(&self, _schema: SchemaRef) -> Arc { - Arc::new(Self { ..self.clone() }) + fn with_schema(&self, schema: SchemaRef) -> Arc { + Arc::new(Self { + file_schema: Some(schema), + ..self.clone() + }) } fn with_statistics(&self, statistics: Statistics) -> Arc { @@ -559,25 +572,41 @@ impl FileSource for ParquetSource { .predicate() .map(|p| format!(", predicate={p}")) .unwrap_or_default(); - let pruning_predicate_string = self - .pruning_predicate - .as_ref() - .map(|pre| { - let mut guarantees = pre + + write!(f, "{predicate_string}")?; + + // Try to build a the pruning predicates. + // These are only generated here because it's useful to have *some* + // idea of what pushdown is happening when viewing plans. + // However it is important to note that these predicates are *not* + // necessarily the predicates that are actually evaluated: + // the actual predicates are built in reference to the physical schema of + // each file, which we do not have at this point and hence cannot use. + // Instead we use the logical schema of the file (the table schema without partition columns). + if let (Some(file_schema), Some(predicate)) = + (&self.file_schema, &self.predicate) + { + let predicate_creation_errors = Count::new(); + if let (Some(pruning_predicate), _) = build_pruning_predicates( + Some(predicate), + file_schema, + &predicate_creation_errors, + ) { + let mut guarantees = pruning_predicate .literal_guarantees() .iter() - .map(|item| format!("{}", item)) + .map(|item| format!("{item}")) .collect_vec(); guarantees.sort(); - format!( + writeln!( + f, ", pruning_predicate={}, required_guarantees=[{}]", - pre.predicate_expr(), + pruning_predicate.predicate_expr(), guarantees.join(", ") - ) - }) - .unwrap_or_default(); - - write!(f, "{}{}", predicate_string, pruning_predicate_string) + )?; + } + }; + Ok(()) } DisplayFormatType::TreeRender => { if let Some(predicate) = self.predicate() { @@ -587,4 +616,60 @@ impl FileSource for ParquetSource { } } } + + fn try_pushdown_filters( + &self, + filters: Vec>, + config: &ConfigOptions, + ) -> datafusion_common::Result>> { + let Some(file_schema) = self.file_schema.clone() else { + return Ok(FilterPushdownPropagation::unsupported(filters)); + }; + // Can we push down the filters themselves into the scan or only use stats pruning? + let config_pushdown_enabled = config.execution.parquet.pushdown_filters; + let table_pushdown_enabled = self.pushdown_filters(); + let pushdown_filters = table_pushdown_enabled || config_pushdown_enabled; + + let mut source = self.clone(); + let filters = PredicateSupports::new_with_supported_check(filters, |filter| { + can_expr_be_pushed_down_with_schemas(filter, &file_schema) + }); + if filters.is_all_unsupported() { + // No filters can be pushed down, so we can just return the remaining filters + // and avoid replacing the source in the physical plan. + return Ok(FilterPushdownPropagation::with_filters(filters)); + } + let allowed_filters = filters.collect_supported(); + let predicate = match source.predicate { + Some(predicate) => conjunction( + std::iter::once(predicate).chain(allowed_filters.iter().cloned()), + ), + None => conjunction(allowed_filters.iter().cloned()), + }; + source.predicate = Some(predicate); + let source = Arc::new(source); + // If pushdown_filters is false we tell our parents that they still have to handle the filters, + // even if we updated the predicate to include the filters (they will only be used for stats pruning). + if !pushdown_filters { + return Ok(FilterPushdownPropagation::with_filters( + filters.make_unsupported(), + ) + .with_updated_node(source)); + } + Ok(FilterPushdownPropagation::with_filters(filters).with_updated_node(source)) + } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> datafusion_common::Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } } diff --git a/datafusion/datasource-parquet/src/writer.rs b/datafusion/datasource-parquet/src/writer.rs index 64eb37c81f5d..d37b6e26a753 100644 --- a/datafusion/datasource-parquet/src/writer.rs +++ b/datafusion/datasource-parquet/src/writer.rs @@ -46,7 +46,15 @@ pub async fn plan_to_parquet( let propclone = writer_properties.clone(); let storeref = Arc::clone(&store); - let buf_writer = BufWriter::new(storeref, file.clone()); + let buf_writer = BufWriter::with_capacity( + storeref, + file.clone(), + task_ctx + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + ); let mut stream = plan.execute(i, Arc::clone(&task_ctx))?; join_set.spawn(async move { let mut writer = diff --git a/datafusion/datasource-parquet/tests/apply_schema_adapter_tests.rs b/datafusion/datasource-parquet/tests/apply_schema_adapter_tests.rs new file mode 100644 index 000000000000..955cd224e6a4 --- /dev/null +++ b/datafusion/datasource-parquet/tests/apply_schema_adapter_tests.rs @@ -0,0 +1,206 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod parquet_adapter_tests { + use arrow::{ + datatypes::{DataType, Field, Schema, SchemaRef}, + record_batch::RecordBatch, + }; + use datafusion_common::{ColumnStatistics, DataFusionError, Result}; + use datafusion_datasource::{ + file::FileSource, + file_scan_config::FileScanConfigBuilder, + schema_adapter::{SchemaAdapter, SchemaAdapterFactory, SchemaMapper}, + }; + use datafusion_datasource_parquet::source::ParquetSource; + use datafusion_execution::object_store::ObjectStoreUrl; + use std::{fmt::Debug, sync::Arc}; + + /// A test schema adapter factory that adds prefix to column names + #[derive(Debug)] + struct PrefixAdapterFactory { + prefix: String, + } + + impl SchemaAdapterFactory for PrefixAdapterFactory { + fn create( + &self, + projected_table_schema: SchemaRef, + _table_schema: SchemaRef, + ) -> Box { + Box::new(PrefixAdapter { + input_schema: projected_table_schema, + prefix: self.prefix.clone(), + }) + } + } + + /// A test schema adapter that adds prefix to column names + #[derive(Debug)] + struct PrefixAdapter { + input_schema: SchemaRef, + prefix: String, + } + + impl SchemaAdapter for PrefixAdapter { + fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { + let field = self.input_schema.field(index); + file_schema.fields.find(field.name()).map(|(i, _)| i) + } + + fn map_schema( + &self, + file_schema: &Schema, + ) -> Result<(Arc, Vec)> { + let mut projection = Vec::with_capacity(file_schema.fields().len()); + for (file_idx, file_field) in file_schema.fields().iter().enumerate() { + if self.input_schema.fields().find(file_field.name()).is_some() { + projection.push(file_idx); + } + } + + // Create a schema mapper that adds a prefix to column names + #[derive(Debug)] + struct PrefixSchemaMapping { + // Keep only the prefix field which is actually used in the implementation + prefix: String, + } + + impl SchemaMapper for PrefixSchemaMapping { + fn map_batch(&self, batch: RecordBatch) -> Result { + // Create a new schema with prefixed field names + let prefixed_fields: Vec = batch + .schema() + .fields() + .iter() + .map(|field| { + Field::new( + format!("{}{}", self.prefix, field.name()), + field.data_type().clone(), + field.is_nullable(), + ) + }) + .collect(); + let prefixed_schema = Arc::new(Schema::new(prefixed_fields)); + + // Create a new batch with the prefixed schema but the same data + let options = arrow::record_batch::RecordBatchOptions::default(); + RecordBatch::try_new_with_options( + prefixed_schema, + batch.columns().to_vec(), + &options, + ) + .map_err(|e| DataFusionError::ArrowError(e, None)) + } + + fn map_column_statistics( + &self, + stats: &[ColumnStatistics], + ) -> Result> { + // For testing, just return the input statistics + Ok(stats.to_vec()) + } + } + + Ok(( + Arc::new(PrefixSchemaMapping { + prefix: self.prefix.clone(), + }), + projection, + )) + } + } + + #[test] + fn test_apply_schema_adapter_with_factory() { + // Create a schema + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + // Create a parquet source + let source = ParquetSource::default(); + + // Create a file scan config with source that has a schema adapter factory + let factory = Arc::new(PrefixAdapterFactory { + prefix: "test_".to_string(), + }); + + let file_source = source.clone().with_schema_adapter_factory(factory).unwrap(); + + let config = FileScanConfigBuilder::new( + ObjectStoreUrl::local_filesystem(), + schema.clone(), + file_source, + ) + .build(); + + // Apply schema adapter to a new source + let result_source = source.apply_schema_adapter(&config).unwrap(); + + // Verify the adapter was applied + assert!(result_source.schema_adapter_factory().is_some()); + + // Create adapter and test it produces expected schema + let adapter_factory = result_source.schema_adapter_factory().unwrap(); + let adapter = adapter_factory.create(schema.clone(), schema.clone()); + + // Create a dummy batch to test the schema mapping + let dummy_batch = RecordBatch::new_empty(schema.clone()); + + // Get the file schema (which is the same as the table schema in this test) + let (mapper, _) = adapter.map_schema(&schema).unwrap(); + + // Apply the mapping to get the output schema + let mapped_batch = mapper.map_batch(dummy_batch).unwrap(); + let output_schema = mapped_batch.schema(); + + // Check the column names have the prefix + assert_eq!(output_schema.field(0).name(), "test_id"); + assert_eq!(output_schema.field(1).name(), "test_name"); + } + + #[test] + fn test_apply_schema_adapter_without_factory() { + // Create a schema + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + // Create a parquet source + let source = ParquetSource::default(); + + // Convert to Arc + let file_source: Arc = Arc::new(source.clone()); + + // Create a file scan config without a schema adapter factory + let config = FileScanConfigBuilder::new( + ObjectStoreUrl::local_filesystem(), + schema.clone(), + file_source, + ) + .build(); + + // Apply schema adapter function - should pass through the source unchanged + let result_source = source.apply_schema_adapter(&config).unwrap(); + + // Verify no adapter was applied + assert!(result_source.schema_adapter_factory().is_none()); + } +} diff --git a/datafusion/datasource/Cargo.toml b/datafusion/datasource/Cargo.toml index 1088efc268c9..afef1901fad8 100644 --- a/datafusion/datasource/Cargo.toml +++ b/datafusion/datasource/Cargo.toml @@ -46,7 +46,7 @@ async-compression = { version = "0.4.19", features = [ ], optional = true } async-trait = { workspace = true } bytes = { workspace = true } -bzip2 = { version = "0.5.2", optional = true } +bzip2 = { version = "0.6.0", optional = true } chrono = { workspace = true } datafusion-common = { workspace = true, features = ["object_store"] } datafusion-common-runtime = { workspace = true } @@ -56,7 +56,7 @@ datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-session = { workspace = true } -flate2 = { version = "1.1.1", optional = true } +flate2 = { version = "1.1.2", optional = true } futures = { workspace = true } glob = "0.3.0" itertools = { workspace = true } @@ -66,7 +66,7 @@ parquet = { workspace = true, optional = true } rand = { workspace = true } tempfile = { workspace = true, optional = true } tokio = { workspace = true } -tokio-util = { version = "0.7.14", features = ["io"], optional = true } +tokio-util = { version = "0.7.15", features = ["io"], optional = true } url = { workspace = true } xz2 = { version = "0.1", optional = true, features = ["static"] } zstd = { version = "0.13", optional = true, default-features = false } diff --git a/datafusion/datasource/README.md b/datafusion/datasource/README.md index 750ee9375154..5d743bc83063 100644 --- a/datafusion/datasource/README.md +++ b/datafusion/datasource/README.md @@ -23,4 +23,9 @@ This crate is a submodule of DataFusion that defines common DataSource related components like FileScanConfig, FileCompression etc. +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/datasource/benches/split_groups_by_statistics.rs b/datafusion/datasource/benches/split_groups_by_statistics.rs index f7c5e1b44ae0..d51fdfc0a6e9 100644 --- a/datafusion/datasource/benches/split_groups_by_statistics.rs +++ b/datafusion/datasource/benches/split_groups_by_statistics.rs @@ -15,14 +15,16 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; +use std::time::Duration; + use arrow::datatypes::{DataType, Field, Schema}; -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion_datasource::file_scan_config::FileScanConfig; use datafusion_datasource::{generate_test_files, verify_sort_integrity}; -use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use std::sync::Arc; -use std::time::Duration; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; pub fn compare_split_groups_by_statistics_algorithms(c: &mut Criterion) { let file_schema = Arc::new(Schema::new(vec![Field::new( @@ -31,13 +33,8 @@ pub fn compare_split_groups_by_statistics_algorithms(c: &mut Criterion) { false, )])); - let sort_expr = PhysicalSortExpr { - expr: Arc::new(datafusion_physical_expr::expressions::Column::new( - "value", 0, - )), - options: arrow::compute::SortOptions::default(), - }; - let sort_ordering = LexOrdering::from(vec![sort_expr]); + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("value", 0))); + let sort_ordering = LexOrdering::from([sort_expr]); // Small, medium, large number of files let file_counts = [10, 100, 1000]; @@ -55,7 +52,7 @@ pub fn compare_split_groups_by_statistics_algorithms(c: &mut Criterion) { group.bench_with_input( BenchmarkId::new( "original", - format!("files={},overlap={:.1}", num_files, overlap), + format!("files={num_files},overlap={overlap:.1}"), ), &( file_groups.clone(), @@ -77,8 +74,8 @@ pub fn compare_split_groups_by_statistics_algorithms(c: &mut Criterion) { for &tp in &target_partitions { group.bench_with_input( BenchmarkId::new( - format!("v2_partitions={}", tp), - format!("files={},overlap={:.1}", num_files, overlap), + format!("v2_partitions={tp}"), + format!("files={num_files},overlap={overlap:.1}"), ), &( file_groups.clone(), diff --git a/datafusion/datasource/src/file.rs b/datafusion/datasource/src/file.rs index 0066f39801a1..c5f21ebf1a0f 100644 --- a/datafusion/datasource/src/file.rs +++ b/datafusion/datasource/src/file.rs @@ -25,17 +25,32 @@ use std::sync::Arc; use crate::file_groups::FileGroupPartitioner; use crate::file_scan_config::FileScanConfig; use crate::file_stream::FileOpener; +use crate::schema_adapter::SchemaAdapterFactory; use arrow::datatypes::SchemaRef; -use datafusion_common::Statistics; -use datafusion_physical_expr::LexOrdering; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{not_impl_err, Result, Statistics}; +use datafusion_physical_expr::{LexOrdering, PhysicalExpr}; +use datafusion_physical_plan::filter_pushdown::FilterPushdownPropagation; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion_physical_plan::DisplayFormatType; use object_store::ObjectStore; -/// Common file format behaviors needs to implement. +/// Helper function to convert any type implementing FileSource to Arc<dyn FileSource> +pub fn as_file_source(source: T) -> Arc { + Arc::new(source) +} + +/// file format specific behaviors for elements in [`DataSource`] /// -/// See implementation examples such as `ParquetSource`, `CsvSource` +/// See more details on specific implementations: +/// * [`ArrowSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.ArrowSource.html) +/// * [`AvroSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.AvroSource.html) +/// * [`CsvSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.CsvSource.html) +/// * [`JsonSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.JsonSource.html) +/// * [`ParquetSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.ParquetSource.html) +/// +/// [`DataSource`]: crate::source::DataSource pub trait FileSource: Send + Sync { /// Creates a `dyn FileOpener` based on given parameters fn create_file_opener( @@ -57,7 +72,7 @@ pub trait FileSource: Send + Sync { /// Return execution plan metrics fn metrics(&self) -> &ExecutionPlanMetricsSet; /// Return projected statistics - fn statistics(&self) -> datafusion_common::Result; + fn statistics(&self) -> Result; /// String representation of file source such as "csv", "json", "parquet" fn file_type(&self) -> &str; /// Format FileType specific information @@ -65,17 +80,19 @@ pub trait FileSource: Send + Sync { Ok(()) } - /// If supported by the [`FileSource`], redistribute files across partitions according to their size. - /// Allows custom file formats to implement their own repartitioning logic. + /// If supported by the [`FileSource`], redistribute files across partitions + /// according to their size. Allows custom file formats to implement their + /// own repartitioning logic. /// - /// Provides a default repartitioning behavior, see comments on [`FileGroupPartitioner`] for more detail. + /// The default implementation uses [`FileGroupPartitioner`]. See that + /// struct for more details. fn repartitioned( &self, target_partitions: usize, repartition_file_min_size: usize, output_ordering: Option, config: &FileScanConfig, - ) -> datafusion_common::Result> { + ) -> Result> { if config.file_compression_type.is_compressed() || config.new_lines_in_values { return Ok(None); } @@ -93,4 +110,42 @@ pub trait FileSource: Send + Sync { } Ok(None) } + + /// Try to push down filters into this FileSource. + /// See [`ExecutionPlan::handle_child_pushdown_result`] for more details. + /// + /// [`ExecutionPlan::handle_child_pushdown_result`]: datafusion_physical_plan::ExecutionPlan::handle_child_pushdown_result + fn try_pushdown_filters( + &self, + filters: Vec>, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::unsupported(filters)) + } + + /// Set optional schema adapter factory. + /// + /// [`SchemaAdapterFactory`] allows user to specify how fields from the + /// file get mapped to that of the table schema. If you implement this + /// method, you should also implement [`schema_adapter_factory`]. + /// + /// The default implementation returns a not implemented error. + /// + /// [`schema_adapter_factory`]: Self::schema_adapter_factory + fn with_schema_adapter_factory( + &self, + _factory: Arc, + ) -> Result> { + not_impl_err!( + "FileSource {} does not support schema adapter factory", + self.file_type() + ) + } + + /// Returns the current schema adapter factory if set + /// + /// Default implementation returns `None`. + fn schema_adapter_factory(&self) -> Option> { + None + } } diff --git a/datafusion/datasource/src/file_format.rs b/datafusion/datasource/src/file_format.rs index 0e0b7b12e16a..b2caf5277a25 100644 --- a/datafusion/datasource/src/file_format.rs +++ b/datafusion/datasource/src/file_format.rs @@ -28,11 +28,10 @@ use crate::file_compression_type::FileCompressionType; use crate::file_scan_config::FileScanConfig; use crate::file_sink_config::FileSinkConfig; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::datatypes::SchemaRef; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{internal_err, not_impl_err, GetExt, Result, Statistics}; -use datafusion_expr::Expr; -use datafusion_physical_expr::{LexRequirement, PhysicalExpr}; +use datafusion_physical_expr::LexRequirement; use datafusion_physical_plan::ExecutionPlan; use datafusion_session::Session; @@ -94,7 +93,6 @@ pub trait FileFormat: Send + Sync + fmt::Debug { &self, state: &dyn Session, conf: FileScanConfig, - filters: Option<&Arc>, ) -> Result>; /// Take a list of files and the configuration to convert it to the @@ -109,37 +107,10 @@ pub trait FileFormat: Send + Sync + fmt::Debug { not_impl_err!("Writer not implemented for this format") } - /// Check if the specified file format has support for pushing down the provided filters within - /// the given schemas. Added initially to support the Parquet file format's ability to do this. - fn supports_filters_pushdown( - &self, - _file_schema: &Schema, - _table_schema: &Schema, - _filters: &[&Expr], - ) -> Result { - Ok(FilePushdownSupport::NoSupport) - } - /// Return the related FileSource such as `CsvSource`, `JsonSource`, etc. fn file_source(&self) -> Arc; } -/// An enum to distinguish between different states when determining if certain filters can be -/// pushed down to file scanning -#[derive(Debug, PartialEq)] -pub enum FilePushdownSupport { - /// The file format/system being asked does not support any sort of pushdown. This should be - /// used even if the file format theoretically supports some sort of pushdown, but it's not - /// enabled or implemented yet. - NoSupport, - /// The file format/system being asked *does* support pushdown, but it can't make it work for - /// the provided filter/expression - NotSupportedForFilter, - /// The file format/system being asked *does* support pushdown and *can* make it work for the - /// provided filter/expression - Supported, -} - /// Factory for creating [`FileFormat`] instances based on session and command level options /// /// Users can provide their own `FileFormatFactory` to support arbitrary file formats diff --git a/datafusion/datasource/src/file_groups.rs b/datafusion/datasource/src/file_groups.rs index 15c86427ed00..8bfadbef775c 100644 --- a/datafusion/datasource/src/file_groups.rs +++ b/datafusion/datasource/src/file_groups.rs @@ -420,9 +420,19 @@ impl FileGroup { self.files.push(file); } - /// Get the statistics for this group - pub fn statistics(&self) -> Option<&Statistics> { - self.statistics.as_deref() + /// Get the specific file statistics for the given index + /// If the index is None, return the `FileGroup` statistics + pub fn file_statistics(&self, index: Option) -> Option<&Statistics> { + if let Some(index) = index { + self.files.get(index).and_then(|f| f.statistics.as_deref()) + } else { + self.statistics.as_deref() + } + } + + /// Get the mutable reference to the statistics for this group + pub fn statistics_mut(&mut self) -> Option<&mut Statistics> { + self.statistics.as_mut().map(Arc::make_mut) } /// Partition the list of files into `n` groups @@ -953,8 +963,8 @@ mod test { (Some(_), None) => panic!("Expected Some, got None"), (None, Some(_)) => panic!("Expected None, got Some"), (Some(expected), Some(actual)) => { - let expected_string = format!("{:#?}", expected); - let actual_string = format!("{:#?}", actual); + let expected_string = format!("{expected:#?}"); + let actual_string = format!("{actual:#?}"); assert_eq!(expected_string, actual_string); } } diff --git a/datafusion/datasource/src/file_meta.rs b/datafusion/datasource/src/file_meta.rs index 098a15eeb38a..ed7d958c6020 100644 --- a/datafusion/datasource/src/file_meta.rs +++ b/datafusion/datasource/src/file_meta.rs @@ -22,6 +22,7 @@ use object_store::{path::Path, ObjectMeta}; use crate::FileRange; /// A single file or part of a file that should be read, along with its schema, statistics +#[derive(Debug, Clone)] pub struct FileMeta { /// Path for the file (e.g. URL, filesystem path, etc) pub object_meta: ObjectMeta, diff --git a/datafusion/datasource/src/file_scan_config.rs b/datafusion/datasource/src/file_scan_config.rs index 19482eb2ccc6..431b6ab0bcf0 100644 --- a/datafusion/datasource/src/file_scan_config.rs +++ b/datafusion/datasource/src/file_scan_config.rs @@ -23,6 +23,19 @@ use std::{ fmt::Result as FmtResult, marker::PhantomData, sync::Arc, }; +use crate::file_groups::FileGroup; +#[allow(unused_imports)] +use crate::schema_adapter::SchemaAdapterFactory; +use crate::{ + display::FileGroupsDisplay, + file::FileSource, + file_compression_type::FileCompressionType, + file_stream::FileStream, + source::{DataSource, DataSourceExec}, + statistics::MinMaxStatistics, + PartitionedFile, +}; +use arrow::datatypes::FieldRef; use arrow::{ array::{ ArrayData, ArrayRef, BufferBuilder, DictionaryArray, RecordBatch, @@ -31,33 +44,29 @@ use arrow::{ buffer::Buffer, datatypes::{ArrowNativeType, DataType, Field, Schema, SchemaRef, UInt16Type}, }; -use datafusion_common::{exec_err, ColumnStatistics, Constraints, Result, Statistics}; -use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{ + exec_err, ColumnStatistics, Constraints, DataFusionError, Result, ScalarValue, + Statistics, +}; use datafusion_execution::{ object_store::ObjectStoreUrl, SendableRecordBatchStream, TaskContext, }; -use datafusion_physical_expr::{ - expressions::Column, EquivalenceProperties, LexOrdering, Partitioning, - PhysicalSortExpr, -}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_plan::filter_pushdown::FilterPushdownPropagation; use datafusion_physical_plan::{ display::{display_orderings, ProjectSchemaDisplay}, metrics::ExecutionPlanMetricsSet, projection::{all_alias_free_columns, new_projections_for_columns, ProjectionExec}, DisplayAs, DisplayFormatType, ExecutionPlan, }; -use log::{debug, warn}; -use crate::file_groups::FileGroup; -use crate::{ - display::FileGroupsDisplay, - file::FileSource, - file_compression_type::FileCompressionType, - file_stream::FileStream, - source::{DataSource, DataSourceExec}, - statistics::MinMaxStatistics, - PartitionedFile, -}; +use datafusion_physical_plan::coop::cooperative; +use datafusion_physical_plan::execution_plan::SchedulingType; +use log::{debug, warn}; /// The base configurations for a [`DataSourceExec`], the a physical plan for /// any given file format. @@ -71,6 +80,7 @@ use crate::{ /// # use arrow::datatypes::{Field, Fields, DataType, Schema, SchemaRef}; /// # use object_store::ObjectStore; /// # use datafusion_common::Statistics; +/// # use datafusion_common::Result; /// # use datafusion_datasource::file::FileSource; /// # use datafusion_datasource::file_groups::FileGroup; /// # use datafusion_datasource::PartitionedFile; @@ -80,6 +90,7 @@ use crate::{ /// # use datafusion_execution::object_store::ObjectStoreUrl; /// # use datafusion_physical_plan::ExecutionPlan; /// # use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; +/// # use datafusion_datasource::schema_adapter::SchemaAdapterFactory; /// # let file_schema = Arc::new(Schema::new(vec![ /// # Field::new("c1", DataType::Int32, false), /// # Field::new("c2", DataType::Int32, false), @@ -87,22 +98,26 @@ use crate::{ /// # Field::new("c4", DataType::Int32, false), /// # ])); /// # // Note: crate mock ParquetSource, as ParquetSource is not in the datasource crate +/// #[derive(Clone)] /// # struct ParquetSource { -/// # projected_statistics: Option +/// # projected_statistics: Option, +/// # schema_adapter_factory: Option> /// # }; /// # impl FileSource for ParquetSource { /// # fn create_file_opener(&self, _: Arc, _: &FileScanConfig, _: usize) -> Arc { unimplemented!() } /// # fn as_any(&self) -> &dyn Any { self } /// # fn with_batch_size(&self, _: usize) -> Arc { unimplemented!() } -/// # fn with_schema(&self, _: SchemaRef) -> Arc { unimplemented!() } +/// # fn with_schema(&self, _: SchemaRef) -> Arc { Arc::new(self.clone()) as Arc } /// # fn with_projection(&self, _: &FileScanConfig) -> Arc { unimplemented!() } -/// # fn with_statistics(&self, statistics: Statistics) -> Arc { Arc::new(Self {projected_statistics: Some(statistics)} ) } +/// # fn with_statistics(&self, statistics: Statistics) -> Arc { Arc::new(Self {projected_statistics: Some(statistics), schema_adapter_factory: self.schema_adapter_factory.clone()} ) } /// # fn metrics(&self) -> &ExecutionPlanMetricsSet { unimplemented!() } -/// # fn statistics(&self) -> datafusion_common::Result { Ok(self.projected_statistics.clone().expect("projected_statistics should be set")) } +/// # fn statistics(&self) -> Result { Ok(self.projected_statistics.clone().expect("projected_statistics should be set")) } /// # fn file_type(&self) -> &str { "parquet" } +/// # fn with_schema_adapter_factory(&self, factory: Arc) -> Result> { Ok(Arc::new(Self {projected_statistics: self.projected_statistics.clone(), schema_adapter_factory: Some(factory)} )) } +/// # fn schema_adapter_factory(&self) -> Option> { self.schema_adapter_factory.clone() } /// # } /// # impl ParquetSource { -/// # fn new() -> Self { Self {projected_statistics: None} } +/// # fn new() -> Self { Self {projected_statistics: None, schema_adapter_factory: None} } /// # } /// // create FileScan config for reading parquet files from file:// /// let object_store_url = ObjectStoreUrl::local_filesystem(); @@ -161,7 +176,7 @@ pub struct FileScanConfig { /// all records after filtering are returned. pub limit: Option, /// The partitioning columns - pub table_partition_cols: Vec, + pub table_partition_cols: Vec, /// All equivalent lexicographical orderings that describe the schema. pub output_ordering: Vec, /// File compression type @@ -228,15 +243,21 @@ pub struct FileScanConfig { pub struct FileScanConfigBuilder { object_store_url: ObjectStoreUrl, /// Table schema before any projections or partition columns are applied. - /// This schema is used to read the files, but is **not** necessarily the schema of the physical files. - /// Rather this is the schema that the physical file schema will be mapped onto, and the schema that the + /// + /// This schema is used to read the files, but is **not** necessarily the + /// schema of the physical files. Rather this is the schema that the + /// physical file schema will be mapped onto, and the schema that the /// [`DataSourceExec`] will return. + /// + /// This is usually the same as the table schema as specified by the `TableProvider` minus any partition columns. + /// + /// This probably would be better named `table_schema` file_schema: SchemaRef, file_source: Arc, limit: Option, projection: Option>, - table_partition_cols: Vec, + table_partition_cols: Vec, constraints: Option, file_groups: Vec, statistics: Option, @@ -300,7 +321,10 @@ impl FileScanConfigBuilder { /// Set the partitioning columns pub fn with_table_partition_cols(mut self, table_partition_cols: Vec) -> Self { - self.table_partition_cols = table_partition_cols; + self.table_partition_cols = table_partition_cols + .into_iter() + .map(|f| Arc::new(f) as FieldRef) + .collect(); self } @@ -402,7 +426,9 @@ impl FileScanConfigBuilder { let statistics = statistics.unwrap_or_else(|| Statistics::new_unknown(&file_schema)); - let file_source = file_source.with_statistics(statistics.clone()); + let file_source = file_source + .with_statistics(statistics.clone()) + .with_schema(Arc::clone(&file_schema)); let file_compression_type = file_compression_type.unwrap_or(FileCompressionType::UNCOMPRESSED); let new_lines_in_values = new_lines_in_values.unwrap_or(false); @@ -458,13 +484,12 @@ impl DataSource for FileScanConfig { let source = self .file_source .with_batch_size(batch_size) - .with_schema(Arc::clone(&self.file_schema)) .with_projection(self); let opener = source.create_file_opener(object_store, self, partition); let stream = FileStream::new(self, partition, opener, source.metrics())?; - Ok(Box::pin(stream)) + Ok(Box::pin(cooperative(stream))) } fn as_any(&self) -> &dyn Any { @@ -474,7 +499,8 @@ impl DataSource for FileScanConfig { fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> FmtResult { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - let (schema, _, _, orderings) = self.project(); + let schema = self.projected_schema(); + let orderings = get_projected_output_ordering(self, &schema); write!(f, "file_groups=")?; FileGroupsDisplay(&self.file_groups).fmt_as(t, f)?; @@ -528,10 +554,14 @@ impl DataSource for FileScanConfig { fn eq_properties(&self) -> EquivalenceProperties { let (schema, constraints, _, orderings) = self.project(); - EquivalenceProperties::new_with_orderings(schema, orderings.as_slice()) + EquivalenceProperties::new_with_orderings(schema, orderings) .with_constraints(constraints) } + fn scheduling_type(&self) -> SchedulingType { + SchedulingType::Cooperative + } + fn statistics(&self) -> Result { Ok(self.projected_stats()) } @@ -576,7 +606,7 @@ impl DataSource for FileScanConfig { &file_scan .projection .clone() - .unwrap_or((0..self.file_schema.fields().len()).collect()), + .unwrap_or_else(|| (0..self.file_schema.fields().len()).collect()), ); DataSourceExec::from_data_source( FileScanConfigBuilder::from(file_scan) @@ -587,6 +617,32 @@ impl DataSource for FileScanConfig { ) as _ })) } + + fn try_pushdown_filters( + &self, + filters: Vec>, + config: &ConfigOptions, + ) -> Result>> { + let result = self.file_source.try_pushdown_filters(filters, config)?; + match result.updated_node { + Some(new_file_source) => { + let file_scan_config = FileScanConfigBuilder::from(self.clone()) + .with_source(new_file_source) + .build(); + Ok(FilterPushdownPropagation { + filters: result.filters, + updated_node: Some(Arc::new(file_scan_config) as _), + }) + } + None => { + // If the file source does not support filter pushdown, return the original config + Ok(FilterPushdownPropagation { + filters: result.filters, + updated_node: None, + }) + } + } + } } impl FileScanConfig { @@ -607,12 +663,14 @@ impl FileScanConfig { file_source: Arc, ) -> Self { let statistics = Statistics::new_unknown(&file_schema); - let file_source = file_source.with_statistics(statistics.clone()); + let file_source = file_source + .with_statistics(statistics.clone()) + .with_schema(Arc::clone(&file_schema)); Self { object_store_url, file_schema, file_groups: vec![], - constraints: Constraints::empty(), + constraints: Constraints::default(), projection: None, limit: None, table_partition_cols: vec![], @@ -688,7 +746,9 @@ impl FileScanConfig { self.file_schema.field(idx).clone() } else { let partition_idx = idx - self.file_schema.fields().len(); - self.table_partition_cols[partition_idx].clone() + Arc::unwrap_or_clone(Arc::clone( + &self.table_partition_cols[partition_idx], + )) } }) .collect(); @@ -701,10 +761,7 @@ impl FileScanConfig { pub fn projected_constraints(&self) -> Constraints { let indexes = self.projection_indices(); - - self.constraints - .project(&indexes) - .unwrap_or_else(Constraints::empty) + self.constraints.project(&indexes).unwrap_or_default() } /// Set the projection of the files @@ -751,7 +808,10 @@ impl FileScanConfig { /// Set the partitioning columns of the files #[deprecated(since = "47.0.0", note = "use FileScanConfigBuilder instead")] pub fn with_table_partition_cols(mut self, table_partition_cols: Vec) -> Self { - self.table_partition_cols = table_partition_cols; + self.table_partition_cols = table_partition_cols + .into_iter() + .map(|f| Arc::new(f) as FieldRef) + .collect(); self } @@ -1051,7 +1111,8 @@ impl Debug for FileScanConfig { impl DisplayAs for FileScanConfig { fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> FmtResult { - let (schema, _, _, orderings) = self.project(); + let schema = self.projected_schema(); + let orderings = get_projected_output_ordering(self, &schema); write!(f, "file_groups=")?; FileGroupsDisplay(&self.file_groups).fmt_as(t, f)?; @@ -1387,16 +1448,16 @@ fn get_projected_output_ordering( ) -> Vec { let mut all_orderings = vec![]; for output_ordering in &base_config.output_ordering { - let mut new_ordering = LexOrdering::default(); + let mut new_ordering = vec![]; for PhysicalSortExpr { expr, options } in output_ordering.iter() { if let Some(col) = expr.as_any().downcast_ref::() { let name = col.name(); if let Some((idx, _)) = projected_schema.column_with_name(name) { // Compute the new sort expression (with correct index) after projection: - new_ordering.push(PhysicalSortExpr { - expr: Arc::new(Column::new(name, idx)), - options: *options, - }); + new_ordering.push(PhysicalSortExpr::new( + Arc::new(Column::new(name, idx)), + *options, + )); continue; } } @@ -1405,11 +1466,9 @@ fn get_projected_output_ordering( break; } - // do not push empty entries - // otherwise we may have `Some(vec![])` at the output ordering. - if new_ordering.is_empty() { + let Some(new_ordering) = LexOrdering::new(new_ordering) else { continue; - } + }; // Check if any file groups are not sorted if base_config.file_groups.iter().any(|group| { @@ -1470,41 +1529,17 @@ pub fn wrap_partition_value_in_dict(val: ScalarValue) -> ScalarValue { #[cfg(test)] mod tests { + use super::*; use crate::{ generate_test_files, test_util::MockSource, tests::aggr_test_schema, verify_sort_integrity, }; - use super::*; - use arrow::{ - array::{Int32Array, RecordBatch}, - compute::SortOptions, - }; - + use arrow::array::{Int32Array, RecordBatch}; use datafusion_common::stats::Precision; - use datafusion_common::{assert_batches_eq, DFSchema}; - use datafusion_expr::{execution_props::ExecutionProps, SortExpr}; - use datafusion_physical_expr::create_physical_expr; - use std::collections::HashMap; - - fn create_physical_sort_expr( - e: &SortExpr, - input_dfschema: &DFSchema, - execution_props: &ExecutionProps, - ) -> Result { - let SortExpr { - expr, - asc, - nulls_first, - } = e; - Ok(PhysicalSortExpr { - expr: create_physical_expr(expr, input_dfschema, execution_props)?, - options: SortOptions { - descending: !asc, - nulls_first: *nulls_first, - }, - }) - } + use datafusion_common::{assert_batches_eq, internal_err}; + use datafusion_expr::SortExpr; + use datafusion_physical_expr::create_physical_sort_expr; /// Returns the column names on the schema pub fn columns(schema: &Schema) -> Vec { @@ -1564,7 +1599,7 @@ mod tests { ); // verify the proj_schema includes the last column and exactly the same the field it is defined - let (proj_schema, _, _, _) = conf.project(); + let proj_schema = conf.projected_schema(); assert_eq!(proj_schema.fields().len(), file_schema.fields().len() + 1); assert_eq!( *proj_schema.field(file_schema.fields().len()), @@ -1670,7 +1705,7 @@ mod tests { assert_eq!(source_statistics, statistics); assert_eq!(source_statistics.column_statistics.len(), 3); - let (proj_schema, ..) = conf.project(); + let proj_schema = conf.projected_schema(); // created a projector for that projected schema let mut proj = PartitionColumnProjector::new( proj_schema, @@ -2021,7 +2056,7 @@ mod tests { )))) .collect::>(), )); - let sort_order = LexOrdering::from( + let Some(sort_order) = LexOrdering::new( case.sort .into_iter() .map(|expr| { @@ -2032,7 +2067,9 @@ mod tests { ) }) .collect::>>()?, - ); + ) else { + return internal_err!("This test should always use an ordering"); + }; let partitioned_files = FileGroup::new( case.files.into_iter().map(From::from).collect::>(), @@ -2195,13 +2232,15 @@ mod tests { wrap_partition_type_in_dict(DataType::Utf8), false, )]) - .with_constraints(Constraints::empty()) .with_statistics(Statistics::new_unknown(&file_schema)) .with_file_groups(vec![FileGroup::new(vec![PartitionedFile::new( "test.parquet".to_string(), 1024, )])]) - .with_output_ordering(vec![LexOrdering::default()]) + .with_output_ordering(vec![[PhysicalSortExpr::new_default(Arc::new( + Column::new("date", 0), + ))] + .into()]) .with_file_compression_type(FileCompressionType::UNCOMPRESSED) .with_newlines_in_values(true) .build(); @@ -2315,6 +2354,7 @@ mod tests { let new_config = new_builder.build(); // Verify properties match + let partition_cols = partition_cols.into_iter().map(Arc::new).collect::>(); assert_eq!(new_config.object_store_url, object_store_url); assert_eq!(new_config.file_schema, schema); assert_eq!(new_config.projection, Some(vec![0, 2])); @@ -2344,14 +2384,12 @@ mod tests { // Setup sort expression let exec_props = ExecutionProps::new(); let df_schema = DFSchema::try_from_qualified_schema("test", schema.as_ref())?; - let sort_expr = vec![col("value").sort(true, false)]; - - let physical_sort_exprs: Vec<_> = sort_expr - .iter() - .map(|expr| create_physical_sort_expr(expr, &df_schema, &exec_props).unwrap()) - .collect(); - - let sort_ordering = LexOrdering::from(physical_sort_exprs); + let sort_expr = [col("value").sort(true, false)]; + let sort_ordering = sort_expr + .map(|expr| { + create_physical_sort_expr(&expr, &df_schema, &exec_props).unwrap() + }) + .into(); // Test case parameters struct TestCase { @@ -2463,10 +2501,7 @@ mod tests { avg_files_per_partition ); - println!( - "Distribution - min files: {}, max files: {}", - min_size, max_size - ); + println!("Distribution - min files: {min_size}, max files: {max_size}"); } } diff --git a/datafusion/datasource/src/file_sink_config.rs b/datafusion/datasource/src/file_sink_config.rs index 2968bd1ee044..8a86b11a4743 100644 --- a/datafusion/datasource/src/file_sink_config.rs +++ b/datafusion/datasource/src/file_sink_config.rs @@ -22,12 +22,14 @@ use crate::sink::DataSink; use crate::write::demux::{start_demuxer_task, DemuxedStreamReceiver}; use crate::ListingTableUrl; +use arrow::array::RecordBatch; use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::Result; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::dml::InsertOp; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use async_trait::async_trait; use object_store::ObjectStore; @@ -77,13 +79,34 @@ pub trait FileSink: DataSink { .runtime_env() .object_store(&config.object_store_url)?; let (demux_task, file_stream_rx) = start_demuxer_task(config, data, context); - self.spawn_writer_tasks_and_join( - context, - demux_task, - file_stream_rx, - object_store, - ) - .await + let mut num_rows = self + .spawn_writer_tasks_and_join( + context, + demux_task, + file_stream_rx, + Arc::clone(&object_store), + ) + .await?; + if num_rows == 0 { + // If no rows were written, then no files are output either. + // In this case, send an empty record batch through to ensure the output file is generated + let schema = Arc::clone(&config.output_schema); + let empty_batch = RecordBatch::new_empty(Arc::clone(&schema)); + let data = Box::pin(RecordBatchStreamAdapter::new( + schema, + futures::stream::iter(vec![Ok(empty_batch)]), + )); + let (demux_task, file_stream_rx) = start_demuxer_task(config, data, context); + num_rows = self + .spawn_writer_tasks_and_join( + context, + demux_task, + file_stream_rx, + Arc::clone(&object_store), + ) + .await?; + } + Ok(num_rows) } } diff --git a/datafusion/datasource/src/file_stream.rs b/datafusion/datasource/src/file_stream.rs index 1caefc3277ac..25546b3263c9 100644 --- a/datafusion/datasource/src/file_stream.rs +++ b/datafusion/datasource/src/file_stream.rs @@ -78,7 +78,7 @@ impl FileStream { file_opener: Arc, metrics: &ExecutionPlanMetricsSet, ) -> Result { - let (projected_schema, ..) = config.project(); + let projected_schema = config.projected_schema(); let pc_projector = PartitionColumnProjector::new( Arc::clone(&projected_schema), &config @@ -120,16 +120,17 @@ impl FileStream { let part_file = self.file_iter.pop_front()?; let file_meta = FileMeta { - object_meta: part_file.object_meta, - range: part_file.range, - extensions: part_file.extensions, + object_meta: part_file.object_meta.clone(), + range: part_file.range.clone(), + extensions: part_file.extensions.clone(), metadata_size_hint: part_file.metadata_size_hint, }; + let partition_values = part_file.partition_values.clone(); Some( self.file_opener - .open(file_meta) - .map(|future| (future, part_file.partition_values)), + .open(file_meta, part_file) + .map(|future| (future, partition_values)), ) } @@ -367,7 +368,7 @@ impl Default for OnError { pub trait FileOpener: Unpin + Send + Sync { /// Asynchronously open the specified file and return a stream /// of [`RecordBatch`] - fn open(&self, file_meta: FileMeta) -> Result; + fn open(&self, file_meta: FileMeta, file: PartitionedFile) -> Result; } /// Represents the state of the next `FileOpenFuture`. Since we need to poll @@ -555,7 +556,11 @@ mod tests { } impl FileOpener for TestOpener { - fn open(&self, _file_meta: FileMeta) -> Result { + fn open( + &self, + _file_meta: FileMeta, + _file: PartitionedFile, + ) -> Result { let idx = self.current_idx.fetch_add(1, Ordering::SeqCst); if self.error_opening_idx.contains(&idx) { diff --git a/datafusion/datasource/src/memory.rs b/datafusion/datasource/src/memory.rs index 6d0e16ef4b91..f5eb354ea13f 100644 --- a/datafusion/datasource/src/memory.rs +++ b/datafusion/datasource/src/memory.rs @@ -15,345 +15,48 @@ // specific language governing permissions and limitations // under the License. -//! Execution plan for reading in-memory batches of data - use std::any::Any; +use std::cmp::Ordering; +use std::collections::BinaryHeap; use std::fmt; use std::fmt::Debug; use std::sync::Arc; use crate::sink::DataSink; use crate::source::{DataSource, DataSourceExec}; -use async_trait::async_trait; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; + +use arrow::array::{RecordBatch, RecordBatchOptions}; +use arrow::datatypes::{Schema, SchemaRef}; +use datafusion_common::{internal_err, plan_err, project_schema, Result, ScalarValue}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::equivalence::{ + OrderingEquivalenceClass, ProjectionMapping, +}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::utils::collect_columns; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use datafusion_physical_plan::memory::MemoryStream; use datafusion_physical_plan::projection::{ all_alias_free_columns, new_projections_for_columns, ProjectionExec, }; use datafusion_physical_plan::{ common, ColumnarValue, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, - PhysicalExpr, PlanProperties, SendableRecordBatchStream, Statistics, + PhysicalExpr, SendableRecordBatchStream, Statistics, }; -use arrow::array::{RecordBatch, RecordBatchOptions}; -use arrow::datatypes::{Schema, SchemaRef}; -use datafusion_common::{ - internal_err, plan_err, project_schema, Constraints, Result, ScalarValue, -}; -use datafusion_execution::TaskContext; -use datafusion_physical_expr::equivalence::ProjectionMapping; -use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; +use async_trait::async_trait; +use datafusion_physical_plan::coop::cooperative; +use datafusion_physical_plan::execution_plan::SchedulingType; use futures::StreamExt; +use itertools::Itertools; use tokio::sync::RwLock; -/// Execution plan for reading in-memory batches of data -#[derive(Clone)] -#[deprecated( - since = "46.0.0", - note = "use MemorySourceConfig and DataSourceExec instead" -)] -pub struct MemoryExec { - inner: DataSourceExec, - /// The partitions to query - partitions: Vec>, - /// Optional projection - projection: Option>, - // Sort information: one or more equivalent orderings - sort_information: Vec, - /// if partition sizes should be displayed - show_sizes: bool, -} - -#[allow(unused, deprecated)] -impl Debug for MemoryExec { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - self.inner.fmt_as(DisplayFormatType::Default, f) - } -} - -#[allow(unused, deprecated)] -impl DisplayAs for MemoryExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { - self.inner.fmt_as(t, f) - } -} - -#[allow(unused, deprecated)] -impl ExecutionPlan for MemoryExec { - fn name(&self) -> &'static str { - "MemoryExec" - } - - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - self.inner.properties() - } - - fn children(&self) -> Vec<&Arc> { - // This is a leaf node and has no children - vec![] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - // MemoryExec has no children - if children.is_empty() { - Ok(self) - } else { - internal_err!("Children cannot be replaced in {self:?}") - } - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - self.inner.execute(partition, context) - } - - /// We recompute the statistics dynamically from the arrow metadata as it is pretty cheap to do so - fn statistics(&self) -> Result { - self.inner.statistics() - } - - fn try_swapping_with_projection( - &self, - projection: &ProjectionExec, - ) -> Result>> { - self.inner.try_swapping_with_projection(projection) - } -} - -#[allow(unused, deprecated)] -impl MemoryExec { - /// Create a new execution plan for reading in-memory record batches - /// The provided `schema` should not have the projection applied. - pub fn try_new( - partitions: &[Vec], - schema: SchemaRef, - projection: Option>, - ) -> Result { - let source = MemorySourceConfig::try_new(partitions, schema, projection.clone())?; - let data_source = DataSourceExec::new(Arc::new(source)); - Ok(Self { - inner: data_source, - partitions: partitions.to_vec(), - projection, - sort_information: vec![], - show_sizes: true, - }) - } - - /// Create a new execution plan from a list of constant values (`ValuesExec`) - pub fn try_new_as_values( - schema: SchemaRef, - data: Vec>>, - ) -> Result { - if data.is_empty() { - return plan_err!("Values list cannot be empty"); - } - - let n_row = data.len(); - let n_col = schema.fields().len(); - - // We have this single row batch as a placeholder to satisfy evaluation argument - // and generate a single output row - let placeholder_schema = Arc::new(Schema::empty()); - let placeholder_batch = RecordBatch::try_new_with_options( - Arc::clone(&placeholder_schema), - vec![], - &RecordBatchOptions::new().with_row_count(Some(1)), - )?; - - // Evaluate each column - let arrays = (0..n_col) - .map(|j| { - (0..n_row) - .map(|i| { - let expr = &data[i][j]; - let result = expr.evaluate(&placeholder_batch)?; - - match result { - ColumnarValue::Scalar(scalar) => Ok(scalar), - ColumnarValue::Array(array) if array.len() == 1 => { - ScalarValue::try_from_array(&array, 0) - } - ColumnarValue::Array(_) => { - plan_err!("Cannot have array values in a values list") - } - } - }) - .collect::>>() - .and_then(ScalarValue::iter_to_array) - }) - .collect::>>()?; - - let batch = RecordBatch::try_new_with_options( - Arc::clone(&schema), - arrays, - &RecordBatchOptions::new().with_row_count(Some(n_row)), - )?; - - let partitions = vec![batch]; - Self::try_new_from_batches(Arc::clone(&schema), partitions) - } - - /// Create a new plan using the provided schema and batches. - /// - /// Errors if any of the batches don't match the provided schema, or if no - /// batches are provided. - pub fn try_new_from_batches( - schema: SchemaRef, - batches: Vec, - ) -> Result { - if batches.is_empty() { - return plan_err!("Values list cannot be empty"); - } - - for batch in &batches { - let batch_schema = batch.schema(); - if batch_schema != schema { - return plan_err!( - "Batch has invalid schema. Expected: {}, got: {}", - schema, - batch_schema - ); - } - } - - let partitions = vec![batches]; - let source = MemorySourceConfig { - partitions: partitions.clone(), - schema: Arc::clone(&schema), - projected_schema: Arc::clone(&schema), - projection: None, - sort_information: vec![], - show_sizes: true, - fetch: None, - }; - let data_source = DataSourceExec::new(Arc::new(source)); - Ok(Self { - inner: data_source, - partitions, - projection: None, - sort_information: vec![], - show_sizes: true, - }) - } - - fn memory_source_config(&self) -> MemorySourceConfig { - self.inner - .data_source() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - } - - pub fn with_constraints(mut self, constraints: Constraints) -> Self { - self.inner = self.inner.with_constraints(constraints); - self - } - - /// Set `show_sizes` to determine whether to display partition sizes - pub fn with_show_sizes(mut self, show_sizes: bool) -> Self { - let mut memory_source = self.memory_source_config(); - memory_source.show_sizes = show_sizes; - self.show_sizes = show_sizes; - self.inner = DataSourceExec::new(Arc::new(memory_source)); - self - } - - /// Ref to constraints - pub fn constraints(&self) -> &Constraints { - self.properties().equivalence_properties().constraints() - } - - /// Ref to partitions - pub fn partitions(&self) -> &[Vec] { - &self.partitions - } - - /// Ref to projection - pub fn projection(&self) -> &Option> { - &self.projection - } - - /// Show sizes - pub fn show_sizes(&self) -> bool { - self.show_sizes - } - - /// Ref to sort information - pub fn sort_information(&self) -> &[LexOrdering] { - &self.sort_information - } - - /// A memory table can be ordered by multiple expressions simultaneously. - /// [`EquivalenceProperties`] keeps track of expressions that describe the - /// global ordering of the schema. These columns are not necessarily same; e.g. - /// ```text - /// ┌-------┐ - /// | a | b | - /// |---|---| - /// | 1 | 9 | - /// | 2 | 8 | - /// | 3 | 7 | - /// | 5 | 5 | - /// └---┴---┘ - /// ``` - /// where both `a ASC` and `b DESC` can describe the table ordering. With - /// [`EquivalenceProperties`], we can keep track of these equivalences - /// and treat `a ASC` and `b DESC` as the same ordering requirement. - /// - /// Note that if there is an internal projection, that projection will be - /// also applied to the given `sort_information`. - pub fn try_with_sort_information( - mut self, - sort_information: Vec, - ) -> Result { - self.sort_information = sort_information.clone(); - let mut memory_source = self.memory_source_config(); - memory_source = memory_source.try_with_sort_information(sort_information)?; - self.inner = DataSourceExec::new(Arc::new(memory_source)); - Ok(self) - } - - /// Arc clone of ref to original schema - pub fn original_schema(&self) -> SchemaRef { - Arc::clone(&self.inner.schema()) - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - schema: SchemaRef, - orderings: &[LexOrdering], - constraints: Constraints, - partitions: &[Vec], - ) -> PlanProperties { - PlanProperties::new( - EquivalenceProperties::new_with_orderings(schema, orderings) - .with_constraints(constraints), - Partitioning::UnknownPartitioning(partitions.len()), - EmissionType::Incremental, - Boundedness::Bounded, - ) - } -} - /// Data source configuration for reading in-memory batches of data #[derive(Clone, Debug)] pub struct MemorySourceConfig { - /// The partitions to query + /// The partitions to query. + /// + /// Each partition is a `Vec`. partitions: Vec>, /// Schema representing the data before projection schema: SchemaRef, @@ -376,14 +79,14 @@ impl DataSource for MemorySourceConfig { partition: usize, _context: Arc, ) -> Result { - Ok(Box::pin( + Ok(Box::pin(cooperative( MemoryStream::try_new( self.partitions[partition].clone(), Arc::clone(&self.projected_schema), self.projection.clone(), )? .with_fetch(self.fetch), - )) + ))) } fn as_any(&self) -> &dyn Any { @@ -399,9 +102,7 @@ impl DataSource for MemorySourceConfig { let output_ordering = self .sort_information .first() - .map(|output_ordering| { - format!(", output_ordering={}", output_ordering) - }) + .map(|output_ordering| format!(", output_ordering={output_ordering}")) .unwrap_or_default(); let eq_properties = self.eq_properties(); @@ -409,12 +110,12 @@ impl DataSource for MemorySourceConfig { let constraints = if constraints.is_empty() { String::new() } else { - format!(", {}", constraints) + format!(", {constraints}") }; let limit = self .fetch - .map_or(String::new(), |limit| format!(", fetch={}", limit)); + .map_or(String::new(), |limit| format!(", fetch={limit}")); if self.show_sizes { write!( f, @@ -445,6 +146,39 @@ impl DataSource for MemorySourceConfig { } } + /// If possible, redistribute batches across partitions according to their size. + /// + /// Returns `Ok(None)` if unable to repartition. Preserve output ordering if exists. + /// Refer to [`DataSource::repartitioned`] for further details. + fn repartitioned( + &self, + target_partitions: usize, + _repartition_file_min_size: usize, + output_ordering: Option, + ) -> Result>> { + if self.partitions.is_empty() || self.partitions.len() >= target_partitions + // if have no partitions, or already have more partitions than desired, do not repartition + { + return Ok(None); + } + + let maybe_repartitioned = if let Some(output_ordering) = output_ordering { + self.repartition_preserving_order(target_partitions, output_ordering)? + } else { + self.repartition_evenly_by_size(target_partitions)? + }; + + if let Some(repartitioned) = maybe_repartitioned { + Ok(Some(Arc::new(Self::try_new( + &repartitioned, + self.original_schema(), + self.projection.clone(), + )?))) + } else { + Ok(None) + } + } + fn output_partitioning(&self) -> Partitioning { Partitioning::UnknownPartitioning(self.partitions.len()) } @@ -452,10 +186,14 @@ impl DataSource for MemorySourceConfig { fn eq_properties(&self) -> EquivalenceProperties { EquivalenceProperties::new_with_orderings( Arc::clone(&self.projected_schema), - self.sort_information.as_slice(), + self.sort_information.clone(), ) } + fn scheduling_type(&self) -> SchedulingType { + SchedulingType::Cooperative + } + fn statistics(&self) -> Result { Ok(common::compute_record_batch_statistics( &self.partitions, @@ -695,24 +433,21 @@ impl MemorySourceConfig { // If there is a projection on the source, we also need to project orderings if let Some(projection) = &self.projection { + let base_schema = self.original_schema(); + let proj_exprs = projection.iter().map(|idx| { + let name = base_schema.field(*idx).name(); + (Arc::new(Column::new(name, *idx)) as _, name.to_string()) + }); + let projection_mapping = + ProjectionMapping::try_new(proj_exprs, &base_schema)?; let base_eqp = EquivalenceProperties::new_with_orderings( - self.original_schema(), - &sort_information, + Arc::clone(&base_schema), + sort_information, ); - let proj_exprs = projection - .iter() - .map(|idx| { - let base_schema = self.original_schema(); - let name = base_schema.field(*idx).name(); - (Arc::new(Column::new(name, *idx)) as _, name.to_string()) - }) - .collect::>(); - let projection_mapping = - ProjectionMapping::try_new(&proj_exprs, &self.original_schema())?; - sort_information = base_eqp - .project(&projection_mapping, Arc::clone(&self.projected_schema)) - .into_oeq_class() - .into_inner(); + let proj_eqp = + base_eqp.project(&projection_mapping, Arc::clone(&self.projected_schema)); + let oeq_class: OrderingEquivalenceClass = proj_eqp.into(); + sort_information = oeq_class.into(); } self.sort_information = sort_information; @@ -723,6 +458,226 @@ impl MemorySourceConfig { pub fn original_schema(&self) -> SchemaRef { Arc::clone(&self.schema) } + + /// Repartition while preserving order. + /// + /// Returns `Ok(None)` if cannot fulfill the requested repartitioning, such + /// as having too few batches to fulfill the `target_partitions` or if unable + /// to preserve output ordering. + fn repartition_preserving_order( + &self, + target_partitions: usize, + output_ordering: LexOrdering, + ) -> Result>>> { + if !self.eq_properties().ordering_satisfy(output_ordering)? { + Ok(None) + } else { + let total_num_batches = + self.partitions.iter().map(|b| b.len()).sum::(); + if total_num_batches < target_partitions { + // no way to create the desired repartitioning + return Ok(None); + } + + let cnt_to_repartition = target_partitions - self.partitions.len(); + + // Label the current partitions and their order. + // Such that when we later split up the partitions into smaller sizes, we are maintaining the order. + let to_repartition = self + .partitions + .iter() + .enumerate() + .map(|(idx, batches)| RePartition { + idx: idx + (cnt_to_repartition * idx), // make space in ordering for split partitions + row_count: batches.iter().map(|batch| batch.num_rows()).sum(), + batches: batches.clone(), + }) + .collect_vec(); + + // Put all of the partitions into a heap ordered by `RePartition::partial_cmp`, which sizes + // by count of rows. + let mut max_heap = BinaryHeap::with_capacity(target_partitions); + for rep in to_repartition { + max_heap.push(rep); + } + + // Split the largest partitions into smaller partitions. Maintaining the output + // order of the partitions & newly created partitions. + let mut cannot_split_further = Vec::with_capacity(target_partitions); + for _ in 0..cnt_to_repartition { + // triggers loop for the cnt_to_repartition. So if need another 4 partitions, it attempts to split 4 times. + loop { + // Take the largest item off the heap, and attempt to split. + let Some(to_split) = max_heap.pop() else { + // Nothing left to attempt repartition. Break inner loop. + break; + }; + + // Split the partition. The new partitions will be ordered with idx and idx+1. + let mut new_partitions = to_split.split(); + if new_partitions.len() > 1 { + for new_partition in new_partitions { + max_heap.push(new_partition); + } + // Successful repartition. Break inner loop, and return to outer `cnt_to_repartition` loop. + break; + } else { + cannot_split_further.push(new_partitions.remove(0)); + } + } + } + let mut partitions = max_heap.drain().collect_vec(); + partitions.extend(cannot_split_further); + + // Finally, sort all partitions by the output ordering. + // This was the original ordering of the batches within the partition. We are maintaining this ordering. + partitions.sort_by_key(|p| p.idx); + let partitions = partitions.into_iter().map(|rep| rep.batches).collect_vec(); + + Ok(Some(partitions)) + } + } + + /// Repartition into evenly sized chunks (as much as possible without batch splitting), + /// disregarding any ordering. + /// + /// Current implementation uses a first-fit-decreasing bin packing, modified to enable + /// us to still return the desired count of `target_partitions`. + /// + /// Returns `Ok(None)` if cannot fulfill the requested repartitioning, such + /// as having too few batches to fulfill the `target_partitions`. + fn repartition_evenly_by_size( + &self, + target_partitions: usize, + ) -> Result>>> { + // determine if we have enough total batches to fulfill request + let mut flatten_batches = + self.partitions.clone().into_iter().flatten().collect_vec(); + if flatten_batches.len() < target_partitions { + return Ok(None); + } + + // Take all flattened batches (all in 1 partititon/vec) and divide evenly into the desired number of `target_partitions`. + let total_num_rows = flatten_batches.iter().map(|b| b.num_rows()).sum::(); + // sort by size, so we pack multiple smaller batches into the same partition + flatten_batches.sort_by_key(|b| std::cmp::Reverse(b.num_rows())); + + // Divide. + let mut partitions = + vec![Vec::with_capacity(flatten_batches.len()); target_partitions]; + let mut target_partition_size = total_num_rows.div_ceil(target_partitions); + let mut total_rows_seen = 0; + let mut curr_bin_row_count = 0; + let mut idx = 0; + for batch in flatten_batches { + let row_cnt = batch.num_rows(); + idx = std::cmp::min(idx, target_partitions - 1); + + partitions[idx].push(batch); + curr_bin_row_count += row_cnt; + total_rows_seen += row_cnt; + + if curr_bin_row_count >= target_partition_size { + idx += 1; + curr_bin_row_count = 0; + + // update target_partition_size, to handle very lopsided batch distributions + // while still returning the count of `target_partitions` + if total_rows_seen < total_num_rows { + target_partition_size = (total_num_rows - total_rows_seen) + .div_ceil(target_partitions - idx); + } + } + } + + Ok(Some(partitions)) + } +} + +/// For use in repartitioning, track the total size and original partition index. +/// +/// Do not implement clone, in order to avoid unnecessary copying during repartitioning. +struct RePartition { + /// Original output ordering for the partition. + idx: usize, + /// Total size of the partition, for use in heap ordering + /// (a.k.a. splitting up the largest partitions). + row_count: usize, + /// A partition containing record batches. + batches: Vec, +} + +impl RePartition { + /// Split [`RePartition`] into 2 pieces, consuming self. + /// + /// Returns only 1 partition if cannot be split further. + fn split(self) -> Vec { + if self.batches.len() == 1 { + return vec![self]; + } + + let new_0 = RePartition { + idx: self.idx, // output ordering + row_count: 0, + batches: vec![], + }; + let new_1 = RePartition { + idx: self.idx + 1, // output ordering +1 + row_count: 0, + batches: vec![], + }; + let split_pt = self.row_count / 2; + + let [new_0, new_1] = self.batches.into_iter().fold( + [new_0, new_1], + |[mut new0, mut new1], batch| { + if new0.row_count < split_pt { + new0.add_batch(batch); + } else { + new1.add_batch(batch); + } + [new0, new1] + }, + ); + vec![new_0, new_1] + } + + fn add_batch(&mut self, batch: RecordBatch) { + self.row_count += batch.num_rows(); + self.batches.push(batch); + } +} + +impl PartialOrd for RePartition { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.row_count.cmp(&other.row_count)) + } +} + +impl Ord for RePartition { + fn cmp(&self, other: &Self) -> Ordering { + self.row_count.cmp(&other.row_count) + } +} + +impl PartialEq for RePartition { + fn eq(&self, other: &Self) -> bool { + self.row_count.eq(&other.row_count) + } +} + +impl Eq for RePartition {} + +impl fmt::Display for RePartition { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}rows-in-{}batches@{}", + self.row_count, + self.batches.len(), + self.idx + ) + } } /// Type alias for partition data @@ -816,22 +771,22 @@ mod memory_source_tests { use crate::memory::MemorySourceConfig; use crate::source::DataSourceExec; - use datafusion_physical_plan::ExecutionPlan; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::Result; use datafusion_physical_expr::expressions::col; - use datafusion_physical_expr::PhysicalSortExpr; - use datafusion_physical_expr_common::sort_expr::LexOrdering; + use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_plan::ExecutionPlan; #[test] - fn test_memory_order_eq() -> datafusion_common::Result<()> { + fn test_memory_order_eq() -> Result<()> { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int64, false), Field::new("b", DataType::Int64, false), Field::new("c", DataType::Int64, false), ])); - let sort1 = LexOrdering::new(vec![ + let sort1: LexOrdering = [ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions::default(), @@ -840,13 +795,14 @@ mod memory_source_tests { expr: col("b", &schema)?, options: SortOptions::default(), }, - ]); - let sort2 = LexOrdering::new(vec![PhysicalSortExpr { + ] + .into(); + let sort2: LexOrdering = [PhysicalSortExpr { expr: col("c", &schema)?, options: SortOptions::default(), - }]); - let mut expected_output_order = LexOrdering::default(); - expected_output_order.extend(sort1.clone()); + }] + .into(); + let mut expected_output_order = sort1.clone(); expected_output_order.extend(sort2.clone()); let sort_information = vec![sort1.clone(), sort2.clone()]; @@ -868,15 +824,17 @@ mod memory_source_tests { #[cfg(test)] mod tests { - use crate::tests::{aggr_test_schema, make_partition}; - use super::*; + use crate::test_util::col; + use crate::tests::{aggr_test_schema, make_partition}; - use datafusion_physical_plan::expressions::lit; - + use arrow::array::{ArrayRef, Int32Array, Int64Array, StringArray}; use arrow::datatypes::{DataType, Field}; use datafusion_common::assert_batches_eq; use datafusion_common::stats::{ColumnStatistics, Precision}; + use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_plan::expressions::lit; + use futures::StreamExt; #[tokio::test] @@ -976,7 +934,7 @@ mod tests { )?; assert_eq!( - values.statistics()?, + values.partition_statistics(None)?, Statistics { num_rows: Precision::Exact(rows), total_byte_size: Precision::Exact(8), // not important @@ -992,4 +950,458 @@ mod tests { Ok(()) } + + fn batch(row_size: usize) -> RecordBatch { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("foo"); row_size])); + let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![1; row_size])); + RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap() + } + + fn schema() -> SchemaRef { + batch(1).schema() + } + + fn memorysrcconfig_no_partitions( + sort_information: Vec, + ) -> Result { + let partitions = vec![]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + fn memorysrcconfig_1_partition_1_batch( + sort_information: Vec, + ) -> Result { + let partitions = vec![vec![batch(100)]]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + fn memorysrcconfig_3_partitions_1_batch_each( + sort_information: Vec, + ) -> Result { + let partitions = vec![vec![batch(100)], vec![batch(100)], vec![batch(100)]]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + fn memorysrcconfig_3_partitions_with_2_batches_each( + sort_information: Vec, + ) -> Result { + let partitions = vec![ + vec![batch(100), batch(100)], + vec![batch(100), batch(100)], + vec![batch(100), batch(100)], + ]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + /// Batches of different sizes, with batches ordered by size (100_000, 10_000, 100, 1) + /// in the Memtable partition (a.k.a. vector of batches). + fn memorysrcconfig_1_partition_with_different_sized_batches( + sort_information: Vec, + ) -> Result { + let partitions = vec![vec![batch(100_000), batch(10_000), batch(100), batch(1)]]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + /// Same as [`memorysrcconfig_1_partition_with_different_sized_batches`], + /// but the batches are ordered differently (not by size) + /// in the Memtable partition (a.k.a. vector of batches). + fn memorysrcconfig_1_partition_with_ordering_not_matching_size( + sort_information: Vec, + ) -> Result { + let partitions = vec![vec![batch(100_000), batch(1), batch(100), batch(10_000)]]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + fn memorysrcconfig_2_partition_with_different_sized_batches( + sort_information: Vec, + ) -> Result { + let partitions = vec![ + vec![batch(100_000), batch(10_000), batch(1_000)], + vec![batch(2_000), batch(20)], + ]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + fn memorysrcconfig_2_partition_with_extreme_sized_batches( + sort_information: Vec, + ) -> Result { + let partitions = vec![ + vec![ + batch(100_000), + batch(1), + batch(1), + batch(1), + batch(1), + batch(0), + ], + vec![batch(1), batch(1), batch(1), batch(1), batch(0), batch(100)], + ]; + MemorySourceConfig::try_new(&partitions, schema(), None)? + .try_with_sort_information(sort_information) + } + + /// Assert that we get the expected count of partitions after repartitioning. + /// + /// If None, then we expected the [`DataSource::repartitioned`] to return None. + fn assert_partitioning( + partitioned_datasrc: Option>, + partition_cnt: Option, + ) { + let should_exist = if let Some(partition_cnt) = partition_cnt { + format!("new datasource should exist and have {partition_cnt:?} partitions") + } else { + "new datasource should not exist".into() + }; + + let actual = partitioned_datasrc + .map(|datasrc| datasrc.output_partitioning().partition_count()); + assert_eq!( + actual, + partition_cnt, + "partitioned datasrc does not match expected, we expected {should_exist}, instead found {actual:?}" + ); + } + + fn run_all_test_scenarios( + output_ordering: Option, + sort_information_on_config: Vec, + ) -> Result<()> { + let not_used = usize::MAX; + + // src has no partitions + let mem_src_config = + memorysrcconfig_no_partitions(sort_information_on_config.clone())?; + let partitioned_datasrc = + mem_src_config.repartitioned(1, not_used, output_ordering.clone())?; + assert_partitioning(partitioned_datasrc, None); + + // src has partitions == target partitions (=1) + let target_partitions = 1; + let mem_src_config = + memorysrcconfig_1_partition_1_batch(sort_information_on_config.clone())?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, None); + + // src has partitions == target partitions (=3) + let target_partitions = 3; + let mem_src_config = memorysrcconfig_3_partitions_1_batch_each( + sort_information_on_config.clone(), + )?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, None); + + // src has partitions > target partitions, but we don't merge them + let target_partitions = 2; + let mem_src_config = memorysrcconfig_3_partitions_1_batch_each( + sort_information_on_config.clone(), + )?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, None); + + // src has partitions < target partitions, but not enough batches (per partition) to split into more partitions + let target_partitions = 4; + let mem_src_config = memorysrcconfig_3_partitions_1_batch_each( + sort_information_on_config.clone(), + )?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, None); + + // src has partitions < target partitions, and can split to sufficient amount + // has 6 batches across 3 partitions. Will need to split 2 of it's partitions. + let target_partitions = 5; + let mem_src_config = memorysrcconfig_3_partitions_with_2_batches_each( + sort_information_on_config.clone(), + )?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, Some(5)); + + // src has partitions < target partitions, and can split to sufficient amount + // has 6 batches across 3 partitions. Will need to split all of it's partitions. + let target_partitions = 6; + let mem_src_config = memorysrcconfig_3_partitions_with_2_batches_each( + sort_information_on_config.clone(), + )?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, Some(6)); + + // src has partitions < target partitions, but not enough total batches to fulfill the split (desired target_partitions) + let target_partitions = 3 * 2 + 1; + let mem_src_config = memorysrcconfig_3_partitions_with_2_batches_each( + sort_information_on_config.clone(), + )?; + let partitioned_datasrc = mem_src_config.repartitioned( + target_partitions, + not_used, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc, None); + + // src has 1 partition with many batches of lopsided sizes + // make sure it handles the split properly + let target_partitions = 2; + let mem_src_config = memorysrcconfig_1_partition_with_different_sized_batches( + sort_information_on_config, + )?; + let partitioned_datasrc = mem_src_config.clone().repartitioned( + target_partitions, + not_used, + output_ordering, + )?; + assert_partitioning(partitioned_datasrc.clone(), Some(2)); + // Starting = batch(100_000), batch(10_000), batch(100), batch(1). + // It should have split as p1=batch(100_000), p2=[batch(10_000), batch(100), batch(1)] + let partitioned_datasrc = partitioned_datasrc.unwrap(); + let Some(mem_src_config) = partitioned_datasrc + .as_any() + .downcast_ref::() + else { + unreachable!() + }; + let repartitioned_raw_batches = mem_src_config.partitions.clone(); + assert_eq!(repartitioned_raw_batches.len(), 2); + let [ref p1, ref p2] = repartitioned_raw_batches[..] else { + unreachable!() + }; + // p1=batch(100_000) + assert_eq!(p1.len(), 1); + assert_eq!(p1[0].num_rows(), 100_000); + // p2=[batch(10_000), batch(100), batch(1)] + assert_eq!(p2.len(), 3); + assert_eq!(p2[0].num_rows(), 10_000); + assert_eq!(p2[1].num_rows(), 100); + assert_eq!(p2[2].num_rows(), 1); + + Ok(()) + } + + #[test] + fn test_repartition_no_sort_information_no_output_ordering() -> Result<()> { + let no_sort = vec![]; + let no_output_ordering = None; + + // Test: common set of functionality + run_all_test_scenarios(no_output_ordering.clone(), no_sort.clone())?; + + // Test: how no-sort-order divides differently. + // * does not preserve separate partitions (with own internal ordering) on even split, + // * nor does it preserve ordering (re-orders batch(2_000) vs batch(1_000)). + let target_partitions = 3; + let mem_src_config = + memorysrcconfig_2_partition_with_different_sized_batches(no_sort)?; + let partitioned_datasrc = mem_src_config.clone().repartitioned( + target_partitions, + usize::MAX, + no_output_ordering, + )?; + assert_partitioning(partitioned_datasrc.clone(), Some(3)); + // Starting = batch(100_000), batch(10_000), batch(1_000), batch(2_000), batch(20) + // It should have split as p1=batch(100_000), p2=batch(10_000), p3=rest(mixed across original partitions) + let repartitioned_raw_batches = mem_src_config + .repartition_evenly_by_size(target_partitions)? + .unwrap(); + assert_eq!(repartitioned_raw_batches.len(), 3); + let [ref p1, ref p2, ref p3] = repartitioned_raw_batches[..] else { + unreachable!() + }; + // p1=batch(100_000) + assert_eq!(p1.len(), 1); + assert_eq!(p1[0].num_rows(), 100_000); + // p2=batch(10_000) + assert_eq!(p2.len(), 1); + assert_eq!(p2[0].num_rows(), 10_000); + // p3= batch(2_000), batch(1_000), batch(20) + assert_eq!(p3.len(), 3); + assert_eq!(p3[0].num_rows(), 2_000); + assert_eq!(p3[1].num_rows(), 1_000); + assert_eq!(p3[2].num_rows(), 20); + + Ok(()) + } + + #[test] + fn test_repartition_no_sort_information_no_output_ordering_lopsized_batches( + ) -> Result<()> { + let no_sort = vec![]; + let no_output_ordering = None; + + // Test: case has two input partitions: + // b(100_000), b(1), b(1), b(1), b(1), b(0) + // b(1), b(1), b(1), b(1), b(0), b(100) + // + // We want an output with target_partitions=5, which means the ideal division is: + // b(100_000) + // b(100) + // b(1), b(1), b(1) + // b(1), b(1), b(1) + // b(1), b(1), b(0) + let target_partitions = 5; + let mem_src_config = + memorysrcconfig_2_partition_with_extreme_sized_batches(no_sort)?; + let partitioned_datasrc = mem_src_config.clone().repartitioned( + target_partitions, + usize::MAX, + no_output_ordering, + )?; + assert_partitioning(partitioned_datasrc.clone(), Some(5)); + // Starting partition 1 = batch(100_000), batch(1), batch(1), batch(1), batch(1), batch(0) + // Starting partition 1 = batch(1), batch(1), batch(1), batch(1), batch(0), batch(100) + // It should have split as p1=batch(100_000), p2=batch(100), p3=[batch(1),batch(1)], p4=[batch(1),batch(1)], p5=[batch(1),batch(1),batch(0),batch(0)] + let repartitioned_raw_batches = mem_src_config + .repartition_evenly_by_size(target_partitions)? + .unwrap(); + assert_eq!(repartitioned_raw_batches.len(), 5); + let [ref p1, ref p2, ref p3, ref p4, ref p5] = repartitioned_raw_batches[..] + else { + unreachable!() + }; + // p1=batch(100_000) + assert_eq!(p1.len(), 1); + assert_eq!(p1[0].num_rows(), 100_000); + // p2=batch(100) + assert_eq!(p2.len(), 1); + assert_eq!(p2[0].num_rows(), 100); + // p3=[batch(1),batch(1),batch(1)] + assert_eq!(p3.len(), 3); + assert_eq!(p3[0].num_rows(), 1); + assert_eq!(p3[1].num_rows(), 1); + assert_eq!(p3[2].num_rows(), 1); + // p4=[batch(1),batch(1),batch(1)] + assert_eq!(p4.len(), 3); + assert_eq!(p4[0].num_rows(), 1); + assert_eq!(p4[1].num_rows(), 1); + assert_eq!(p4[2].num_rows(), 1); + // p5=[batch(1),batch(1),batch(0),batch(0)] + assert_eq!(p5.len(), 4); + assert_eq!(p5[0].num_rows(), 1); + assert_eq!(p5[1].num_rows(), 1); + assert_eq!(p5[2].num_rows(), 0); + assert_eq!(p5[3].num_rows(), 0); + + Ok(()) + } + + #[test] + fn test_repartition_with_sort_information() -> Result<()> { + let schema = schema(); + let sort_key: LexOrdering = + [PhysicalSortExpr::new_default(col("c", &schema)?)].into(); + let has_sort = vec![sort_key.clone()]; + let output_ordering = Some(sort_key); + + // Test: common set of functionality + run_all_test_scenarios(output_ordering.clone(), has_sort.clone())?; + + // Test: DOES preserve separate partitions (with own internal ordering) + let target_partitions = 3; + let mem_src_config = + memorysrcconfig_2_partition_with_different_sized_batches(has_sort)?; + let partitioned_datasrc = mem_src_config.clone().repartitioned( + target_partitions, + usize::MAX, + output_ordering.clone(), + )?; + assert_partitioning(partitioned_datasrc.clone(), Some(3)); + // Starting = batch(100_000), batch(10_000), batch(1_000), batch(2_000), batch(20) + // It should have split as p1=batch(100_000), p2=[batch(10_000),batch(1_000)], p3= + let Some(output_ord) = output_ordering else { + unreachable!() + }; + let repartitioned_raw_batches = mem_src_config + .repartition_preserving_order(target_partitions, output_ord)? + .unwrap(); + assert_eq!(repartitioned_raw_batches.len(), 3); + let [ref p1, ref p2, ref p3] = repartitioned_raw_batches[..] else { + unreachable!() + }; + // p1=batch(100_000) + assert_eq!(p1.len(), 1); + assert_eq!(p1[0].num_rows(), 100_000); + // p2=[batch(10_000),batch(1_000)] + assert_eq!(p2.len(), 2); + assert_eq!(p2[0].num_rows(), 10_000); + assert_eq!(p2[1].num_rows(), 1_000); + // p3=batch(2_000), batch(20) + assert_eq!(p3.len(), 2); + assert_eq!(p3[0].num_rows(), 2_000); + assert_eq!(p3[1].num_rows(), 20); + + Ok(()) + } + + #[test] + fn test_repartition_with_batch_ordering_not_matching_sizing() -> Result<()> { + let schema = schema(); + let sort_key: LexOrdering = + [PhysicalSortExpr::new_default(col("c", &schema)?)].into(); + let has_sort = vec![sort_key.clone()]; + let output_ordering = Some(sort_key); + + // src has 1 partition with many batches of lopsided sizes + // note that the input vector of batches are not ordered by decreasing size + let target_partitions = 2; + let mem_src_config = + memorysrcconfig_1_partition_with_ordering_not_matching_size(has_sort)?; + let partitioned_datasrc = mem_src_config.clone().repartitioned( + target_partitions, + usize::MAX, + output_ordering, + )?; + assert_partitioning(partitioned_datasrc.clone(), Some(2)); + // Starting = batch(100_000), batch(1), batch(100), batch(10_000). + // It should have split as p1=batch(100_000), p2=[batch(1), batch(100), batch(10_000)] + let partitioned_datasrc = partitioned_datasrc.unwrap(); + let Some(mem_src_config) = partitioned_datasrc + .as_any() + .downcast_ref::() + else { + unreachable!() + }; + let repartitioned_raw_batches = mem_src_config.partitions.clone(); + assert_eq!(repartitioned_raw_batches.len(), 2); + let [ref p1, ref p2] = repartitioned_raw_batches[..] else { + unreachable!() + }; + // p1=batch(100_000) + assert_eq!(p1.len(), 1); + assert_eq!(p1[0].num_rows(), 100_000); + // p2=[batch(1), batch(100), batch(10_000)] -- **this is preserving the partition order** + assert_eq!(p2.len(), 3); + assert_eq!(p2[0].num_rows(), 1); + assert_eq!(p2[1].num_rows(), 100); + assert_eq!(p2[2].num_rows(), 10_000); + + Ok(()) + } } diff --git a/datafusion/datasource/src/mod.rs b/datafusion/datasource/src/mod.rs index 3e44851d145b..92e25a97c3a4 100644 --- a/datafusion/datasource/src/mod.rs +++ b/datafusion/datasource/src/mod.rs @@ -48,6 +48,7 @@ pub mod test_util; pub mod url; pub mod write; +pub use self::file::as_file_source; pub use self::url::ListingTableUrl; use crate::file_groups::FileGroup; use chrono::TimeZone; @@ -197,6 +198,23 @@ impl PartitionedFile { self.statistics = Some(statistics); self } + + /// Check if this file has any statistics. + /// This returns `true` if the file has any Exact or Inexact statistics + /// and `false` if all statistics are `Precision::Absent`. + pub fn has_statistics(&self) -> bool { + if let Some(stats) = &self.statistics { + stats.column_statistics.iter().any(|col_stats| { + col_stats.null_count != Precision::Absent + || col_stats.max_value != Precision::Absent + || col_stats.min_value != Precision::Absent + || col_stats.sum_value != Precision::Absent + || col_stats.distinct_count != Precision::Absent + }) + } else { + false + } + } } impl From for PartitionedFile { @@ -379,7 +397,7 @@ pub fn generate_test_files(num_files: usize, overlap_factor: f64) -> Vec datafusion_common::Result + Send + Sync; /// Factory for creating [`SchemaAdapter`] /// @@ -96,6 +103,12 @@ pub trait SchemaAdapter: Send + Sync { pub trait SchemaMapper: Debug + Send + Sync { /// Adapts a `RecordBatch` to match the `table_schema` fn map_batch(&self, batch: RecordBatch) -> datafusion_common::Result; + + /// Adapts file-level column `Statistics` to match the `table_schema` + fn map_column_statistics( + &self, + file_col_statistics: &[ColumnStatistics], + ) -> datafusion_common::Result>; } /// Default [`SchemaAdapterFactory`] for mapping schemas. @@ -219,6 +232,32 @@ pub(crate) struct DefaultSchemaAdapter { projected_table_schema: SchemaRef, } +/// Checks if a file field can be cast to a table field +/// +/// Returns Ok(true) if casting is possible, or an error explaining why casting is not possible +pub(crate) fn can_cast_field( + file_field: &Field, + table_field: &Field, +) -> datafusion_common::Result { + match (file_field.data_type(), table_field.data_type()) { + (DataType::Struct(source_fields), DataType::Struct(target_fields)) => { + validate_struct_compatibility(source_fields, target_fields) + } + _ => { + if can_cast_types(file_field.data_type(), table_field.data_type()) { + Ok(true) + } else { + plan_err!( + "Cannot cast file schema field {} of type {:?} to table schema field of type {:?}", + file_field.name(), + file_field.data_type(), + table_field.data_type() + ) + } + } + } +} + impl SchemaAdapter for DefaultSchemaAdapter { /// Map a column index in the table schema to a column index in a particular /// file schema @@ -242,40 +281,54 @@ impl SchemaAdapter for DefaultSchemaAdapter { &self, file_schema: &Schema, ) -> datafusion_common::Result<(Arc, Vec)> { - let mut projection = Vec::with_capacity(file_schema.fields().len()); - let mut field_mappings = vec![None; self.projected_table_schema.fields().len()]; - - for (file_idx, file_field) in file_schema.fields.iter().enumerate() { - if let Some((table_idx, table_field)) = - self.projected_table_schema.fields().find(file_field.name()) - { - match can_cast_types(file_field.data_type(), table_field.data_type()) { - true => { - field_mappings[table_idx] = Some(projection.len()); - projection.push(file_idx); - } - false => { - return plan_err!( - "Cannot cast file schema field {} of type {:?} to table schema field of type {:?}", - file_field.name(), - file_field.data_type(), - table_field.data_type() - ) - } - } - } - } + let (field_mappings, projection) = create_field_mapping( + file_schema, + &self.projected_table_schema, + can_cast_field, + )?; Ok(( - Arc::new(SchemaMapping { - projected_table_schema: Arc::clone(&self.projected_table_schema), + Arc::new(SchemaMapping::new( + Arc::clone(&self.projected_table_schema), field_mappings, - }), + Arc::new(|array: &ArrayRef, field: &Field| cast_column(array, field)), + )), projection, )) } } +/// Helper function that creates field mappings between file schema and table schema +/// +/// Maps columns from the file schema to their corresponding positions in the table schema, +/// applying type compatibility checking via the provided predicate function. +/// +/// Returns field mappings (for column reordering) and a projection (for field selection). +pub(crate) fn create_field_mapping( + file_schema: &Schema, + projected_table_schema: &SchemaRef, + can_map_field: F, +) -> datafusion_common::Result<(Vec>, Vec)> +where + F: Fn(&Field, &Field) -> datafusion_common::Result, +{ + let mut projection = Vec::with_capacity(file_schema.fields().len()); + let mut field_mappings = vec![None; projected_table_schema.fields().len()]; + + for (file_idx, file_field) in file_schema.fields.iter().enumerate() { + if let Some((table_idx, table_field)) = + projected_table_schema.fields().find(file_field.name()) + { + if can_map_field(file_field, table_field)? { + field_mappings[table_idx] = Some(projection.len()); + projection.push(file_idx); + } + } + } + + Ok((field_mappings, projection)) +} + /// The SchemaMapping struct holds a mapping from the file schema to the table /// schema and any necessary type conversions. /// @@ -285,7 +338,6 @@ impl SchemaAdapter for DefaultSchemaAdapter { /// `projected_table_schema` as it can only operate on the projected fields. /// /// [`map_batch`]: Self::map_batch -#[derive(Debug)] pub struct SchemaMapping { /// The schema of the table. This is the expected schema after conversion /// and it should match the schema of the query result. @@ -296,6 +348,36 @@ pub struct SchemaMapping { /// They are Options instead of just plain `usize`s because the table could /// have fields that don't exist in the file. field_mappings: Vec>, + /// Function used to adapt a column from the file schema to the table schema + /// when it exists in both schemas + cast_column: Arc, +} + +impl Debug for SchemaMapping { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SchemaMapping") + .field("projected_table_schema", &self.projected_table_schema) + .field("field_mappings", &self.field_mappings) + .field("cast_column", &"") + .finish() + } +} + +impl SchemaMapping { + /// Creates a new SchemaMapping instance + /// + /// Initializes the field mappings needed to transform file data to the projected table schema + pub fn new( + projected_table_schema: SchemaRef, + field_mappings: Vec>, + cast_column: Arc, + ) -> Self { + Self { + projected_table_schema, + field_mappings, + cast_column, + } + } } impl SchemaMapper for SchemaMapping { @@ -320,9 +402,9 @@ impl SchemaMapper for SchemaMapping { // If this field only exists in the table, and not in the file, then we know // that it's null, so just return that. || Ok(new_null_array(field.data_type(), batch_rows)), - // However, if it does exist in both, then try to cast it to the correct output - // type - |batch_idx| cast(&batch_cols[batch_idx], field.data_type()), + // However, if it does exist in both, use the cast_column function + // to perform any necessary conversions + |batch_idx| (self.cast_column)(&batch_cols[batch_idx], field), ) }) .collect::, _>>()?; @@ -334,4 +416,597 @@ impl SchemaMapper for SchemaMapping { let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?; Ok(record_batch) } + + /// Adapts file-level column `Statistics` to match the `table_schema` + fn map_column_statistics( + &self, + file_col_statistics: &[ColumnStatistics], + ) -> datafusion_common::Result> { + let mut table_col_statistics = vec![]; + + // Map the statistics for each field in the file schema to the corresponding field in the + // table schema, if a field is not present in the file schema, we need to fill it with `ColumnStatistics::new_unknown` + for (_, file_col_idx) in self + .projected_table_schema + .fields() + .iter() + .zip(&self.field_mappings) + { + if let Some(file_col_idx) = file_col_idx { + table_col_statistics.push( + file_col_statistics + .get(*file_col_idx) + .cloned() + .unwrap_or_default(), + ); + } else { + table_col_statistics.push(ColumnStatistics::new_unknown()); + } + } + + Ok(table_col_statistics) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::{ + array::{Array, ArrayRef, StringBuilder, StructArray, TimestampMillisecondArray}, + compute::cast, + datatypes::{DataType, Field, TimeUnit}, + record_batch::RecordBatch, + }; + use datafusion_common::{stats::Precision, Result, ScalarValue, Statistics}; + + #[test] + fn test_schema_mapping_map_statistics_basic() { + // Create table schema (a, b, c) + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Float64, true), + ])); + + // Create file schema (b, a) - different order, missing c + let file_schema = Schema::new(vec![ + Field::new("b", DataType::Utf8, true), + Field::new("a", DataType::Int32, true), + ]); + + // Create SchemaAdapter + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + + // Get mapper and projection + let (mapper, projection) = adapter.map_schema(&file_schema).unwrap(); + + // Should project columns 0,1 from file + assert_eq!(projection, vec![0, 1]); + + // Create file statistics + let mut file_stats = Statistics::default(); + + // Statistics for column b (index 0 in file) + let b_stats = ColumnStatistics { + null_count: Precision::Exact(5), + ..Default::default() + }; + + // Statistics for column a (index 1 in file) + let a_stats = ColumnStatistics { + null_count: Precision::Exact(10), + ..Default::default() + }; + + file_stats.column_statistics = vec![b_stats, a_stats]; + + // Map statistics + let table_col_stats = mapper + .map_column_statistics(&file_stats.column_statistics) + .unwrap(); + + // Verify stats + assert_eq!(table_col_stats.len(), 3); + assert_eq!(table_col_stats[0].null_count, Precision::Exact(10)); // a from file idx 1 + assert_eq!(table_col_stats[1].null_count, Precision::Exact(5)); // b from file idx 0 + assert_eq!(table_col_stats[2].null_count, Precision::Absent); // c (unknown) + } + + #[test] + fn test_schema_mapping_map_statistics_empty() { + // Create schemas + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ])); + let file_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]); + + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + let (mapper, _) = adapter.map_schema(&file_schema).unwrap(); + + // Empty file statistics + let file_stats = Statistics::default(); + let table_col_stats = mapper + .map_column_statistics(&file_stats.column_statistics) + .unwrap(); + + // All stats should be unknown + assert_eq!(table_col_stats.len(), 2); + assert_eq!(table_col_stats[0], ColumnStatistics::new_unknown(),); + assert_eq!(table_col_stats[1], ColumnStatistics::new_unknown(),); + } + + #[test] + fn test_can_cast_field() { + // Same type should work + let from_field = Field::new("col", DataType::Int32, true); + let to_field = Field::new("col", DataType::Int32, true); + assert!(can_cast_field(&from_field, &to_field).unwrap()); + + // Casting Int32 to Float64 is allowed + let from_field = Field::new("col", DataType::Int32, true); + let to_field = Field::new("col", DataType::Float64, true); + assert!(can_cast_field(&from_field, &to_field).unwrap()); + + // Casting Float64 to Utf8 should work (converts to string) + let from_field = Field::new("col", DataType::Float64, true); + let to_field = Field::new("col", DataType::Utf8, true); + assert!(can_cast_field(&from_field, &to_field).unwrap()); + + // Binary to Utf8 is not supported - this is an example of a cast that should fail + // Note: We use Binary instead of Utf8->Int32 because Arrow actually supports that cast + let from_field = Field::new("col", DataType::Binary, true); + let to_field = Field::new("col", DataType::Decimal128(10, 2), true); + let result = can_cast_field(&from_field, &to_field); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot cast file schema field col")); + } + + #[test] + fn test_create_field_mapping() { + // Define the table schema + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Float64, true), + ])); + + // Define file schema: different order, missing column c, and b has different type + let file_schema = Schema::new(vec![ + Field::new("b", DataType::Float64, true), // Different type but castable to Utf8 + Field::new("a", DataType::Int32, true), // Same type + Field::new("d", DataType::Boolean, true), // Not in table schema + ]); + + // Custom can_map_field function that allows all mappings for testing + let allow_all = |_: &Field, _: &Field| Ok(true); + + // Test field mapping + let (field_mappings, projection) = + create_field_mapping(&file_schema, &table_schema, allow_all).unwrap(); + + // Expected: + // - field_mappings[0] (a) maps to projection[1] + // - field_mappings[1] (b) maps to projection[0] + // - field_mappings[2] (c) is None (not in file) + assert_eq!(field_mappings, vec![Some(1), Some(0), None]); + assert_eq!(projection, vec![0, 1]); // Projecting file columns b, a + + // Test with a failing mapper + let fails_all = |_: &Field, _: &Field| Ok(false); + let (field_mappings, projection) = + create_field_mapping(&file_schema, &table_schema, fails_all).unwrap(); + + // Should have no mappings or projections if all cast checks fail + assert_eq!(field_mappings, vec![None, None, None]); + assert_eq!(projection, Vec::::new()); + + // Test with error-producing mapper + let error_mapper = |_: &Field, _: &Field| plan_err!("Test error"); + let result = create_field_mapping(&file_schema, &table_schema, error_mapper); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Test error")); + } + + #[test] + fn test_schema_mapping_new() { + // Define the projected table schema + let projected_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ])); + + // Define field mappings from table to file + let field_mappings = vec![Some(1), Some(0)]; + + // Create SchemaMapping manually + let mapping = SchemaMapping::new( + Arc::clone(&projected_schema), + field_mappings.clone(), + Arc::new(|array: &ArrayRef, field: &Field| cast_column(array, field)), + ); + + // Check that fields were set correctly + assert_eq!(*mapping.projected_table_schema, *projected_schema); + assert_eq!(mapping.field_mappings, field_mappings); + + // Test with a batch to ensure it works properly + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("b_file", DataType::Utf8, true), + Field::new("a_file", DataType::Int32, true), + ])), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["hello", "world"])), + Arc::new(arrow::array::Int32Array::from(vec![1, 2])), + ], + ) + .unwrap(); + + // Test that map_batch works with our manually created mapping + let mapped_batch = mapping.map_batch(batch).unwrap(); + + // Verify the mapped batch has the correct schema and data + assert_eq!(*mapped_batch.schema(), *projected_schema); + assert_eq!(mapped_batch.num_columns(), 2); + assert_eq!(mapped_batch.column(0).len(), 2); // a column + assert_eq!(mapped_batch.column(1).len(), 2); // b column + } + + #[test] + fn test_map_schema_error_path() { + // Define the table schema + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Decimal128(10, 2), true), // Use Decimal which has stricter cast rules + ])); + + // Define file schema with incompatible type for column c + let file_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Float64, true), // Different but castable + Field::new("c", DataType::Binary, true), // Not castable to Decimal128 + ]); + + // Create DefaultSchemaAdapter + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + + // map_schema should error due to incompatible types + let result = adapter.map_schema(&file_schema); + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Cannot cast file schema field c")); + } + + #[test] + fn test_map_schema_happy_path() { + // Define the table schema + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Decimal128(10, 2), true), + ])); + + // Create DefaultSchemaAdapter + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + + // Define compatible file schema (missing column c) + let compatible_file_schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), // Can be cast to Int32 + Field::new("b", DataType::Float64, true), // Can be cast to Utf8 + ]); + + // Test successful schema mapping + let (mapper, projection) = adapter.map_schema(&compatible_file_schema).unwrap(); + + // Verify field_mappings and projection created correctly + assert_eq!(projection, vec![0, 1]); // Projecting a and b + + // Verify the SchemaMapping works with actual data + let file_batch = RecordBatch::try_new( + Arc::new(compatible_file_schema.clone()), + vec![ + Arc::new(arrow::array::Int64Array::from(vec![100, 200])), + Arc::new(arrow::array::Float64Array::from(vec![1.5, 2.5])), + ], + ) + .unwrap(); + + let mapped_batch = mapper.map_batch(file_batch).unwrap(); + + // Verify correct schema mapping + assert_eq!(*mapped_batch.schema(), *table_schema); + assert_eq!(mapped_batch.num_columns(), 3); // a, b, c + + // Column c should be null since it wasn't in the file schema + let c_array = mapped_batch.column(2); + assert_eq!(c_array.len(), 2); + assert_eq!(c_array.null_count(), 2); + } + + #[test] + fn test_adapt_struct_with_added_nested_fields() -> Result<()> { + let (file_schema, table_schema) = create_test_schemas_with_nested_fields(); + let batch = create_test_batch_with_struct_data(&file_schema)?; + + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + let (mapper, _) = adapter.map_schema(file_schema.as_ref())?; + let mapped_batch = mapper.map_batch(batch)?; + + verify_adapted_batch_with_nested_fields(&mapped_batch, &table_schema)?; + Ok(()) + } + + #[test] + fn test_map_column_statistics_struct() -> Result<()> { + let (file_schema, table_schema) = create_test_schemas_with_nested_fields(); + + let adapter = DefaultSchemaAdapter { + projected_table_schema: Arc::clone(&table_schema), + }; + let (mapper, _) = adapter.map_schema(file_schema.as_ref())?; + + let file_stats = vec![ + create_test_column_statistics( + 0, + 100, + Some(ScalarValue::Int32(Some(1))), + Some(ScalarValue::Int32(Some(100))), + Some(ScalarValue::Int32(Some(5100))), + ), + create_test_column_statistics(10, 50, None, None, None), + ]; + + let table_stats = mapper.map_column_statistics(&file_stats)?; + assert_eq!(table_stats.len(), 1); + verify_column_statistics( + &table_stats[0], + Some(0), + Some(100), + Some(ScalarValue::Int32(Some(1))), + Some(ScalarValue::Int32(Some(100))), + Some(ScalarValue::Int32(Some(5100))), + ); + let missing_stats = mapper.map_column_statistics(&[])?; + assert_eq!(missing_stats.len(), 1); + assert_eq!(missing_stats[0], ColumnStatistics::new_unknown()); + Ok(()) + } + + fn create_test_schemas_with_nested_fields() -> (SchemaRef, SchemaRef) { + let file_schema = Arc::new(Schema::new(vec![Field::new( + "info", + DataType::Struct( + vec![ + Field::new("location", DataType::Utf8, true), + Field::new( + "timestamp_utc", + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), + true, + ), + ] + .into(), + ), + true, + )])); + + let table_schema = Arc::new(Schema::new(vec![Field::new( + "info", + DataType::Struct( + vec![ + Field::new("location", DataType::Utf8, true), + Field::new( + "timestamp_utc", + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), + true, + ), + Field::new( + "reason", + DataType::Struct( + vec![ + Field::new("_level", DataType::Float64, true), + Field::new( + "details", + DataType::Struct( + vec![ + Field::new("rurl", DataType::Utf8, true), + Field::new("s", DataType::Float64, true), + Field::new("t", DataType::Utf8, true), + ] + .into(), + ), + true, + ), + ] + .into(), + ), + true, + ), + ] + .into(), + ), + true, + )])); + + (file_schema, table_schema) + } + + fn create_test_batch_with_struct_data( + file_schema: &SchemaRef, + ) -> Result { + let mut location_builder = StringBuilder::new(); + location_builder.append_value("San Francisco"); + location_builder.append_value("New York"); + + let timestamp_array = TimestampMillisecondArray::from(vec![ + Some(1640995200000), + Some(1641081600000), + ]); + + let timestamp_type = + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())); + let timestamp_array = cast(×tamp_array, ×tamp_type)?; + + let info_struct = StructArray::from(vec![ + ( + Arc::new(Field::new("location", DataType::Utf8, true)), + Arc::new(location_builder.finish()) as ArrayRef, + ), + ( + Arc::new(Field::new("timestamp_utc", timestamp_type, true)), + timestamp_array, + ), + ]); + + Ok(RecordBatch::try_new( + Arc::clone(file_schema), + vec![Arc::new(info_struct)], + )?) + } + + fn verify_adapted_batch_with_nested_fields( + mapped_batch: &RecordBatch, + table_schema: &SchemaRef, + ) -> Result<()> { + assert_eq!(mapped_batch.schema(), *table_schema); + assert_eq!(mapped_batch.num_rows(), 2); + + let info_col = mapped_batch.column(0); + let info_array = info_col + .as_any() + .downcast_ref::() + .expect("Expected info column to be a StructArray"); + + verify_preserved_fields(info_array)?; + verify_reason_field_structure(info_array)?; + Ok(()) + } + + fn verify_preserved_fields(info_array: &StructArray) -> Result<()> { + let location_col = info_array + .column_by_name("location") + .expect("Expected location field in struct"); + let location_array = location_col + .as_any() + .downcast_ref::() + .expect("Expected location to be a StringArray"); + assert_eq!(location_array.value(0), "San Francisco"); + assert_eq!(location_array.value(1), "New York"); + + let timestamp_col = info_array + .column_by_name("timestamp_utc") + .expect("Expected timestamp_utc field in struct"); + let timestamp_array = timestamp_col + .as_any() + .downcast_ref::() + .expect("Expected timestamp_utc to be a TimestampMillisecondArray"); + assert_eq!(timestamp_array.value(0), 1640995200000); + assert_eq!(timestamp_array.value(1), 1641081600000); + Ok(()) + } + + fn verify_reason_field_structure(info_array: &StructArray) -> Result<()> { + let reason_col = info_array + .column_by_name("reason") + .expect("Expected reason field in struct"); + let reason_array = reason_col + .as_any() + .downcast_ref::() + .expect("Expected reason to be a StructArray"); + assert_eq!(reason_array.fields().len(), 2); + assert!(reason_array.column_by_name("_level").is_some()); + assert!(reason_array.column_by_name("details").is_some()); + + let details_col = reason_array + .column_by_name("details") + .expect("Expected details field in reason struct"); + let details_array = details_col + .as_any() + .downcast_ref::() + .expect("Expected details to be a StructArray"); + assert_eq!(details_array.fields().len(), 3); + assert!(details_array.column_by_name("rurl").is_some()); + assert!(details_array.column_by_name("s").is_some()); + assert!(details_array.column_by_name("t").is_some()); + for i in 0..2 { + assert!(reason_array.is_null(i), "reason field should be null"); + } + Ok(()) + } + + fn verify_column_statistics( + stats: &ColumnStatistics, + expected_null_count: Option, + expected_distinct_count: Option, + expected_min: Option, + expected_max: Option, + expected_sum: Option, + ) { + if let Some(count) = expected_null_count { + assert_eq!( + stats.null_count, + Precision::Exact(count), + "Null count should match expected value" + ); + } + if let Some(count) = expected_distinct_count { + assert_eq!( + stats.distinct_count, + Precision::Exact(count), + "Distinct count should match expected value" + ); + } + if let Some(min) = expected_min { + assert_eq!( + stats.min_value, + Precision::Exact(min), + "Min value should match expected value" + ); + } + if let Some(max) = expected_max { + assert_eq!( + stats.max_value, + Precision::Exact(max), + "Max value should match expected value" + ); + } + if let Some(sum) = expected_sum { + assert_eq!( + stats.sum_value, + Precision::Exact(sum), + "Sum value should match expected value" + ); + } + } + + fn create_test_column_statistics( + null_count: usize, + distinct_count: usize, + min_value: Option, + max_value: Option, + sum_value: Option, + ) -> ColumnStatistics { + ColumnStatistics { + null_count: Precision::Exact(null_count), + distinct_count: Precision::Exact(distinct_count), + min_value: min_value.map_or_else(|| Precision::Absent, Precision::Exact), + max_value: max_value.map_or_else(|| Precision::Absent, Precision::Exact), + sum_value: sum_value.map_or_else(|| Precision::Absent, Precision::Exact), + } + } } diff --git a/datafusion/datasource/src/sink.rs b/datafusion/datasource/src/sink.rs index 0552370d8ed0..b8c5b42bf767 100644 --- a/datafusion/datasource/src/sink.rs +++ b/datafusion/datasource/src/sink.rs @@ -22,22 +22,21 @@ use std::fmt; use std::fmt::Debug; use std::sync::Arc; -use datafusion_physical_plan::metrics::MetricsSet; -use datafusion_physical_plan::stream::RecordBatchStreamAdapter; -use datafusion_physical_plan::ExecutionPlanProperties; -use datafusion_physical_plan::{ - execute_input_stream, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, - PlanProperties, SendableRecordBatchStream, -}; - use arrow::array::{ArrayRef, RecordBatch, UInt64Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{Distribution, EquivalenceProperties}; -use datafusion_physical_expr_common::sort_expr::LexRequirement; +use datafusion_physical_expr_common::sort_expr::{LexRequirement, OrderingRequirements}; +use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use datafusion_physical_plan::{ + execute_input_stream, DisplayAs, DisplayFormatType, ExecutionPlan, + ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream, +}; use async_trait::async_trait; +use datafusion_physical_plan::execution_plan::{EvaluationType, SchedulingType}; use futures::StreamExt; /// `DataSink` implements writing streams of [`RecordBatch`]es to @@ -143,6 +142,8 @@ impl DataSinkExec { input.pipeline_behavior(), input.boundedness(), ) + .with_scheduling_type(SchedulingType::Cooperative) + .with_evaluation_type(EvaluationType::Eager) } } @@ -184,10 +185,10 @@ impl ExecutionPlan for DataSinkExec { vec![Distribution::SinglePartition; self.children().len()] } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { // The required input ordering is set externally (e.g. by a `ListingTable`). - // Otherwise, there is no specific requirement (i.e. `sort_expr` is `None`). - vec![self.sort_order.as_ref().cloned()] + // Otherwise, there is no specific requirement (i.e. `sort_order` is `None`). + vec![self.sort_order.as_ref().cloned().map(Into::into)] } fn maintains_input_order(&self) -> Vec { diff --git a/datafusion/datasource/src/source.rs b/datafusion/datasource/src/source.rs index 6c9122ce1ac1..4dda95b0856b 100644 --- a/datafusion/datasource/src/source.rs +++ b/datafusion/datasource/src/source.rs @@ -22,7 +22,9 @@ use std::fmt; use std::fmt::{Debug, Formatter}; use std::sync::Arc; -use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion_physical_plan::execution_plan::{ + Boundedness, EmissionType, SchedulingType, +}; use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::{ @@ -31,44 +33,122 @@ use datafusion_physical_plan::{ use crate::file_scan_config::FileScanConfig; use datafusion_common::config::ConfigOptions; -use datafusion_common::{Constraints, Statistics}; +use datafusion_common::{Constraints, Result, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning, PhysicalExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::filter_pushdown::{ + ChildPushdownResult, FilterPushdownPhase, FilterPushdownPropagation, +}; -/// Common behaviors in Data Sources for both from Files and Memory. +/// A source of data, typically a list of files or memory +/// +/// This trait provides common behaviors for abstract sources of data. It has +/// two common implementations: +/// +/// 1. [`FileScanConfig`]: lists of files +/// 2. [`MemorySourceConfig`]: in memory list of `RecordBatch` +/// +/// File format specific behaviors are defined by [`FileSource`] /// /// # See Also -/// * [`DataSourceExec`] for physical plan implementation -/// * [`FileSource`] for file format implementations (Parquet, Json, etc) +/// * [`FileSource`] for file format specific implementations (Parquet, Json, etc) +/// * [`DataSourceExec`]: The [`ExecutionPlan`] that reads from a `DataSource` /// /// # Notes +/// /// Requires `Debug` to assist debugging /// +/// [`FileScanConfig`]: https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.FileScanConfig.html +/// [`MemorySourceConfig`]: https://docs.rs/datafusion/latest/datafusion/datasource/memory/struct.MemorySourceConfig.html /// [`FileSource`]: crate::file::FileSource +/// [`FileFormat``]: https://docs.rs/datafusion/latest/datafusion/datasource/file_format/index.html +/// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableProvider.html +/// +/// The following diagram shows how DataSource, FileSource, and DataSourceExec are related +/// ```text +/// ┌─────────────────────┐ -----► execute path +/// │ │ ┄┄┄┄┄► init path +/// │ DataSourceExec │ +/// │ │ +/// └───────▲─────────────┘ +/// ┊ │ +/// ┊ │ +/// ┌──────────▼──────────┐ ┌──────────-──────────┐ +/// │ │ | | +/// │ DataSource(trait) │ | TableProvider(trait)| +/// │ │ | | +/// └───────▲─────────────┘ └─────────────────────┘ +/// ┊ │ ┊ +/// ┌───────────────┿──┴────────────────┐ ┊ +/// | ┌┄┄┄┄┄┄┄┄┄┄┄┘ | ┊ +/// | ┊ | ┊ +/// ┌──────────▼──────────┐ ┌──────────▼──────────┐ ┊ +/// │ │ │ │ ┌──────────▼──────────┐ +/// │ FileScanConfig │ │ MemorySourceConfig │ | | +/// │ │ │ │ | FileFormat(trait) | +/// └──────────────▲──────┘ └─────────────────────┘ | | +/// │ ┊ └─────────────────────┘ +/// │ ┊ ┊ +/// │ ┊ ┊ +/// ┌──────────▼──────────┐ ┌──────────▼──────────┐ +/// │ │ │ ArrowSource │ +/// │ FileSource(trait) ◄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄┄│ ... │ +/// │ │ │ ParquetSource │ +/// └─────────────────────┘ └─────────────────────┘ +/// │ +/// │ +/// │ +/// │ +/// ┌──────────▼──────────┐ +/// │ ArrowSource │ +/// │ ... │ +/// │ ParquetSource │ +/// └─────────────────────┘ +/// | +/// FileOpener (called by FileStream) +/// │ +/// ┌──────────▼──────────┐ +/// │ │ +/// │ RecordBatch │ +/// │ │ +/// └─────────────────────┘ +/// ``` pub trait DataSource: Send + Sync + Debug { fn open( &self, partition: usize, context: Arc, - ) -> datafusion_common::Result; + ) -> Result; fn as_any(&self) -> &dyn Any; /// Format this source for display in explain plans fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result; - /// Return a copy of this DataSource with a new partitioning scheme + /// Return a copy of this DataSource with a new partitioning scheme. + /// + /// Returns `Ok(None)` (the default) if the partitioning cannot be changed. + /// Refer to [`ExecutionPlan::repartitioned`] for details on when None should be returned. + /// + /// Repartitioning should not change the output ordering, if this ordering exists. + /// Refer to [`MemorySourceConfig::repartition_preserving_order`](crate::memory::MemorySourceConfig) + /// and the FileSource's + /// [`FileGroupPartitioner::repartition_file_groups`](crate::file_groups::FileGroupPartitioner::repartition_file_groups) + /// for examples. fn repartitioned( &self, _target_partitions: usize, _repartition_file_min_size: usize, _output_ordering: Option, - ) -> datafusion_common::Result>> { + ) -> Result>> { Ok(None) } fn output_partitioning(&self) -> Partitioning; fn eq_properties(&self) -> EquivalenceProperties; - fn statistics(&self) -> datafusion_common::Result; + fn scheduling_type(&self) -> SchedulingType { + SchedulingType::NonCooperative + } + fn statistics(&self) -> Result; /// Return a copy of this DataSource with a new fetch limit fn with_fetch(&self, _limit: Option) -> Option>; fn fetch(&self) -> Option; @@ -78,17 +158,30 @@ pub trait DataSource: Send + Sync + Debug { fn try_swapping_with_projection( &self, _projection: &ProjectionExec, - ) -> datafusion_common::Result>>; + ) -> Result>>; + /// Try to push down filters into this DataSource. + /// See [`ExecutionPlan::handle_child_pushdown_result`] for more details. + /// + /// [`ExecutionPlan::handle_child_pushdown_result`]: datafusion_physical_plan::ExecutionPlan::handle_child_pushdown_result + fn try_pushdown_filters( + &self, + filters: Vec>, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::unsupported(filters)) + } } -/// [`ExecutionPlan`] handles different file formats like JSON, CSV, AVRO, ARROW, PARQUET +/// [`ExecutionPlan`] that reads one or more files +/// +/// `DataSourceExec` implements common functionality such as applying +/// projections, and caching plan properties. /// -/// `DataSourceExec` implements common functionality such as applying projections, -/// and caching plan properties. +/// The [`DataSource`] describes where to find the data for this data source +/// (for example in files or what in memory partitions). /// -/// The [`DataSource`] trait describes where to find the data for this data -/// source (for example what files or what in memory partitions). Format -/// specifics are implemented with the [`FileSource`] trait. +/// For file based [`DataSource`]s, format specific behavior is implemented in +/// the [`FileSource`] trait. /// /// [`FileSource`]: crate::file::FileSource #[derive(Clone, Debug)] @@ -131,15 +224,19 @@ impl ExecutionPlan for DataSourceExec { fn with_new_children( self: Arc, _: Vec>, - ) -> datafusion_common::Result> { + ) -> Result> { Ok(self) } + /// Implementation of [`ExecutionPlan::repartitioned`] which relies upon the inner [`DataSource::repartitioned`]. + /// + /// If the data source does not support changing its partitioning, returns `Ok(None)` (the default). Refer + /// to [`ExecutionPlan::repartitioned`] for more details. fn repartitioned( &self, target_partitions: usize, config: &ConfigOptions, - ) -> datafusion_common::Result>> { + ) -> Result>> { let data_source = self.data_source.repartitioned( target_partitions, config.optimizer.repartition_file_min_size, @@ -163,18 +260,36 @@ impl ExecutionPlan for DataSourceExec { &self, partition: usize, context: Arc, - ) -> datafusion_common::Result { - self.data_source.open(partition, context) + ) -> Result { + self.data_source.open(partition, Arc::clone(&context)) } fn metrics(&self) -> Option { Some(self.data_source.metrics().clone_inner()) } - fn statistics(&self) -> datafusion_common::Result { + fn statistics(&self) -> Result { self.data_source.statistics() } + fn partition_statistics(&self, partition: Option) -> Result { + if let Some(partition) = partition { + let mut statistics = Statistics::new_unknown(&self.schema()); + if let Some(file_config) = + self.data_source.as_any().downcast_ref::() + { + if let Some(file_group) = file_config.file_groups.get(partition) { + if let Some(stat) = file_group.file_statistics(None) { + statistics = stat.clone(); + } + } + } + Ok(statistics) + } else { + Ok(self.data_source.statistics()?) + } + } + fn with_fetch(&self, limit: Option) -> Option> { let data_source = self.data_source.with_fetch(limit)?; let cache = self.cache.clone(); @@ -189,9 +304,38 @@ impl ExecutionPlan for DataSourceExec { fn try_swapping_with_projection( &self, projection: &ProjectionExec, - ) -> datafusion_common::Result>> { + ) -> Result>> { self.data_source.try_swapping_with_projection(projection) } + + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + config: &ConfigOptions, + ) -> Result>> { + // Push any remaining filters into our data source + let res = self.data_source.try_pushdown_filters( + child_pushdown_result.parent_filters.collect_all(), + config, + )?; + match res.updated_node { + Some(data_source) => { + let mut new_node = self.clone(); + new_node.data_source = data_source; + new_node.cache = + Self::compute_properties(Arc::clone(&new_node.data_source)); + Ok(FilterPushdownPropagation { + filters: res.filters, + updated_node: Some(Arc::new(new_node)), + }) + } + None => Ok(FilterPushdownPropagation { + filters: res.filters, + updated_node: None, + }), + } + } } impl DataSourceExec { @@ -199,6 +343,7 @@ impl DataSourceExec { Arc::new(Self::new(Arc::new(data_source))) } + // Default constructor for `DataSourceExec`, setting the `cooperative` flag to `true`. pub fn new(data_source: Arc) -> Self { let cache = Self::compute_properties(Arc::clone(&data_source)); Self { data_source, cache } @@ -234,6 +379,7 @@ impl DataSourceExec { EmissionType::Incremental, Boundedness::Bounded, ) + .with_scheduling_type(data_source.scheduling_type()) } /// Downcast the `DataSourceExec`'s `data_source` to a specific file source @@ -254,3 +400,13 @@ impl DataSourceExec { }) } } + +/// Create a new `DataSourceExec` from a `DataSource` +impl From for DataSourceExec +where + S: DataSource + 'static, +{ + fn from(source: S) -> Self { + Self::new(Arc::new(source)) + } +} diff --git a/datafusion/datasource/src/statistics.rs b/datafusion/datasource/src/statistics.rs index 8a04d77b273d..db9af0ff7675 100644 --- a/datafusion/datasource/src/statistics.rs +++ b/datafusion/datasource/src/statistics.rs @@ -20,24 +20,25 @@ //! Currently, this module houses code to sort file groups if they are non-overlapping with //! respect to the required sort order. See [`MinMaxStatistics`] -use futures::{Stream, StreamExt}; use std::sync::Arc; use crate::file_groups::FileGroup; use crate::PartitionedFile; use arrow::array::RecordBatch; +use arrow::compute::SortColumn; use arrow::datatypes::SchemaRef; -use arrow::{ - compute::SortColumn, - row::{Row, Rows}, -}; +use arrow::row::{Row, Rows}; use datafusion_common::stats::Precision; -use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; -use datafusion_physical_expr::{expressions::Column, PhysicalSortExpr}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_common::{ + plan_datafusion_err, plan_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_plan::{ColumnStatistics, Statistics}; +use futures::{Stream, StreamExt}; + /// A normalized representation of file min/max statistics that allows for efficient sorting & comparison. /// The min/max values are ordered by [`Self::sort_order`]. /// Furthermore, any columns that are reversed in the sort order have their min/max values swapped. @@ -71,9 +72,7 @@ impl MinMaxStatistics { projection: Option<&[usize]>, // Indices of projection in full table schema (None = all columns) files: impl IntoIterator, ) -> Result { - use datafusion_common::ScalarValue; - - let statistics_and_partition_values = files + let Some(statistics_and_partition_values) = files .into_iter() .map(|file| { file.statistics @@ -81,9 +80,9 @@ impl MinMaxStatistics { .zip(Some(file.partition_values.as_slice())) }) .collect::>>() - .ok_or_else(|| { - DataFusionError::Plan("Parquet file missing statistics".to_string()) - })?; + else { + return plan_err!("Parquet file missing statistics"); + }; // Helper function to get min/max statistics for a given column of projected_schema let get_min_max = |i: usize| -> Result<(Vec, Vec)> { @@ -96,9 +95,7 @@ impl MinMaxStatistics { .get_value() .cloned() .zip(s.column_statistics[i].max_value.get_value().cloned()) - .ok_or_else(|| { - DataFusionError::Plan("statistics not found".to_string()) - }) + .ok_or_else(|| plan_datafusion_err!("statistics not found")) } else { let partition_value = &pv[i - s.column_statistics.len()]; Ok((partition_value.clone(), partition_value.clone())) @@ -109,27 +106,28 @@ impl MinMaxStatistics { .unzip()) }; - let sort_columns = sort_columns_from_physical_sort_exprs(projected_sort_order) - .ok_or(DataFusionError::Plan( - "sort expression must be on column".to_string(), - ))?; + let Some(sort_columns) = + sort_columns_from_physical_sort_exprs(projected_sort_order) + else { + return plan_err!("sort expression must be on column"); + }; // Project the schema & sort order down to just the relevant columns let min_max_schema = Arc::new( projected_schema .project(&(sort_columns.iter().map(|c| c.index()).collect::>()))?, ); - let min_max_sort_order = LexOrdering::from( - sort_columns - .iter() - .zip(projected_sort_order.iter()) - .enumerate() - .map(|(i, (col, sort))| PhysicalSortExpr { - expr: Arc::new(Column::new(col.name(), i)), - options: sort.options, - }) - .collect::>(), - ); + + let min_max_sort_order = projected_sort_order + .iter() + .zip(sort_columns.iter()) + .enumerate() + .map(|(idx, (sort_expr, col))| { + let expr = Arc::new(Column::new(col.name(), idx)); + PhysicalSortExpr::new(expr, sort_expr.options) + }); + // Safe to `unwrap` as we know that sort columns are non-empty: + let min_max_sort_order = LexOrdering::new(min_max_sort_order).unwrap(); let (min_values, max_values): (Vec<_>, Vec<_>) = sort_columns .iter() @@ -137,7 +135,9 @@ impl MinMaxStatistics { // Reverse the projection to get the index of the column in the full statistics // The file statistics contains _every_ column , but the sort column's index() // refers to the index in projected_schema - let i = projection.map(|p| p[c.index()]).unwrap_or(c.index()); + let i = projection + .map(|p| p[c.index()]) + .unwrap_or_else(|| c.index()); let (min, max) = get_min_max(i).map_err(|e| { e.context(format!("get min/max for column: '{}'", c.name())) @@ -187,25 +187,23 @@ impl MinMaxStatistics { .map_err(|e| e.context("create sort fields"))?; let converter = RowConverter::new(sort_fields)?; - let sort_columns = sort_columns_from_physical_sort_exprs(sort_order).ok_or( - DataFusionError::Plan("sort expression must be on column".to_string()), - )?; + let Some(sort_columns) = sort_columns_from_physical_sort_exprs(sort_order) else { + return plan_err!("sort expression must be on column"); + }; // swap min/max if they're reversed in the ordering let (new_min_cols, new_max_cols): (Vec<_>, Vec<_>) = sort_order .iter() .zip(sort_columns.iter().copied()) .map(|(sort_expr, column)| { - if sort_expr.options.descending { - max_values - .column_by_name(column.name()) - .zip(min_values.column_by_name(column.name())) + let maxes = max_values.column_by_name(column.name()); + let mins = min_values.column_by_name(column.name()); + let opt_value = if sort_expr.options.descending { + maxes.zip(mins) } else { - min_values - .column_by_name(column.name()) - .zip(max_values.column_by_name(column.name())) - } - .ok_or_else(|| { + mins.zip(maxes) + }; + opt_value.ok_or_else(|| { plan_datafusion_err!( "missing column in MinMaxStatistics::new: '{}'", column.name() @@ -283,7 +281,7 @@ fn sort_columns_from_physical_sort_exprs( sort_order .iter() .map(|expr| expr.expr.as_any().downcast_ref::()) - .collect::>>() + .collect() } /// Get all files as well as the file level summary statistics (no statistic for partition columns). @@ -476,7 +474,7 @@ pub fn compute_all_files_statistics( // Then summary statistics across all file groups let file_groups_statistics = file_groups_with_stats .iter() - .filter_map(|file_group| file_group.statistics()); + .filter_map(|file_group| file_group.file_statistics(None)); let mut statistics = Statistics::try_merge_iter(file_groups_statistics, &table_schema)?; diff --git a/datafusion/datasource/src/test_util.rs b/datafusion/datasource/src/test_util.rs index 9a9b98d5041b..e4a5114aa073 100644 --- a/datafusion/datasource/src/test_util.rs +++ b/datafusion/datasource/src/test_util.rs @@ -17,12 +17,14 @@ use crate::{ file::FileSource, file_scan_config::FileScanConfig, file_stream::FileOpener, + schema_adapter::SchemaAdapterFactory, }; use std::sync::Arc; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::{Result, Statistics}; +use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use object_store::ObjectStore; @@ -31,6 +33,7 @@ use object_store::ObjectStore; pub(crate) struct MockSource { metrics: ExecutionPlanMetricsSet, projected_statistics: Option, + schema_adapter_factory: Option>, } impl FileSource for MockSource { @@ -80,4 +83,23 @@ impl FileSource for MockSource { fn file_type(&self) -> &str { "mock" } + + fn with_schema_adapter_factory( + &self, + schema_adapter_factory: Arc, + ) -> Result> { + Ok(Arc::new(Self { + schema_adapter_factory: Some(schema_adapter_factory), + ..self.clone() + })) + } + + fn schema_adapter_factory(&self) -> Option> { + self.schema_adapter_factory.clone() + } +} + +/// Create a column expression +pub(crate) fn col(name: &str, schema: &Schema) -> Result> { + Ok(Arc::new(Column::new_with_schema(name, schema)?)) } diff --git a/datafusion/datasource/src/url.rs b/datafusion/datasource/src/url.rs index bddfdbcc06d1..348791be9828 100644 --- a/datafusion/datasource/src/url.rs +++ b/datafusion/datasource/src/url.rs @@ -282,6 +282,28 @@ impl ListingTableUrl { let url = &self.url[url::Position::BeforeScheme..url::Position::BeforePath]; ObjectStoreUrl::parse(url).unwrap() } + + /// Returns true if the [`ListingTableUrl`] points to the folder + pub fn is_folder(&self) -> bool { + self.url.scheme() == "file" && self.is_collection() + } + + /// Return the `url` for [`ListingTableUrl`] + pub fn get_url(&self) -> &Url { + &self.url + } + + /// Return the `glob` for [`ListingTableUrl`] + pub fn get_glob(&self) -> &Option { + &self.glob + } + + /// Returns a copy of current [`ListingTableUrl`] with a specified `glob` + pub fn with_glob(self, glob: &str) -> Result { + let glob = + Pattern::new(glob).map_err(|e| DataFusionError::External(Box::new(e)))?; + Self::try_new(self.url, Some(glob)) + } } /// Creates a file URL from a potentially relative filesystem path diff --git a/datafusion/datasource/src/write/demux.rs b/datafusion/datasource/src/write/demux.rs index 49c3a64d24aa..75fb557b63d2 100644 --- a/datafusion/datasource/src/write/demux.rs +++ b/datafusion/datasource/src/write/demux.rs @@ -45,7 +45,7 @@ use datafusion_execution::TaskContext; use chrono::NaiveDate; use futures::StreamExt; use object_store::path::Path; -use rand::distributions::DistString; +use rand::distr::SampleString; use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender}; type RecordBatchReceiver = Receiver; @@ -151,8 +151,7 @@ async fn row_count_demuxer( let max_buffered_batches = exec_options.max_buffered_batches_per_output_file; let minimum_parallel_files = exec_options.minimum_parallel_output_files; let mut part_idx = 0; - let write_id = - rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + let write_id = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), 16); let mut open_file_streams = Vec::with_capacity(minimum_parallel_files); @@ -225,7 +224,7 @@ fn generate_file_path( if !single_file_output { base_output_path .prefix() - .child(format!("{}_{}.{}", write_id, part_idx, file_extension)) + .child(format!("{write_id}_{part_idx}.{file_extension}")) } else { base_output_path.prefix().to_owned() } @@ -267,8 +266,7 @@ async fn hive_style_partitions_demuxer( file_extension: String, keep_partition_by_columns: bool, ) -> Result<()> { - let write_id = - rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); + let write_id = rand::distr::Alphanumeric.sample_string(&mut rand::rng(), 16); let exec_options = &context.session_config().options().execution; let max_buffered_recordbatches = exec_options.max_buffered_batches_per_output_file; @@ -513,7 +511,7 @@ fn compute_take_arrays( for vals in all_partition_values.iter() { part_key.push(vals[i].clone().into()); } - let builder = take_map.entry(part_key).or_insert(UInt64Builder::new()); + let builder = take_map.entry(part_key).or_insert_with(UInt64Builder::new); builder.append_value(i as u64); } take_map @@ -556,5 +554,5 @@ fn compute_hive_style_file_path( file_path = file_path.child(format!("{}={}", partition_by[j].0, part_key[j])); } - file_path.child(format!("{}.{}", write_id, file_extension)) + file_path.child(format!("{write_id}.{file_extension}")) } diff --git a/datafusion/datasource/src/write/mod.rs b/datafusion/datasource/src/write/mod.rs index f581126095a7..3694568682a5 100644 --- a/datafusion/datasource/src/write/mod.rs +++ b/datafusion/datasource/src/write/mod.rs @@ -77,15 +77,18 @@ pub trait BatchSerializer: Sync + Send { /// Returns an [`AsyncWrite`] which writes to the given object store location /// with the specified compression. +/// +/// The writer will have a default buffer size as chosen by [`BufWriter::new`]. +/// /// We drop the `AbortableWrite` struct and the writer will not try to cleanup on failure. /// Users can configure automatic cleanup with their cloud provider. +#[deprecated(since = "48.0.0", note = "Use ObjectWriterBuilder::new(...) instead")] pub async fn create_writer( file_compression_type: FileCompressionType, location: &Path, object_store: Arc, ) -> Result> { - let buf_writer = BufWriter::new(object_store, location.clone()); - file_compression_type.convert_async_writer(buf_writer) + ObjectWriterBuilder::new(file_compression_type, location, object_store).build() } /// Converts table schema to writer schema, which may differ in the case @@ -109,3 +112,108 @@ pub fn get_writer_schema(config: &FileSinkConfig) -> Arc { Arc::clone(config.output_schema()) } } + +/// A builder for an [`AsyncWrite`] that writes to an object store location. +/// +/// This can be used to specify file compression on the writer. The writer +/// will have a default buffer size unless altered. The specific default size +/// is chosen by [`BufWriter::new`]. +/// +/// We drop the `AbortableWrite` struct and the writer will not try to cleanup on failure. +/// Users can configure automatic cleanup with their cloud provider. +#[derive(Debug)] +pub struct ObjectWriterBuilder { + /// Compression type for object writer. + file_compression_type: FileCompressionType, + /// Output path + location: Path, + /// The related store that handles the given path + object_store: Arc, + /// The size of the buffer for the object writer. + buffer_size: Option, +} + +impl ObjectWriterBuilder { + /// Create a new [`ObjectWriterBuilder`] for the specified path and compression type. + pub fn new( + file_compression_type: FileCompressionType, + location: &Path, + object_store: Arc, + ) -> Self { + Self { + file_compression_type, + location: location.clone(), + object_store, + buffer_size: None, + } + } + + /// Set buffer size in bytes for object writer. + /// + /// # Example + /// ``` + /// # use datafusion_datasource::file_compression_type::FileCompressionType; + /// # use datafusion_datasource::write::ObjectWriterBuilder; + /// # use object_store::memory::InMemory; + /// # use object_store::path::Path; + /// # use std::sync::Arc; + /// # let compression_type = FileCompressionType::UNCOMPRESSED; + /// # let location = Path::from("/foo/bar"); + /// # let object_store = Arc::new(InMemory::new()); + /// let mut builder = ObjectWriterBuilder::new(compression_type, &location, object_store); + /// builder.set_buffer_size(Some(20 * 1024 * 1024)); //20 MiB + /// assert_eq!(builder.get_buffer_size(), Some(20 * 1024 * 1024), "Internal error: Builder buffer size doesn't match"); + /// ``` + pub fn set_buffer_size(&mut self, buffer_size: Option) { + self.buffer_size = buffer_size; + } + + /// Set buffer size in bytes for object writer, returning the builder. + /// + /// # Example + /// ``` + /// # use datafusion_datasource::file_compression_type::FileCompressionType; + /// # use datafusion_datasource::write::ObjectWriterBuilder; + /// # use object_store::memory::InMemory; + /// # use object_store::path::Path; + /// # use std::sync::Arc; + /// # let compression_type = FileCompressionType::UNCOMPRESSED; + /// # let location = Path::from("/foo/bar"); + /// # let object_store = Arc::new(InMemory::new()); + /// let builder = ObjectWriterBuilder::new(compression_type, &location, object_store) + /// .with_buffer_size(Some(20 * 1024 * 1024)); //20 MiB + /// assert_eq!(builder.get_buffer_size(), Some(20 * 1024 * 1024), "Internal error: Builder buffer size doesn't match"); + /// ``` + pub fn with_buffer_size(mut self, buffer_size: Option) -> Self { + self.buffer_size = buffer_size; + self + } + + /// Currently specified buffer size in bytes. + pub fn get_buffer_size(&self) -> Option { + self.buffer_size + } + + /// Return a writer object that writes to the object store location. + /// + /// If a buffer size has not been set, the default buffer buffer size will + /// be used. + /// + /// # Errors + /// If there is an error applying the compression type. + pub fn build(self) -> Result> { + let Self { + file_compression_type, + location, + object_store, + buffer_size, + } = self; + + let buf_writer = match buffer_size { + Some(size) => BufWriter::with_capacity(object_store, location, size), + None => BufWriter::new(object_store, location), + }; + + file_compression_type.convert_async_writer(buf_writer) + } +} diff --git a/datafusion/datasource/src/write/orchestration.rs b/datafusion/datasource/src/write/orchestration.rs index 0ac1d26c6cc1..a09509ac5862 100644 --- a/datafusion/datasource/src/write/orchestration.rs +++ b/datafusion/datasource/src/write/orchestration.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use super::demux::DemuxedStreamReceiver; -use super::{create_writer, BatchSerializer}; +use super::{BatchSerializer, ObjectWriterBuilder}; use crate::file_compression_type::FileCompressionType; use datafusion_common::error::Result; @@ -257,7 +257,15 @@ pub async fn spawn_writer_tasks_and_join( }); while let Some((location, rb_stream)) = file_stream_rx.recv().await { let writer = - create_writer(compression, &location, Arc::clone(&object_store)).await?; + ObjectWriterBuilder::new(compression, &location, Arc::clone(&object_store)) + .with_buffer_size(Some( + context + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + )) + .build()?; if tx_file_bundle .send((rb_stream, Arc::clone(&serializer), writer)) diff --git a/datafusion/doc/README.md b/datafusion/doc/README.md new file mode 100644 index 000000000000..c81a8e78c603 --- /dev/null +++ b/datafusion/doc/README.md @@ -0,0 +1,32 @@ + + +# DataFusion Execution + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate is a submodule of DataFusion that provides structures and macros +for documenting user defined functions. + +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/doc/src/lib.rs b/datafusion/doc/src/lib.rs index 68ed1e2352ca..f9b916c2b3ab 100644 --- a/datafusion/doc/src/lib.rs +++ b/datafusion/doc/src/lib.rs @@ -93,7 +93,7 @@ impl Documentation { self.doc_section.label, self.doc_section .description - .map(|s| format!(", description = \"{}\"", s)) + .map(|s| format!(", description = \"{s}\"")) .unwrap_or_default(), ) .as_ref(), @@ -110,7 +110,7 @@ impl Documentation { &self .sql_example .clone() - .map(|s| format!("\n sql_example = r#\"{}\"#,", s)) + .map(|s| format!("\n sql_example = r#\"{s}\"#,")) .unwrap_or_default(), ); @@ -120,7 +120,7 @@ impl Documentation { args.iter().for_each(|(name, value)| { if value.contains(st_arg_token) { if name.starts_with("The ") { - result.push_str(format!("\n standard_argument(\n name = \"{}\"),", name).as_ref()); + result.push_str(format!("\n standard_argument(\n name = \"{name}\"),").as_ref()); } else { result.push_str(format!("\n standard_argument(\n name = \"{}\",\n prefix = \"{}\"\n ),", name, value.replace(st_arg_token, "")).as_ref()); } @@ -132,7 +132,7 @@ impl Documentation { if let Some(args) = self.arguments.clone() { args.iter().for_each(|(name, value)| { if !value.contains(st_arg_token) { - result.push_str(format!("\n argument(\n name = \"{}\",\n description = \"{}\"\n ),", name, value).as_ref()); + result.push_str(format!("\n argument(\n name = \"{name}\",\n description = \"{value}\"\n ),").as_ref()); } }); } @@ -140,7 +140,7 @@ impl Documentation { if let Some(alt_syntax) = self.alternative_syntax.clone() { alt_syntax.iter().for_each(|syntax| { result.push_str( - format!("\n alternative_syntax = \"{}\",", syntax).as_ref(), + format!("\n alternative_syntax = \"{syntax}\",").as_ref(), ); }); } @@ -148,8 +148,7 @@ impl Documentation { // Related UDFs if let Some(related_udf) = self.related_udfs.clone() { related_udf.iter().for_each(|udf| { - result - .push_str(format!("\n related_udf(name = \"{}\"),", udf).as_ref()); + result.push_str(format!("\n related_udf(name = \"{udf}\"),").as_ref()); }); } diff --git a/datafusion/execution/Cargo.toml b/datafusion/execution/Cargo.toml index 20e507e98b68..5988d3a33660 100644 --- a/datafusion/execution/Cargo.toml +++ b/datafusion/execution/Cargo.toml @@ -52,3 +52,4 @@ url = { workspace = true } [dev-dependencies] chrono = { workspace = true } +insta = { workspace = true } diff --git a/datafusion/execution/README.md b/datafusion/execution/README.md index 8a03255ee4ad..dd82e206e6d5 100644 --- a/datafusion/execution/README.md +++ b/datafusion/execution/README.md @@ -23,4 +23,9 @@ This crate is a submodule of DataFusion that provides execution runtime such as the memory pools and disk manager. +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 53646dc5b468..c1ee2820c0b4 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -23,7 +23,7 @@ use std::{ }; use datafusion_common::{ - config::{ConfigExtension, ConfigOptions}, + config::{ConfigExtension, ConfigOptions, SpillCompression}, Result, ScalarValue, }; @@ -193,9 +193,11 @@ impl SessionConfig { /// /// [`target_partitions`]: datafusion_common::config::ExecutionOptions::target_partitions pub fn with_target_partitions(mut self, n: usize) -> Self { - // partition count must be greater than zero - assert!(n > 0); - self.options.execution.target_partitions = n; + self.options.execution.target_partitions = if n == 0 { + datafusion_common::config::ExecutionOptions::default().target_partitions + } else { + n + }; self } @@ -256,6 +258,11 @@ impl SessionConfig { self.options.execution.collect_statistics } + /// Compression codec for spill file + pub fn spill_compression(&self) -> SpillCompression { + self.options.execution.spill_compression + } + /// Selects a name for the default catalog and schema pub fn with_default_catalog_and_schema( mut self, @@ -419,6 +426,14 @@ impl SessionConfig { self } + /// Set the compression codec [`spill_compression`] used when spilling data to disk. + /// + /// [`spill_compression`]: datafusion_common::config::ExecutionOptions::spill_compression + pub fn with_spill_compression(mut self, spill_compression: SpillCompression) -> Self { + self.options.execution.spill_compression = spill_compression; + self + } + /// Set the size of [`sort_in_place_threshold_bytes`] to control /// how sort does things. /// diff --git a/datafusion/execution/src/disk_manager.rs b/datafusion/execution/src/disk_manager.rs index 2b21a6dbf175..82f2d75ac1b5 100644 --- a/datafusion/execution/src/disk_manager.rs +++ b/datafusion/execution/src/disk_manager.rs @@ -22,7 +22,7 @@ use datafusion_common::{ }; use log::debug; use parking_lot::Mutex; -use rand::{thread_rng, Rng}; +use rand::{rng, Rng}; use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; @@ -32,7 +32,95 @@ use crate::memory_pool::human_readable_size; const DEFAULT_MAX_TEMP_DIRECTORY_SIZE: u64 = 100 * 1024 * 1024 * 1024; // 100GB +/// Builder pattern for the [DiskManager] structure +#[derive(Clone, Debug)] +pub struct DiskManagerBuilder { + /// The storage mode of the disk manager + mode: DiskManagerMode, + /// The maximum amount of data (in bytes) stored inside the temporary directories. + /// Default to 100GB + max_temp_directory_size: u64, +} + +impl Default for DiskManagerBuilder { + fn default() -> Self { + Self { + mode: DiskManagerMode::OsTmpDirectory, + max_temp_directory_size: DEFAULT_MAX_TEMP_DIRECTORY_SIZE, + } + } +} + +impl DiskManagerBuilder { + pub fn set_mode(&mut self, mode: DiskManagerMode) { + self.mode = mode; + } + + pub fn with_mode(mut self, mode: DiskManagerMode) -> Self { + self.set_mode(mode); + self + } + + pub fn set_max_temp_directory_size(&mut self, value: u64) { + self.max_temp_directory_size = value; + } + + pub fn with_max_temp_directory_size(mut self, value: u64) -> Self { + self.set_max_temp_directory_size(value); + self + } + + /// Create a DiskManager given the builder + pub fn build(self) -> Result { + match self.mode { + DiskManagerMode::OsTmpDirectory => Ok(DiskManager { + local_dirs: Mutex::new(Some(vec![])), + max_temp_directory_size: self.max_temp_directory_size, + used_disk_space: Arc::new(AtomicU64::new(0)), + }), + DiskManagerMode::Directories(conf_dirs) => { + let local_dirs = create_local_dirs(conf_dirs)?; + debug!( + "Created local dirs {local_dirs:?} as DataFusion working directory" + ); + Ok(DiskManager { + local_dirs: Mutex::new(Some(local_dirs)), + max_temp_directory_size: self.max_temp_directory_size, + used_disk_space: Arc::new(AtomicU64::new(0)), + }) + } + DiskManagerMode::Disabled => Ok(DiskManager { + local_dirs: Mutex::new(None), + max_temp_directory_size: self.max_temp_directory_size, + used_disk_space: Arc::new(AtomicU64::new(0)), + }), + } + } +} + +#[derive(Clone, Debug)] +pub enum DiskManagerMode { + /// Create a new [DiskManager] that creates temporary files within + /// a temporary directory chosen by the OS + OsTmpDirectory, + + /// Create a new [DiskManager] that creates temporary files within + /// the specified directories. One of the directories will be chosen + /// at random for each temporary file created. + Directories(Vec), + + /// Disable disk manager, attempts to create temporary files will error + Disabled, +} + +impl Default for DiskManagerMode { + fn default() -> Self { + Self::OsTmpDirectory + } +} + /// Configuration for temporary disk access +#[deprecated(since = "48.0.0", note = "Use DiskManagerBuilder instead")] #[derive(Debug, Clone)] pub enum DiskManagerConfig { /// Use the provided [DiskManager] instance @@ -50,12 +138,14 @@ pub enum DiskManagerConfig { Disabled, } +#[allow(deprecated)] impl Default for DiskManagerConfig { fn default() -> Self { Self::NewOs } } +#[allow(deprecated)] impl DiskManagerConfig { /// Create temporary files in a temporary directory chosen by the OS pub fn new() -> Self { @@ -91,7 +181,14 @@ pub struct DiskManager { } impl DiskManager { + /// Creates a builder for [DiskManager] + pub fn builder() -> DiskManagerBuilder { + DiskManagerBuilder::default() + } + /// Create a DiskManager given the configuration + #[allow(deprecated)] + #[deprecated(since = "48.0.0", note = "Use DiskManager::builder() instead")] pub fn try_new(config: DiskManagerConfig) -> Result> { match config { DiskManagerConfig::Existing(manager) => Ok(manager), @@ -103,8 +200,7 @@ impl DiskManager { DiskManagerConfig::NewSpecified(conf_dirs) => { let local_dirs = create_local_dirs(conf_dirs)?; debug!( - "Created local dirs {:?} as DataFusion working directory", - local_dirs + "Created local dirs {local_dirs:?} as DataFusion working directory" ); Ok(Arc::new(Self { local_dirs: Mutex::new(Some(local_dirs)), @@ -120,10 +216,10 @@ impl DiskManager { } } - pub fn with_max_temp_directory_size( - mut self, + pub fn set_max_temp_directory_size( + &mut self, max_temp_directory_size: u64, - ) -> Result { + ) -> Result<()> { // If the disk manager is disabled and `max_temp_directory_size` is not 0, // this operation is not meaningful, fail early. if self.local_dirs.lock().is_none() && max_temp_directory_size != 0 { @@ -133,6 +229,26 @@ impl DiskManager { } self.max_temp_directory_size = max_temp_directory_size; + Ok(()) + } + + pub fn set_arc_max_temp_directory_size( + this: &mut Arc, + max_temp_directory_size: u64, + ) -> Result<()> { + if let Some(inner) = Arc::get_mut(this) { + inner.set_max_temp_directory_size(max_temp_directory_size)?; + Ok(()) + } else { + config_err!("DiskManager should be a single instance") + } + } + + pub fn with_max_temp_directory_size( + mut self, + max_temp_directory_size: u64, + ) -> Result { + self.set_max_temp_directory_size(max_temp_directory_size)?; Ok(self) } @@ -175,7 +291,7 @@ impl DiskManager { local_dirs.push(Arc::new(tempdir)); } - let dir_index = thread_rng().gen_range(0..local_dirs.len()); + let dir_index = rng().random_range(0..local_dirs.len()); Ok(RefCountedTempFile { _parent_temp_dir: Arc::clone(&local_dirs[dir_index]), tempfile: Builder::new() @@ -250,6 +366,10 @@ impl RefCountedTempFile { Ok(()) } + + pub fn current_disk_usage(&self) -> u64 { + self.current_file_disk_usage + } } /// When the temporary file is dropped, subtract its disk usage from the disk manager's total @@ -286,8 +406,7 @@ mod tests { #[test] fn lazy_temp_dir_creation() -> Result<()> { // A default configuration should not create temp files until requested - let config = DiskManagerConfig::new(); - let dm = DiskManager::try_new(config)?; + let dm = Arc::new(DiskManagerBuilder::default().build()?); assert_eq!(0, local_dir_snapshot(&dm).len()); @@ -319,11 +438,14 @@ mod tests { let local_dir2 = TempDir::new()?; let local_dir3 = TempDir::new()?; let local_dirs = vec![local_dir1.path(), local_dir2.path(), local_dir3.path()]; - let config = DiskManagerConfig::new_specified( - local_dirs.iter().map(|p| p.into()).collect(), + let dm = Arc::new( + DiskManagerBuilder::default() + .with_mode(DiskManagerMode::Directories( + local_dirs.iter().map(|p| p.into()).collect(), + )) + .build()?, ); - let dm = DiskManager::try_new(config)?; assert!(dm.tmp_files_enabled()); let actual = dm.create_tmp_file("Testing")?; @@ -335,8 +457,12 @@ mod tests { #[test] fn test_disabled_disk_manager() { - let config = DiskManagerConfig::Disabled; - let manager = DiskManager::try_new(config).unwrap(); + let manager = Arc::new( + DiskManagerBuilder::default() + .with_mode(DiskManagerMode::Disabled) + .build() + .unwrap(), + ); assert!(!manager.tmp_files_enabled()); assert_eq!( manager.create_tmp_file("Testing").unwrap_err().strip_backtrace(), @@ -347,11 +473,9 @@ mod tests { #[test] fn test_disk_manager_create_spill_folder() { let dir = TempDir::new().unwrap(); - let config = DiskManagerConfig::new_specified(vec![dir.path().to_owned()]); - - DiskManager::try_new(config) - .unwrap() - .create_tmp_file("Testing") + DiskManagerBuilder::default() + .with_mode(DiskManagerMode::Directories(vec![dir.path().to_path_buf()])) + .build() .unwrap(); } @@ -374,8 +498,7 @@ mod tests { #[test] fn test_temp_file_still_alive_after_disk_manager_dropped() -> Result<()> { // Test for the case using OS arranged temporary directory - let config = DiskManagerConfig::new(); - let dm = DiskManager::try_new(config)?; + let dm = Arc::new(DiskManagerBuilder::default().build()?); let temp_file = dm.create_tmp_file("Testing")?; let temp_file_path = temp_file.path().to_owned(); assert!(temp_file_path.exists()); @@ -391,10 +514,13 @@ mod tests { let local_dir2 = TempDir::new()?; let local_dir3 = TempDir::new()?; let local_dirs = [local_dir1.path(), local_dir2.path(), local_dir3.path()]; - let config = DiskManagerConfig::new_specified( - local_dirs.iter().map(|p| p.into()).collect(), + let dm = Arc::new( + DiskManagerBuilder::default() + .with_mode(DiskManagerMode::Directories( + local_dirs.iter().map(|p| p.into()).collect(), + )) + .build()?, ); - let dm = DiskManager::try_new(config)?; let temp_file = dm.create_tmp_file("Testing")?; let temp_file_path = temp_file.path().to_owned(); assert!(temp_file_path.exists()); diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 625a779b3eea..e620b2326796 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -57,8 +57,8 @@ pub use pool::*; /// `GroupByHashExec`. It does NOT track and limit memory used internally by /// other operators such as `DataSourceExec` or the `RecordBatch`es that flow /// between operators. Furthermore, operators should not reserve memory for the -/// batches they produce. Instead, if a parent operator needs to hold batches -/// from its children in memory for an extended period, it is the parent +/// batches they produce. Instead, if a consumer operator needs to hold batches +/// from its producers in memory for an extended period, it is the consumer /// operator's responsibility to reserve the necessary memory for those batches. /// /// In order to avoid allocating memory until the OS or the container system @@ -98,6 +98,67 @@ pub use pool::*; /// operator will spill the intermediate buffers to disk, and release memory /// from the memory pool, and continue to retry memory reservation. /// +/// # Related Structs +/// +/// To better understand memory management in DataFusion, here are the key structs +/// and their relationships: +/// +/// - [`MemoryConsumer`]: A named allocation traced by a particular operator. If an +/// execution is parallelized, and there are multiple partitions of the same +/// operator, each partition will have a separate `MemoryConsumer`. +/// - `SharedRegistration`: A registration of a `MemoryConsumer` with a `MemoryPool`. +/// `SharedRegistration` and `MemoryPool` have a many-to-one relationship. `MemoryPool` +/// implementation can decide how to allocate memory based on the registered consumers. +/// (e.g. `FairSpillPool` will try to share available memory evenly among all registered +/// consumers) +/// - [`MemoryReservation`]: Each `MemoryConsumer`/operator can have multiple +/// `MemoryReservation`s for different internal data structures. The relationship +/// between `MemoryConsumer` and `MemoryReservation` is one-to-many. This design +/// enables cleaner operator implementations: +/// - Different `MemoryReservation`s can be used for different purposes +/// - `MemoryReservation` follows RAII principles - to release a reservation, +/// simply drop the `MemoryReservation` object. When all `MemoryReservation`s +/// for a `SharedRegistration` are dropped, the `SharedRegistration` is dropped +/// when its reference count reaches zero, automatically unregistering the +/// `MemoryConsumer` from the `MemoryPool`. +/// +/// ## Relationship Diagram +/// +/// ```text +/// ┌──────────────────┐ ┌──────────────────┐ +/// │MemoryReservation │ │MemoryReservation │ +/// └───┬──────────────┘ └──────────────────┘ ...... +/// │belongs to │ +/// │ ┌───────────────────────┘ │ │ +/// │ │ │ │ +/// ▼ ▼ ▼ ▼ +/// ┌────────────────────────┐ ┌────────────────────────┐ +/// │ SharedRegistration │ │ SharedRegistration │ +/// │ ┌────────────────┐ │ │ ┌────────────────┐ │ +/// │ │ │ │ │ │ │ │ +/// │ │ MemoryConsumer │ │ │ │ MemoryConsumer │ │ +/// │ │ │ │ │ │ │ │ +/// │ └────────────────┘ │ │ └────────────────┘ │ +/// └────────────┬───────────┘ └────────────┬───────────┘ +/// │ │ +/// │ register│into +/// │ │ +/// └─────────────┐ ┌──────────────┘ +/// │ │ +/// ▼ ▼ +/// ╔═══════════════════════════════════════════════════╗ +/// ║ ║ +/// ║ MemoryPool ║ +/// ║ ║ +/// ╚═══════════════════════════════════════════════════╝ +/// ``` +/// +/// For example, there are two parallel partitions of an operator X: each partition +/// corresponds to a `MemoryConsumer` in the above diagram. Inside each partition of +/// operator X, there are typically several `MemoryReservation`s - one for each +/// internal data structure that needs memory tracking (e.g., 1 reservation for the hash +/// table, and 1 reservation for buffered input, etc.). +/// /// # Implementing `MemoryPool` /// /// You can implement a custom allocation policy by implementing the @@ -141,6 +202,25 @@ pub trait MemoryPool: Send + Sync + std::fmt::Debug { /// Return the total amount of memory reserved fn reserved(&self) -> usize; + + /// Return the memory limit of the pool + /// + /// The default implementation of `MemoryPool::memory_limit` + /// will return `MemoryLimit::Unknown`. + /// If you are using your custom memory pool, but have the requirement to + /// know the memory usage limit of the pool, please implement this method + /// to return it(`Memory::Finite(limit)`). + fn memory_limit(&self) -> MemoryLimit { + MemoryLimit::Unknown + } +} + +/// Memory limit of `MemoryPool` +pub enum MemoryLimit { + Infinite, + /// Bounded memory limit in bytes. + Finite(usize), + Unknown, } /// A memory consumer is a named allocation traced by a particular diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index cd6863939d27..11467f69be1c 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; +use crate::memory_pool::{ + human_readable_size, MemoryConsumer, MemoryLimit, MemoryPool, MemoryReservation, +}; use datafusion_common::HashMap; use datafusion_common::{resources_datafusion_err, DataFusionError, Result}; use log::debug; @@ -48,6 +50,10 @@ impl MemoryPool for UnboundedMemoryPool { fn reserved(&self) -> usize { self.used.load(Ordering::Relaxed) } + + fn memory_limit(&self) -> MemoryLimit { + MemoryLimit::Infinite + } } /// A [`MemoryPool`] that implements a greedy first-come first-serve limit. @@ -100,6 +106,10 @@ impl MemoryPool for GreedyMemoryPool { fn reserved(&self) -> usize { self.used.load(Ordering::Relaxed) } + + fn memory_limit(&self) -> MemoryLimit { + MemoryLimit::Finite(self.pool_size) + } } /// A [`MemoryPool`] that prevents spillable reservations from using more than @@ -233,6 +243,10 @@ impl MemoryPool for FairSpillPool { let state = self.state.lock(); state.spillable + state.unspillable } + + fn memory_limit(&self) -> MemoryLimit { + MemoryLimit::Finite(self.pool_size) + } } /// Constructs a resources error based upon the individual [`MemoryReservation`]. @@ -246,7 +260,8 @@ fn insufficient_capacity_err( additional: usize, available: usize, ) -> DataFusionError { - resources_datafusion_err!("Failed to allocate additional {} bytes for {} with {} bytes already allocated for this reservation - {} bytes remain available for the total pool", additional, reservation.registration.consumer.name, reservation.size, available) + resources_datafusion_err!("Failed to allocate additional {} for {} with {} already allocated for this reservation - {} remain available for the total pool", + human_readable_size(additional), reservation.registration.consumer.name, human_readable_size(reservation.size), human_readable_size(available)) } #[derive(Debug)] @@ -328,10 +343,14 @@ impl TrackConsumersPool { consumers[0..std::cmp::min(top, consumers.len())] .iter() .map(|((id, name, can_spill), size)| { - format!("{name}#{id}(can spill: {can_spill}) consumed {size} bytes") + format!( + " {name}#{id}(can spill: {can_spill}) consumed {}", + human_readable_size(*size) + ) }) .collect::>() - .join(", ") + .join(",\n") + + "." } } @@ -408,20 +427,34 @@ impl MemoryPool for TrackConsumersPool { fn reserved(&self) -> usize { self.inner.reserved() } + + fn memory_limit(&self) -> MemoryLimit { + self.inner.memory_limit() + } } fn provide_top_memory_consumers_to_error_msg( error_msg: String, top_consumers: String, ) -> String { - format!("Additional allocation failed with top memory consumers (across reservations) as: {}. Error: {}", top_consumers, error_msg) + format!("Additional allocation failed with top memory consumers (across reservations) as:\n{top_consumers}\nError: {error_msg}") } #[cfg(test)] mod tests { use super::*; + use insta::{allow_duplicates, assert_snapshot, Settings}; use std::sync::Arc; + fn make_settings() -> Settings { + let mut settings = Settings::clone_current(); + settings.add_filter( + r"([^\s]+)\#\d+\(can spill: (true|false)\)", + "$1#[ID](can spill: $2)", + ); + settings + } + #[test] fn test_fair() { let pool = Arc::new(FairSpillPool::new(100)) as _; @@ -440,10 +473,10 @@ mod tests { assert_eq!(pool.reserved(), 4000); let err = r2.try_grow(1).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated for this reservation - 0 bytes remain available for the total pool"); + assert_snapshot!(err, @"Resources exhausted: Failed to allocate additional 1.0 B for r2 with 2000.0 B already allocated for this reservation - 0.0 B remain available for the total pool"); let err = r2.try_grow(1).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated for this reservation - 0 bytes remain available for the total pool"); + assert_snapshot!(err, @"Resources exhausted: Failed to allocate additional 1.0 B for r2 with 2000.0 B already allocated for this reservation - 0.0 B remain available for the total pool"); r1.shrink(1990); r2.shrink(2000); @@ -468,12 +501,12 @@ mod tests { .register(&pool); let err = r3.try_grow(70).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated for this reservation - 40 bytes remain available for the total pool"); + assert_snapshot!(err, @"Resources exhausted: Failed to allocate additional 70.0 B for r3 with 0.0 B already allocated for this reservation - 40.0 B remain available for the total pool"); //Shrinking r2 to zero doesn't allow a3 to allocate more than 45 r2.free(); let err = r3.try_grow(70).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated for this reservation - 40 bytes remain available for the total pool"); + assert_snapshot!(err, @"Resources exhausted: Failed to allocate additional 70.0 B for r3 with 0.0 B already allocated for this reservation - 40.0 B remain available for the total pool"); // But dropping r2 does drop(r2); @@ -486,11 +519,13 @@ mod tests { let mut r4 = MemoryConsumer::new("s4").register(&pool); let err = r4.try_grow(30).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 30 bytes for s4 with 0 bytes already allocated for this reservation - 20 bytes remain available for the total pool"); + assert_snapshot!(err, @"Resources exhausted: Failed to allocate additional 30.0 B for s4 with 0.0 B already allocated for this reservation - 20.0 B remain available for the total pool"); } #[test] fn test_tracked_consumers_pool() { + let setting = make_settings(); + let _bound = setting.bind_to_scope(); let pool: Arc = Arc::new(TrackConsumersPool::new( GreedyMemoryPool::new(100), NonZeroUsize::new(3).unwrap(), @@ -523,20 +558,22 @@ mod tests { // Test: reports if new reservation causes error // using the previously set sizes for other consumers let mut r5 = MemoryConsumer::new("r5").register(&pool); - let expected = format!("Additional allocation failed with top memory consumers (across reservations) as: r1#{}(can spill: false) consumed 50 bytes, r3#{}(can spill: false) consumed 20 bytes, r2#{}(can spill: false) consumed 15 bytes. Error: Failed to allocate additional 150 bytes for r5 with 0 bytes already allocated for this reservation - 5 bytes remain available for the total pool", r1.consumer().id(), r3.consumer().id(), r2.consumer().id()); let res = r5.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(&expected) - ), - "should provide list of top memory consumers, instead found {:?}", - res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + r1#[ID](can spill: false) consumed 50.0 B, + r3#[ID](can spill: false) consumed 20.0 B, + r2#[ID](can spill: false) consumed 15.0 B. + Error: Failed to allocate additional 150.0 B for r5 with 0.0 B already allocated for this reservation - 5.0 B remain available for the total pool + "); } #[test] fn test_tracked_consumers_pool_register() { + let setting = make_settings(); + let _bound = setting.bind_to_scope(); let pool: Arc = Arc::new(TrackConsumersPool::new( GreedyMemoryPool::new(100), NonZeroUsize::new(3).unwrap(), @@ -546,15 +583,14 @@ mod tests { // Test: see error message when no consumers recorded yet let mut r0 = MemoryConsumer::new(same_name).register(&pool); - let expected = format!("Additional allocation failed with top memory consumers (across reservations) as: foo#{}(can spill: false) consumed 0 bytes. Error: Failed to allocate additional 150 bytes for foo with 0 bytes already allocated for this reservation - 100 bytes remain available for the total pool", r0.consumer().id()); let res = r0.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(&expected) - ), - "should provide proper error when no reservations have been made yet, instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + foo#[ID](can spill: false) consumed 0.0 B. + Error: Failed to allocate additional 150.0 B for foo with 0.0 B already allocated for this reservation - 100.0 B remain available for the total pool + "); // API: multiple registrations using the same hashed consumer, // will be recognized *differently* in the TrackConsumersPool. @@ -564,102 +600,101 @@ mod tests { let mut r1 = new_consumer_same_name.register(&pool); // TODO: the insufficient_capacity_err() message is per reservation, not per consumer. // a followup PR will clarify this message "0 bytes already allocated for this reservation" - let expected = format!("Additional allocation failed with top memory consumers (across reservations) as: foo#{}(can spill: false) consumed 10 bytes, foo#{}(can spill: false) consumed 0 bytes. Error: Failed to allocate additional 150 bytes for foo with 0 bytes already allocated for this reservation - 90 bytes remain available for the total pool", r0.consumer().id(), r1.consumer().id()); let res = r1.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(&expected) - ), - "should provide proper error for 2 consumers, instead found {:?}", - res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + foo#[ID](can spill: false) consumed 10.0 B, + foo#[ID](can spill: false) consumed 0.0 B. + Error: Failed to allocate additional 150.0 B for foo with 0.0 B already allocated for this reservation - 90.0 B remain available for the total pool + "); // Test: will accumulate size changes per consumer, not per reservation r1.grow(20); - let expected = format!("Additional allocation failed with top memory consumers (across reservations) as: foo#{}(can spill: false) consumed 20 bytes, foo#{}(can spill: false) consumed 10 bytes. Error: Failed to allocate additional 150 bytes for foo with 20 bytes already allocated for this reservation - 70 bytes remain available for the total pool", r1.consumer().id(), r0.consumer().id()); + let res = r1.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(&expected) - ), - "should provide proper error for 2 consumers(one foo=20 bytes, another foo=10 bytes, available=70), instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + foo#[ID](can spill: false) consumed 20.0 B, + foo#[ID](can spill: false) consumed 10.0 B. + Error: Failed to allocate additional 150.0 B for foo with 20.0 B already allocated for this reservation - 70.0 B remain available for the total pool + "); // Test: different hashed consumer, (even with the same name), // will be recognized as different in the TrackConsumersPool let consumer_with_same_name_but_different_hash = MemoryConsumer::new(same_name).with_can_spill(true); let mut r2 = consumer_with_same_name_but_different_hash.register(&pool); - let expected = format!("Additional allocation failed with top memory consumers (across reservations) as: foo#{}(can spill: false) consumed 20 bytes, foo#{}(can spill: false) consumed 10 bytes, foo#{}(can spill: true) consumed 0 bytes. Error: Failed to allocate additional 150 bytes for foo with 0 bytes already allocated for this reservation - 70 bytes remain available for the total pool", r1.consumer().id(), r0.consumer().id(), r2.consumer().id()); let res = r2.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(&expected) - ), - "should provide proper error with 3 separate consumers(1 = 20 bytes, 2 = 10 bytes, 3 = 0 bytes), instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + foo#[ID](can spill: false) consumed 20.0 B, + foo#[ID](can spill: false) consumed 10.0 B, + foo#[ID](can spill: true) consumed 0.0 B. + Error: Failed to allocate additional 150.0 B for foo with 0.0 B already allocated for this reservation - 70.0 B remain available for the total pool + "); } #[test] fn test_tracked_consumers_pool_deregister() { fn test_per_pool_type(pool: Arc) { // Baseline: see the 2 memory consumers + let setting = make_settings(); + let _bound = setting.bind_to_scope(); let mut r0 = MemoryConsumer::new("r0").register(&pool); r0.grow(10); let r1_consumer = MemoryConsumer::new("r1"); let mut r1 = r1_consumer.register(&pool); r1.grow(20); - let expected = format!("Additional allocation failed with top memory consumers (across reservations) as: r1#{}(can spill: false) consumed 20 bytes, r0#{}(can spill: false) consumed 10 bytes. Error: Failed to allocate additional 150 bytes for r0 with 10 bytes already allocated for this reservation - 70 bytes remain available for the total pool", r1.consumer().id(), r0.consumer().id()); let res = r0.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(&expected) - ), - "should provide proper error with both consumers, instead found {:?}", - res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + allow_duplicates!(assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + r1#[ID](can spill: false) consumed 20.0 B, + r0#[ID](can spill: false) consumed 10.0 B. + Error: Failed to allocate additional 150.0 B for r0 with 10.0 B already allocated for this reservation - 70.0 B remain available for the total pool + ")); // Test: unregister one // only the remaining one should be listed drop(r1); - let expected_consumers = format!("Additional allocation failed with top memory consumers (across reservations) as: r0#{}(can spill: false) consumed 10 bytes", r0.consumer().id()); let res = r0.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(&expected_consumers) - ), - "should provide proper error with only 1 consumer left registered, instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + allow_duplicates!(assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + r0#[ID](can spill: false) consumed 10.0 B. + Error: Failed to allocate additional 150.0 B for r0 with 10.0 B already allocated for this reservation - 90.0 B remain available for the total pool + ")); // Test: actual message we see is the `available is 70`. When it should be `available is 90`. // This is because the pool.shrink() does not automatically occur within the inner_pool.deregister(). - let expected_90_available = "Failed to allocate additional 150 bytes for r0 with 10 bytes already allocated for this reservation - 90 bytes remain available for the total pool"; let res = r0.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected_90_available) - ), - "should find that the inner pool will still count all bytes for the deregistered consumer until the reservation is dropped, instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + allow_duplicates!(assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + r0#[ID](can spill: false) consumed 10.0 B. + Error: Failed to allocate additional 150.0 B for r0 with 10.0 B already allocated for this reservation - 90.0 B remain available for the total pool + ")); // Test: the registration needs to free itself (or be dropped), // for the proper error message - let expected_90_available = "Failed to allocate additional 150 bytes for r0 with 10 bytes already allocated for this reservation - 90 bytes remain available for the total pool"; let res = r0.try_grow(150); - assert!( - matches!( - &res, - Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected_90_available) - ), - "should correctly account the total bytes after reservation is free, instead found {:?}", res - ); + assert!(res.is_err()); + let error = res.unwrap_err().strip_backtrace(); + allow_duplicates!(assert_snapshot!(error, @r" + Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: + r0#[ID](can spill: false) consumed 10.0 B. + Error: Failed to allocate additional 150.0 B for r0 with 10.0 B already allocated for this reservation - 90.0 B remain available for the total pool + ")); } let tracked_spill_pool: Arc = Arc::new(TrackConsumersPool::new( @@ -677,6 +712,8 @@ mod tests { #[test] fn test_tracked_consumers_pool_use_beyond_errors() { + let setting = make_settings(); + let _bound = setting.bind_to_scope(); let upcasted: Arc = Arc::new(TrackConsumersPool::new( GreedyMemoryPool::new(100), @@ -700,12 +737,10 @@ mod tests { .unwrap(); // Test: can get runtime metrics, even without an error thrown - let expected = format!("r3#{}(can spill: false) consumed 45 bytes, r1#{}(can spill: false) consumed 20 bytes", r3.consumer().id(), r1.consumer().id()); let res = downcasted.report_top(2); - assert_eq!( - res, expected, - "should provide list of top memory consumers, instead found {:?}", - res - ); + assert_snapshot!(res, @r" + r3#[ID](can spill: false) consumed 45.0 B, + r1#[ID](can spill: false) consumed 20.0 B. + "); } } diff --git a/datafusion/execution/src/runtime_env.rs b/datafusion/execution/src/runtime_env.rs index 95f14f485792..b086430a4ef7 100644 --- a/datafusion/execution/src/runtime_env.rs +++ b/datafusion/execution/src/runtime_env.rs @@ -18,8 +18,10 @@ //! Execution [`RuntimeEnv`] environment that manages access to object //! store, memory manager, disk manager. +#[allow(deprecated)] +use crate::disk_manager::DiskManagerConfig; use crate::{ - disk_manager::{DiskManager, DiskManagerConfig}, + disk_manager::{DiskManager, DiskManagerBuilder, DiskManagerMode}, memory_pool::{ GreedyMemoryPool, MemoryPool, TrackConsumersPool, UnboundedMemoryPool, }, @@ -27,7 +29,7 @@ use crate::{ }; use crate::cache::cache_manager::{CacheManager, CacheManagerConfig}; -use datafusion_common::Result; +use datafusion_common::{config::ConfigEntry, Result}; use object_store::ObjectStore; use std::path::PathBuf; use std::sync::Arc; @@ -170,8 +172,11 @@ pub type RuntimeConfig = RuntimeEnvBuilder; /// /// See example on [`RuntimeEnv`] pub struct RuntimeEnvBuilder { + #[allow(deprecated)] /// DiskManager to manage temporary disk file usage pub disk_manager: DiskManagerConfig, + /// DiskManager builder to manager temporary disk file usage + pub disk_manager_builder: Option, /// [`MemoryPool`] from which to allocate memory /// /// Defaults to using an [`UnboundedMemoryPool`] if `None` @@ -193,18 +198,27 @@ impl RuntimeEnvBuilder { pub fn new() -> Self { Self { disk_manager: Default::default(), + disk_manager_builder: Default::default(), memory_pool: Default::default(), cache_manager: Default::default(), object_store_registry: Arc::new(DefaultObjectStoreRegistry::default()), } } + #[allow(deprecated)] + #[deprecated(since = "48.0.0", note = "Use with_disk_manager_builder instead")] /// Customize disk manager pub fn with_disk_manager(mut self, disk_manager: DiskManagerConfig) -> Self { self.disk_manager = disk_manager; self } + /// Customize the disk manager builder + pub fn with_disk_manager_builder(mut self, disk_manager: DiskManagerBuilder) -> Self { + self.disk_manager_builder = Some(disk_manager); + self + } + /// Customize memory policy pub fn with_memory_pool(mut self, memory_pool: Arc) -> Self { self.memory_pool = Some(memory_pool); @@ -242,13 +256,17 @@ impl RuntimeEnvBuilder { /// Use the specified path to create any needed temporary files pub fn with_temp_file_path(self, path: impl Into) -> Self { - self.with_disk_manager(DiskManagerConfig::new_specified(vec![path.into()])) + self.with_disk_manager_builder( + DiskManagerBuilder::default() + .with_mode(DiskManagerMode::Directories(vec![path.into()])), + ) } /// Build a RuntimeEnv pub fn build(self) -> Result { let Self { disk_manager, + disk_manager_builder, memory_pool, cache_manager, object_store_registry, @@ -258,7 +276,12 @@ impl RuntimeEnvBuilder { Ok(RuntimeEnv { memory_pool, - disk_manager: DiskManager::try_new(disk_manager)?, + disk_manager: if let Some(builder) = disk_manager_builder { + Arc::new(builder.build()?) + } else { + #[allow(deprecated)] + DiskManager::try_new(disk_manager)? + }, cache_manager: CacheManager::try_new(&cache_manager)?, object_store_registry, }) @@ -268,4 +291,58 @@ impl RuntimeEnvBuilder { pub fn build_arc(self) -> Result> { self.build().map(Arc::new) } + + /// Create a new RuntimeEnvBuilder from an existing RuntimeEnv + pub fn from_runtime_env(runtime_env: &RuntimeEnv) -> Self { + let cache_config = CacheManagerConfig { + table_files_statistics_cache: runtime_env + .cache_manager + .get_file_statistic_cache(), + list_files_cache: runtime_env.cache_manager.get_list_files_cache(), + }; + + Self { + #[allow(deprecated)] + disk_manager: DiskManagerConfig::Existing(Arc::clone( + &runtime_env.disk_manager, + )), + disk_manager_builder: None, + memory_pool: Some(Arc::clone(&runtime_env.memory_pool)), + cache_manager: cache_config, + object_store_registry: Arc::clone(&runtime_env.object_store_registry), + } + } + + /// Returns a list of all available runtime configurations with their current values and descriptions + pub fn entries(&self) -> Vec { + // Memory pool configuration + vec![ConfigEntry { + key: "datafusion.runtime.memory_limit".to_string(), + value: None, // Default is system-dependent + description: "Maximum memory limit for query execution. Supports suffixes K (kilobytes), M (megabytes), and G (gigabytes). Example: '2G' for 2 gigabytes.", + }] + } + + /// Generate documentation that can be included in the user guide + pub fn generate_config_markdown() -> String { + use std::fmt::Write as _; + + let s = Self::default(); + + let mut docs = "| key | default | description |\n".to_string(); + docs += "|-----|---------|-------------|\n"; + let mut entries = s.entries(); + entries.sort_unstable_by(|a, b| a.key.cmp(&b.key)); + + for entry in &entries { + let _ = writeln!( + &mut docs, + "| {} | {} | {} |", + entry.key, + entry.value.as_deref().unwrap_or("NULL"), + entry.description + ); + } + docs + } } diff --git a/datafusion/expr-common/README.md b/datafusion/expr-common/README.md new file mode 100644 index 000000000000..5f95627ca0d4 --- /dev/null +++ b/datafusion/expr-common/README.md @@ -0,0 +1,31 @@ + + +# DataFusion Logical Plan and Expressions + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate is a submodule of DataFusion that provides common logical expressions + +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/expr-common/src/accumulator.rs b/datafusion/expr-common/src/accumulator.rs index 3a63c3289481..2829a9416f03 100644 --- a/datafusion/expr-common/src/accumulator.rs +++ b/datafusion/expr-common/src/accumulator.rs @@ -42,7 +42,6 @@ use std::fmt::Debug; /// [`state`] and combine the state from multiple accumulators /// via [`merge_batch`], as part of efficient multi-phase grouping. /// -/// [`GroupsAccumulator`]: crate::GroupsAccumulator /// [`update_batch`]: Self::update_batch /// [`retract_batch`]: Self::retract_batch /// [`state`]: Self::state diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index cb7cbdbac291..a21ad5bbbcc3 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -237,7 +237,7 @@ impl fmt::Display for ColumnarValue { }; if let Ok(formatted) = formatted { - write!(f, "{}", formatted) + write!(f, "{formatted}") } else { write!(f, "Error formatting columnar value") } diff --git a/datafusion/expr-common/src/groups_accumulator.rs b/datafusion/expr-common/src/groups_accumulator.rs index 5ff1c1d07216..9bcc1edff882 100644 --- a/datafusion/expr-common/src/groups_accumulator.rs +++ b/datafusion/expr-common/src/groups_accumulator.rs @@ -21,7 +21,7 @@ use arrow::array::{ArrayRef, BooleanArray}; use datafusion_common::{not_impl_err, Result}; /// Describes how many rows should be emitted during grouping. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum EmitTo { /// Emit all groups All, diff --git a/datafusion/expr-common/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs index 6af4322df29e..d656c676bd01 100644 --- a/datafusion/expr-common/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -606,7 +606,7 @@ impl Interval { upper: ScalarValue::Boolean(Some(upper)), }) } - _ => internal_err!("Incompatible data types for logical conjunction"), + _ => internal_err!("Incompatible data types for logical disjunction"), } } @@ -949,6 +949,18 @@ impl Display for Interval { } } +impl From for Interval { + fn from(value: ScalarValue) -> Self { + Self::new(value.clone(), value) + } +} + +impl From<&ScalarValue> for Interval { + fn from(value: &ScalarValue) -> Self { + Self::new(value.to_owned(), value.to_owned()) + } +} + /// Applies the given binary operator the `lhs` and `rhs` arguments. pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { match *op { @@ -959,6 +971,7 @@ pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result lhs.lt(rhs), Operator::LtEq => lhs.lt_eq(rhs), Operator::And => lhs.and(rhs), + Operator::Or => lhs.or(rhs), Operator::Plus => lhs.add(rhs), Operator::Minus => lhs.sub(rhs), Operator::Multiply => lhs.mul(rhs), @@ -1683,9 +1696,9 @@ impl Display for NullableInterval { match self { Self::Null { .. } => write!(f, "NullableInterval: {{NULL}}"), Self::MaybeNull { values } => { - write!(f, "NullableInterval: {} U {{NULL}}", values) + write!(f, "NullableInterval: {values} U {{NULL}}") } - Self::NotNull { values } => write!(f, "NullableInterval: {}", values), + Self::NotNull { values } => write!(f, "NullableInterval: {values}"), } } } @@ -2706,8 +2719,8 @@ mod tests { ), ]; for (first, second, expected) in possible_cases { - println!("{}", first); - println!("{}", second); + println!("{first}"); + println!("{second}"); assert_eq!(first.union(second)?, expected) } @@ -3704,14 +3717,14 @@ mod tests { #[test] fn test_interval_display() { let interval = Interval::make(Some(0.25_f32), Some(0.50_f32)).unwrap(); - assert_eq!(format!("{}", interval), "[0.25, 0.5]"); + assert_eq!(format!("{interval}"), "[0.25, 0.5]"); let interval = Interval::try_new( ScalarValue::Float32(Some(f32::NEG_INFINITY)), ScalarValue::Float32(Some(f32::INFINITY)), ) .unwrap(); - assert_eq!(format!("{}", interval), "[NULL, NULL]"); + assert_eq!(format!("{interval}"), "[NULL, NULL]"); } macro_rules! capture_mode_change { diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index a7c9330201bc..5e1705d8ff61 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -843,6 +843,7 @@ impl Signature { volatility, } } + /// Any one of a list of [TypeSignature]s. pub fn one_of(type_signatures: Vec, volatility: Volatility) -> Self { Signature { @@ -850,7 +851,8 @@ impl Signature { volatility, } } - /// Specialized Signature for ArrayAppend and similar functions + + /// Specialized [Signature] for ArrayAppend and similar functions. pub fn array_and_element(volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::ArraySignature( @@ -865,7 +867,41 @@ impl Signature { volatility, } } - /// Specialized Signature for Array functions with an optional index + + /// Specialized [Signature] for ArrayPrepend and similar functions. + pub fn element_and_array(volatility: Volatility) -> Self { + Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Element, + ArrayFunctionArgument::Array, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility, + } + } + + /// Specialized [Signature] for functions that take a fixed number of arrays. + pub fn arrays( + n: usize, + coercion: Option, + volatility: Volatility, + ) -> Self { + Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array; n], + array_coercion: coercion, + }, + ), + volatility, + } + } + + /// Specialized [Signature] for Array functions with an optional index. pub fn array_and_element_and_optional_index(volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::OneOf(vec![ @@ -889,7 +925,7 @@ impl Signature { } } - /// Specialized Signature for ArrayElement and similar functions + /// Specialized [Signature] for ArrayElement and similar functions. pub fn array_and_index(volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::ArraySignature( @@ -898,23 +934,16 @@ impl Signature { ArrayFunctionArgument::Array, ArrayFunctionArgument::Index, ], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }, ), volatility, } } - /// Specialized Signature for ArrayEmpty and similar functions + + /// Specialized [Signature] for ArrayEmpty and similar functions. pub fn array(volatility: Volatility) -> Self { - Signature { - type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::Array { - arguments: vec![ArrayFunctionArgument::Array], - array_coercion: None, - }, - ), - volatility, - } + Signature::arrays(1, Some(ListCoercion::FixedSizedListToList), volatility) } } @@ -940,8 +969,7 @@ mod tests { for case in positive_cases { assert!( case.supports_zero_argument(), - "Expected {:?} to support zero arguments", - case + "Expected {case:?} to support zero arguments" ); } @@ -960,8 +988,7 @@ mod tests { for case in negative_cases { assert!( !case.supports_zero_argument(), - "Expected {:?} not to support zero arguments", - case + "Expected {case:?} not to support zero arguments" ); } } diff --git a/datafusion/expr-common/src/statistics.rs b/datafusion/expr-common/src/statistics.rs index 7e0bc88087ef..14f2f331ef5b 100644 --- a/datafusion/expr-common/src/statistics.rs +++ b/datafusion/expr-common/src/statistics.rs @@ -1559,18 +1559,14 @@ mod tests { assert_eq!( new_generic_from_binary_op(&op, &dist_a, &dist_b)?.range()?, apply_operator(&op, a, b)?, - "Failed for {:?} {op} {:?}", - dist_a, - dist_b + "Failed for {dist_a:?} {op} {dist_b:?}" ); } for op in [Gt, GtEq, Lt, LtEq, Eq, NotEq] { assert_eq!( create_bernoulli_from_comparison(&op, &dist_a, &dist_b)?.range()?, apply_operator(&op, a, b)?, - "Failed for {:?} {op} {:?}", - dist_a, - dist_b + "Failed for {dist_a:?} {op} {dist_b:?}" ); } } diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs b/datafusion/expr-common/src/type_coercion/aggregates.rs index 13d52959aba6..e9377ce7de5a 100644 --- a/datafusion/expr-common/src/type_coercion/aggregates.rs +++ b/datafusion/expr-common/src/type_coercion/aggregates.rs @@ -17,7 +17,7 @@ use crate::signature::TypeSignature; use arrow::datatypes::{ - DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DataType, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; @@ -82,48 +82,48 @@ pub static TIMES: &[DataType] = &[ DataType::Time64(TimeUnit::Nanosecond), ]; -/// Validate the length of `input_types` matches the `signature` for `agg_fun`. +/// Validate the length of `input_fields` matches the `signature` for `agg_fun`. /// -/// This method DOES NOT validate the argument types - only that (at least one, +/// This method DOES NOT validate the argument fields - only that (at least one, /// in the case of [`TypeSignature::OneOf`]) signature matches the desired /// number of input types. pub fn check_arg_count( func_name: &str, - input_types: &[DataType], + input_fields: &[FieldRef], signature: &TypeSignature, ) -> Result<()> { match signature { TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { - if input_types.len() != *agg_count { + if input_fields.len() != *agg_count { return plan_err!( "The function {func_name} expects {:?} arguments, but {:?} were provided", agg_count, - input_types.len() + input_fields.len() ); } } TypeSignature::Exact(types) => { - if types.len() != input_types.len() { + if types.len() != input_fields.len() { return plan_err!( "The function {func_name} expects {:?} arguments, but {:?} were provided", types.len(), - input_types.len() + input_fields.len() ); } } TypeSignature::OneOf(variants) => { let ok = variants .iter() - .any(|v| check_arg_count(func_name, input_types, v).is_ok()); + .any(|v| check_arg_count(func_name, input_fields, v).is_ok()); if !ok { return plan_err!( "The function {func_name} does not accept {:?} function arguments.", - input_types.len() + input_fields.len() ); } } TypeSignature::VariadicAny => { - if input_types.is_empty() { + if input_fields.is_empty() { return plan_err!( "The function {func_name} expects at least one argument" ); @@ -210,6 +210,7 @@ pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4); Ok(DataType::Decimal256(new_precision, new_scale)) } + DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), DataType::Dictionary(_, dict_value_type) => { avg_return_type(func_name, dict_value_type.as_ref()) @@ -231,6 +232,7 @@ pub fn avg_sum_type(arg_type: &DataType) -> Result { let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal256(new_precision, *scale)) } + DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), DataType::Dictionary(_, dict_value_type) => { avg_sum_type(dict_value_type.as_ref()) @@ -298,6 +300,7 @@ pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result Ok(DataType::Decimal128(*p, *s)), DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)), d if d.is_numeric() => Ok(DataType::Float64), + DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), DataType::Dictionary(_, v) => coerced_type(func_name, v.as_ref()), _ => { plan_err!( diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index fdc61cd665ef..955c28c42a3f 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -462,7 +462,7 @@ pub fn type_union_resolution(data_types: &[DataType]) -> Option { // If all the data_types are null, return string if data_types.iter().all(|t| t == &DataType::Null) { - return Some(DataType::Utf8); + return Some(DataType::Utf8View); } // Ignore Nulls, if any data_type category is not the same, return None @@ -931,6 +931,7 @@ fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option { Int32 | UInt32 => Some(Decimal128(10, 0)), Int64 | UInt64 => Some(Decimal128(20, 0)), // TODO if we convert the floating-point data to the decimal type, it maybe overflow. + Float16 => Some(Decimal128(6, 3)), Float32 => Some(Decimal128(14, 7)), Float64 => Some(Decimal128(30, 15)), _ => None, @@ -949,6 +950,7 @@ fn coerce_numeric_type_to_decimal256(numeric_type: &DataType) -> Option Some(Decimal256(10, 0)), Int64 | UInt64 => Some(Decimal256(20, 0)), // TODO if we convert the floating-point data to the decimal type, it maybe overflow. + Float16 => Some(Decimal256(6, 3)), Float32 => Some(Decimal256(14, 7)), Float64 => Some(Decimal256(30, 15)), _ => None, @@ -1044,6 +1046,7 @@ fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Some(Float64), (_, Float32) | (Float32, _) => Some(Float32), + (_, Float16) | (Float16, _) => Some(Float16), // The following match arms encode the following logic: Given the two // integral types, we choose the narrowest possible integral type that // accommodates all values of both types. Note that to avoid information @@ -1138,7 +1141,7 @@ fn dictionary_comparison_coercion( /// 2. Data type of the other side should be able to cast to string type fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; - string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) { + string_coercion(lhs_type, rhs_type).or_else(|| match (lhs_type, rhs_type) { (Utf8View, from_type) | (from_type, Utf8View) => { string_concat_internal_coercion(from_type, &Utf8View) } @@ -1199,7 +1202,8 @@ pub fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (Utf8 | LargeUtf8, other_type) | (other_type, Utf8 | LargeUtf8) + (Utf8 | LargeUtf8 | Utf8View, other_type) + | (other_type, Utf8 | LargeUtf8 | Utf8View) if other_type.is_numeric() => { Some(other_type.clone()) @@ -1297,6 +1301,13 @@ fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Some(LargeBinary) } (Binary, Utf8) | (Utf8, Binary) => Some(Binary), + + // Cast FixedSizeBinary to Binary + (FixedSizeBinary(_), Binary) | (Binary, FixedSizeBinary(_)) => Some(Binary), + (FixedSizeBinary(_), BinaryView) | (BinaryView, FixedSizeBinary(_)) => { + Some(BinaryView) + } + _ => None, } } @@ -1574,6 +1585,10 @@ mod tests { coerce_numeric_type_to_decimal(&DataType::Int64).unwrap(), DataType::Decimal128(20, 0) ); + assert_eq!( + coerce_numeric_type_to_decimal(&DataType::Float16).unwrap(), + DataType::Decimal128(6, 3) + ); assert_eq!( coerce_numeric_type_to_decimal(&DataType::Float32).unwrap(), DataType::Decimal128(14, 7) @@ -2052,6 +2067,13 @@ mod tests { Operator::Plus, Float32 ); + // (_, Float16) | (Float16, _) => Some(Float16), + test_coercion_binary_rule_multiple!( + Float16, + [Float16, Int64, UInt64, Int32, UInt32, Int16, UInt16, Int8, UInt8], + Operator::Plus, + Float16 + ); // (UInt64, Int64 | Int32 | Int16 | Int8) | (Int64 | Int32 | Int16 | Int8, UInt64) => Some(Decimal128(20, 0)), test_coercion_binary_rule_multiple!( UInt64, @@ -2190,6 +2212,18 @@ mod tests { DataType::Boolean ); // float + test_coercion_binary_rule!( + DataType::Float16, + DataType::Int64, + Operator::Eq, + DataType::Float16 + ); + test_coercion_binary_rule!( + DataType::Float16, + DataType::Float64, + Operator::Eq, + DataType::Float64 + ); test_coercion_binary_rule!( DataType::Float32, DataType::Int64, diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 37e1ed1936fb..812544587bf9 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -42,6 +42,7 @@ recursive_protection = ["dep:recursive"] [dependencies] arrow = { workspace = true } +async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true } datafusion-doc = { workspace = true } @@ -58,3 +59,4 @@ sqlparser = { workspace = true } [dev-dependencies] ctor = { workspace = true } env_logger = { workspace = true } +insta = { workspace = true } diff --git a/datafusion/expr/README.md b/datafusion/expr/README.md index b086f930e871..860c36769ee5 100644 --- a/datafusion/expr/README.md +++ b/datafusion/expr/README.md @@ -23,4 +23,9 @@ This crate is a submodule of DataFusion that provides data types and utilities for logical plans and expressions. +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs new file mode 100644 index 000000000000..d900c1634523 --- /dev/null +++ b/datafusion/expr/src/async_udf.rs @@ -0,0 +1,120 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, FieldRef}; +use async_trait::async_trait; +use datafusion_common::config::ConfigOptions; +use datafusion_common::error::Result; +use datafusion_common::internal_err; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::signature::Signature; +use std::any::Any; +use std::fmt::{Debug, Display}; +use std::sync::Arc; + +/// A scalar UDF that can invoke using async methods +/// +/// Note this is less efficient than the ScalarUDFImpl, but it can be used +/// to register remote functions in the context. +/// +/// The name is chosen to mirror ScalarUDFImpl +#[async_trait] +pub trait AsyncScalarUDFImpl: ScalarUDFImpl { + /// The ideal batch size for this function. + /// + /// This is used to determine what size of data to be evaluated at once. + /// If None, the whole batch will be evaluated at once. + fn ideal_batch_size(&self) -> Option { + None + } + + /// Invoke the function asynchronously with the async arguments + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + option: &ConfigOptions, + ) -> Result; +} + +/// A scalar UDF that must be invoked using async methods +/// +/// Note this is not meant to be used directly, but is meant to be an implementation detail +/// for AsyncUDFImpl. +#[derive(Debug)] +pub struct AsyncScalarUDF { + inner: Arc, +} + +impl AsyncScalarUDF { + pub fn new(inner: Arc) -> Self { + Self { inner } + } + + /// The ideal batch size for this function + pub fn ideal_batch_size(&self) -> Option { + self.inner.ideal_batch_size() + } + + /// Turn this AsyncUDF into a ScalarUDF, suitable for + /// registering in the context + pub fn into_scalar_udf(self) -> ScalarUDF { + ScalarUDF::new_from_impl(self) + } + + /// Invoke the function asynchronously with the async arguments + pub async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + option: &ConfigOptions, + ) -> Result { + self.inner.invoke_async_with_args(args, option).await + } +} + +impl ScalarUDFImpl for AsyncScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.inner.name() + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.inner.return_type(arg_types) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + self.inner.return_field_from_args(args) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("async functions should not be called directly") + } +} + +impl Display for AsyncScalarUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "AsyncScalarUDF: {}", self.inner.name()) + } +} diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs index 9cb51612d0ca..69525ea52137 100644 --- a/datafusion/expr/src/conditional_expressions.rs +++ b/datafusion/expr/src/conditional_expressions.rs @@ -72,7 +72,7 @@ impl CaseBuilder { let then_types: Vec = then_expr .iter() .map(|e| match e { - Expr::Literal(_) => e.get_type(&DFSchema::empty()), + Expr::Literal(_, _) => e.get_type(&DFSchema::empty()), _ => Ok(DataType::Null), }) .collect::>>()?; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 9f6855b69824..c50268d99676 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -17,18 +17,20 @@ //! Logical Expressions: [`Expr`] -use std::collections::HashSet; +use std::cmp::Ordering; +use std::collections::{BTreeMap, HashSet}; use std::fmt::{self, Display, Formatter, Write}; use std::hash::{Hash, Hasher}; use std::mem; use std::sync::Arc; use crate::expr_fn::binary_expr; +use crate::function::WindowFunctionSimplification; use crate::logical_plan::Subquery; use crate::Volatility; use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; -use arrow::datatypes::{DataType, FieldRef}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, @@ -50,7 +52,7 @@ use sqlparser::ast::{ /// BinaryExpr { /// left: Expr::Column("A"), /// op: Operator::Plus, -/// right: Expr::Literal(ScalarValue::Int32(Some(1))) +/// right: Expr::Literal(ScalarValue::Int32(Some(1)), None) /// } /// ``` /// @@ -112,10 +114,10 @@ use sqlparser::ast::{ /// # use datafusion_expr::{lit, col, Expr}; /// // All literals are strongly typed in DataFusion. To make an `i64` 42: /// let expr = lit(42i64); -/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)))); -/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)))); +/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)), None)); +/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)), None)); /// // To make a (typed) NULL: -/// let expr = Expr::Literal(ScalarValue::Int64(None)); +/// let expr = Expr::Literal(ScalarValue::Int64(None), None); /// // to make an (untyped) NULL (the optimizer will coerce this to the correct type): /// let expr = lit(ScalarValue::Null); /// ``` @@ -149,7 +151,7 @@ use sqlparser::ast::{ /// if let Expr::BinaryExpr(binary_expr) = expr { /// assert_eq!(*binary_expr.left, col("c1")); /// let scalar = ScalarValue::Int32(Some(42)); -/// assert_eq!(*binary_expr.right, Expr::Literal(scalar)); +/// assert_eq!(*binary_expr.right, Expr::Literal(scalar, None)); /// assert_eq!(binary_expr.op, Operator::Eq); /// } /// ``` @@ -193,7 +195,7 @@ use sqlparser::ast::{ /// ``` /// # use datafusion_expr::{lit, col}; /// let expr = col("c1") + lit(42); -/// assert_eq!(format!("{expr:?}"), "BinaryExpr(BinaryExpr { left: Column(Column { relation: None, name: \"c1\" }), op: Plus, right: Literal(Int32(42)) })"); +/// assert_eq!(format!("{expr:?}"), "BinaryExpr(BinaryExpr { left: Column(Column { relation: None, name: \"c1\" }), op: Plus, right: Literal(Int32(42), None) })"); /// ``` /// /// ## Use the `Display` trait (detailed expression) @@ -239,7 +241,7 @@ use sqlparser::ast::{ /// let mut scalars = HashSet::new(); /// // apply recursively visits all nodes in the expression tree /// expr.apply(|e| { -/// if let Expr::Literal(scalar) = e { +/// if let Expr::Literal(scalar, _) = e { /// scalars.insert(scalar); /// } /// // The return value controls whether to continue visiting the tree @@ -274,7 +276,7 @@ use sqlparser::ast::{ /// assert!(rewritten.transformed); /// // to 42 = 5 AND b = 6 /// assert_eq!(rewritten.data, lit(42).eq(lit(5)).and(col("b").eq(lit(6)))); -#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +#[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] pub enum Expr { /// An expression with a specific name. Alias(Alias), @@ -282,8 +284,8 @@ pub enum Expr { Column(Column), /// A named reference to a variable in a registry. ScalarVariable(DataType, Vec), - /// A constant value. - Literal(ScalarValue), + /// A constant value along with associated [`FieldMetadata`]. + Literal(ScalarValue, Option), /// A binary expression such as "age > 21" BinaryExpr(BinaryExpr), /// LIKE expression @@ -312,27 +314,7 @@ pub enum Expr { Negative(Box), /// Whether an expression is between a given range. Between(Between), - /// The CASE expression is similar to a series of nested if/else and there are two forms that - /// can be used. The first form consists of a series of boolean "when" expressions with - /// corresponding "then" expressions, and an optional "else" expression. - /// - /// ```text - /// CASE WHEN condition THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - /// ``` - /// - /// The second form uses a base expression and then a series of "when" clauses that match on a - /// literal value. - /// - /// ```text - /// CASE expression - /// WHEN value THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - /// ``` + /// A CASE expression (see docs on [`Case`]) Case(Case), /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. @@ -340,7 +322,7 @@ pub enum Expr { /// Casts the expression to a given type and will return a null value if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. TryCast(TryCast), - /// Represents the call of a scalar function with a set of arguments. + /// Call a scalar function with a set of arguments. ScalarFunction(ScalarFunction), /// Calls an aggregate function with arguments, and optional /// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`. @@ -349,8 +331,8 @@ pub enum Expr { /// /// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt AggregateFunction(AggregateFunction), - /// Represents the call of a window function with arguments. - WindowFunction(WindowFunction), + /// Call a window function with a set of arguments. + WindowFunction(Box), /// Returns whether the list contains the expr value. InList(InList), /// EXISTS subquery @@ -378,7 +360,7 @@ pub enum Expr { /// A place holder for parameters in a prepared statement /// (e.g. `$foo` or `$1`) Placeholder(Placeholder), - /// A place holder which hold a reference to a qualified field + /// A placeholder which holds a reference to a qualified field /// in the outer query, used for correlated sub queries. OuterReferenceColumn(DataType, Column), /// Unnest expression @@ -387,7 +369,7 @@ pub enum Expr { impl Default for Expr { fn default() -> Self { - Expr::Literal(ScalarValue::Null) + Expr::Literal(ScalarValue::Null, None) } } @@ -398,6 +380,13 @@ impl From for Expr { } } +/// Create an [`Expr`] from a [`WindowFunction`] +impl From for Expr { + fn from(value: WindowFunction) -> Self { + Expr::WindowFunction(Box::new(value)) + } +} + /// Create an [`Expr`] from an optional qualifier and a [`FieldRef`]. This is /// useful for creating [`Expr`] from a [`DFSchema`]. /// @@ -424,6 +413,192 @@ impl<'a> TreeNodeContainer<'a, Self> for Expr { } } +/// Literal metadata +/// +/// Stores metadata associated with a literal expressions +/// and is designed to be fast to `clone`. +/// +/// This structure is used to store metadata associated with a literal expression, and it +/// corresponds to the `metadata` field on [`Field`]. +/// +/// # Example: Create [`FieldMetadata`] from a [`Field`] +/// ``` +/// # use std::collections::HashMap; +/// # use datafusion_expr::expr::FieldMetadata; +/// # use arrow::datatypes::{Field, DataType}; +/// # let field = Field::new("c1", DataType::Int32, true) +/// # .with_metadata(HashMap::from([("foo".to_string(), "bar".to_string())])); +/// // Create a new `FieldMetadata` instance from a `Field` +/// let metadata = FieldMetadata::new_from_field(&field); +/// // There is also a `From` impl: +/// let metadata = FieldMetadata::from(&field); +/// ``` +/// +/// # Example: Update a [`Field`] with [`FieldMetadata`] +/// ``` +/// # use datafusion_expr::expr::FieldMetadata; +/// # use arrow::datatypes::{Field, DataType}; +/// # let field = Field::new("c1", DataType::Int32, true); +/// # let metadata = FieldMetadata::new_from_field(&field); +/// // Add any metadata from `FieldMetadata` to `Field` +/// let updated_field = metadata.add_to_field(field); +/// ``` +/// +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct FieldMetadata { + /// The inner metadata of a literal expression, which is a map of string + /// keys to string values. + /// + /// Note this is not a `HashMap because `HashMap` does not provide + /// implementations for traits like `Debug` and `Hash`. + inner: Arc>, +} + +impl Default for FieldMetadata { + fn default() -> Self { + Self::new_empty() + } +} + +impl FieldMetadata { + /// Create a new empty metadata instance. + pub fn new_empty() -> Self { + Self { + inner: Arc::new(BTreeMap::new()), + } + } + + /// Merges two optional `FieldMetadata` instances, overwriting any existing + /// keys in `m` with keys from `n` if present + pub fn merge_options( + m: Option<&FieldMetadata>, + n: Option<&FieldMetadata>, + ) -> Option { + match (m, n) { + (Some(m), Some(n)) => { + let mut merged = m.clone(); + merged.extend(n.clone()); + Some(merged) + } + (Some(m), None) => Some(m.clone()), + (None, Some(n)) => Some(n.clone()), + (None, None) => None, + } + } + + /// Create a new metadata instance from a `Field`'s metadata. + pub fn new_from_field(field: &Field) -> Self { + let inner = field + .metadata() + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + Self { + inner: Arc::new(inner), + } + } + + /// Create a new metadata instance from a map of string keys to string values. + pub fn new(inner: BTreeMap) -> Self { + Self { + inner: Arc::new(inner), + } + } + + /// Get the inner metadata as a reference to a `BTreeMap`. + pub fn inner(&self) -> &BTreeMap { + &self.inner + } + + /// Return the inner metadata + pub fn into_inner(self) -> Arc> { + self.inner + } + + /// Adds metadata from `other` into `self`, overwriting any existing keys. + pub fn extend(&mut self, other: Self) { + if other.is_empty() { + return; + } + let other = Arc::unwrap_or_clone(other.into_inner()); + Arc::make_mut(&mut self.inner).extend(other); + } + + /// Returns true if the metadata is empty. + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Returns the number of key-value pairs in the metadata. + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Convert this `FieldMetadata` into a `HashMap` + pub fn to_hashmap(&self) -> std::collections::HashMap { + self.inner + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect() + } + + /// Updates the metadata on the Field with this metadata, if it is not empty. + pub fn add_to_field(&self, field: Field) -> Field { + if self.inner.is_empty() { + return field; + } + + field.with_metadata(self.to_hashmap()) + } +} + +impl From<&Field> for FieldMetadata { + fn from(field: &Field) -> Self { + Self::new_from_field(field) + } +} + +impl From> for FieldMetadata { + fn from(inner: BTreeMap) -> Self { + Self::new(inner) + } +} + +impl From> for FieldMetadata { + fn from(map: std::collections::HashMap) -> Self { + Self::new(map.into_iter().collect()) + } +} + +/// From reference +impl From<&std::collections::HashMap> for FieldMetadata { + fn from(map: &std::collections::HashMap) -> Self { + let inner = map + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + Self::new(inner) + } +} + +/// From hashbrown map +impl From> for FieldMetadata { + fn from(map: HashMap) -> Self { + let inner = map.into_iter().collect(); + Self::new(inner) + } +} + +impl From<&HashMap> for FieldMetadata { + fn from(map: &HashMap) -> Self { + let inner = map + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + Self::new(inner) + } +} + /// UNNEST expression. #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Unnest { @@ -450,7 +625,7 @@ pub struct Alias { pub expr: Box, pub relation: Option, pub name: String, - pub metadata: Option>, + pub metadata: Option, } impl Hash for Alias { @@ -462,13 +637,13 @@ impl Hash for Alias { } impl PartialOrd for Alias { - fn partial_cmp(&self, other: &Self) -> Option { + fn partial_cmp(&self, other: &Self) -> Option { let cmp = self.expr.partial_cmp(&other.expr); - let Some(std::cmp::Ordering::Equal) = cmp else { + let Some(Ordering::Equal) = cmp else { return cmp; }; let cmp = self.relation.partial_cmp(&other.relation); - let Some(std::cmp::Ordering::Equal) = cmp else { + let Some(Ordering::Equal) = cmp else { return cmp; }; self.name.partial_cmp(&other.name) @@ -490,10 +665,7 @@ impl Alias { } } - pub fn with_metadata( - mut self, - metadata: Option>, - ) -> Self { + pub fn with_metadata(mut self, metadata: Option) -> Self { self.metadata = metadata; self } @@ -551,6 +723,28 @@ impl Display for BinaryExpr { } /// CASE expression +/// +/// The CASE expression is similar to a series of nested if/else and there are two forms that +/// can be used. The first form consists of a series of boolean "when" expressions with +/// corresponding "then" expressions, and an optional "else" expression. +/// +/// ```text +/// CASE WHEN condition THEN result +/// [WHEN ...] +/// [ELSE result] +/// END +/// ``` +/// +/// The second form uses a base expression and then a series of "when" clauses that match on a +/// literal value. +/// +/// ```text +/// CASE expression +/// WHEN value THEN result +/// [WHEN ...] +/// [ELSE result] +/// END +/// ``` #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Hash)] pub struct Case { /// Optional base expression that can be compared to literal values in the "when" expressions @@ -631,7 +825,9 @@ impl Between { } } -/// ScalarFunction expression invokes a built-in scalar function +/// Invoke a [`ScalarUDF`] with a set of arguments +/// +/// [`ScalarUDF`]: crate::ScalarUDF #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct ScalarFunction { /// The function @@ -648,7 +844,9 @@ impl ScalarFunction { } impl ScalarFunction { - /// Create a new ScalarFunction expression with a user-defined function (UDF) + /// Create a new `ScalarFunction` from a [`ScalarUDF`] + /// + /// [`ScalarUDF`]: crate::ScalarUDF pub fn new_udf(udf: Arc, args: Vec) -> Self { Self { func: udf, args } } @@ -838,19 +1036,19 @@ pub enum WindowFunctionDefinition { impl WindowFunctionDefinition { /// Returns the datatype of the window function - pub fn return_type( + pub fn return_field( &self, - input_expr_types: &[DataType], + input_expr_fields: &[FieldRef], _input_expr_nullable: &[bool], display_name: &str, - ) -> Result { + ) -> Result { match self { WindowFunctionDefinition::AggregateUDF(fun) => { - fun.return_type(input_expr_types) + fun.return_field(input_expr_fields) + } + WindowFunctionDefinition::WindowUDF(fun) => { + fun.field(WindowUDFFieldArgs::new(input_expr_fields, display_name)) } - WindowFunctionDefinition::WindowUDF(fun) => fun - .field(WindowUDFFieldArgs::new(input_expr_types, display_name)) - .map(|field| field.data_type().clone()), } } @@ -869,6 +1067,16 @@ impl WindowFunctionDefinition { WindowFunctionDefinition::AggregateUDF(fun) => fun.name(), } } + + /// Return the the inner window simplification function, if any + /// + /// See [`WindowFunctionSimplification`] for more information + pub fn simplify(&self) -> Option { + match self { + WindowFunctionDefinition::AggregateUDF(_) => None, + WindowFunctionDefinition::WindowUDF(udwf) => udwf.simplify(), + } + } } impl Display for WindowFunctionDefinition { @@ -940,6 +1148,13 @@ impl WindowFunction { }, } } + + /// Return the the inner window simplification function, if any + /// + /// See [`WindowFunctionSimplification`] for more information + pub fn simplify(&self) -> Option { + self.fun.simplify() + } } /// EXISTS expression @@ -1397,15 +1612,17 @@ impl Expr { /// # Example /// ``` /// # use datafusion_expr::col; - /// use std::collections::HashMap; + /// # use std::collections::HashMap; + /// # use datafusion_expr::expr::FieldMetadata; /// let metadata = HashMap::from([("key".to_string(), "value".to_string())]); + /// let metadata = FieldMetadata::from(metadata); /// let expr = col("foo").alias_with_metadata("bar", Some(metadata)); /// ``` /// pub fn alias_with_metadata( self, name: impl Into, - metadata: Option>, + metadata: Option, ) -> Expr { Expr::Alias(Alias::new(self, None::<&str>, name.into()).with_metadata(metadata)) } @@ -1427,8 +1644,10 @@ impl Expr { /// # Example /// ``` /// # use datafusion_expr::col; - /// use std::collections::HashMap; + /// # use std::collections::HashMap; + /// # use datafusion_expr::expr::FieldMetadata; /// let metadata = HashMap::from([("key".to_string(), "value".to_string())]); + /// let metadata = FieldMetadata::from(metadata); /// let expr = col("foo").alias_qualified_with_metadata(Some("tbl"), "bar", Some(metadata)); /// ``` /// @@ -1436,7 +1655,7 @@ impl Expr { self, relation: Option>, name: impl Into, - metadata: Option>, + metadata: Option, ) -> Expr { Expr::Alias(Alias::new(self, relation, name.into()).with_metadata(metadata)) } @@ -1506,8 +1725,16 @@ impl Expr { |expr| { // f_up: unalias on up so we can remove nested aliases like // `(x as foo) as bar` - if let Expr::Alias(Alias { expr, .. }) = expr { - Ok(Transformed::yes(*expr)) + if let Expr::Alias(alias) = expr { + match alias + .metadata + .as_ref() + .map(|h| h.is_empty()) + .unwrap_or(true) + { + true => Ok(Transformed::yes(*alias.expr)), + false => Ok(Transformed::no(Expr::Alias(alias))), + } } else { Ok(Transformed::no(expr)) } @@ -1747,23 +1974,38 @@ impl Expr { pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<(Expr, bool)> { let mut has_placeholder = false; self.transform(|mut expr| { - // Default to assuming the arguments are the same type - if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { - rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; - rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; - }; - if let Expr::Between(Between { - expr, - negated: _, - low, - high, - }) = &mut expr - { - rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; - rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; - } - if let Expr::Placeholder(_) = &expr { - has_placeholder = true; + match &mut expr { + // Default to assuming the arguments are the same type + Expr::BinaryExpr(BinaryExpr { left, op: _, right }) => { + rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; + rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; + } + Expr::Between(Between { + expr, + negated: _, + low, + high, + }) => { + rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; + rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; + } + Expr::InList(InList { + expr, + list, + negated: _, + }) => { + for item in list.iter_mut() { + rewrite_placeholder(item, expr.as_ref(), schema)?; + } + } + Expr::Like(Like { expr, pattern, .. }) + | Expr::SimilarTo(Like { expr, pattern, .. }) => { + rewrite_placeholder(pattern.as_mut(), expr.as_ref(), schema)?; + } + Expr::Placeholder(_) => { + has_placeholder = true; + } + _ => {} } Ok(Transformed::yes(expr)) }) @@ -1827,6 +2069,15 @@ impl Expr { _ => None, } } + + /// Check if the Expr is literal and get the literal value if it is. + pub fn as_literal(&self) -> Option<&ScalarValue> { + if let Expr::Literal(lit, _) = self { + Some(lit) + } else { + None + } + } } impl Normalizeable for Expr { @@ -2065,32 +2316,29 @@ impl NormalizeEq for Expr { _ => false, } } - ( - Expr::WindowFunction(WindowFunction { + (Expr::WindowFunction(left), Expr::WindowFunction(other)) => { + let WindowFunction { fun: self_fun, - params: self_params, - }), - Expr::WindowFunction(WindowFunction { + params: + WindowFunctionParams { + args: self_args, + window_frame: self_window_frame, + partition_by: self_partition_by, + order_by: self_order_by, + null_treatment: self_null_treatment, + }, + } = left.as_ref(); + let WindowFunction { fun: other_fun, - params: other_params, - }), - ) => { - let ( - WindowFunctionParams { - args: self_args, - window_frame: self_window_frame, - partition_by: self_partition_by, - order_by: self_order_by, - null_treatment: self_null_treatment, - }, - WindowFunctionParams { - args: other_args, - window_frame: other_window_frame, - partition_by: other_partition_by, - order_by: other_order_by, - null_treatment: other_null_treatment, - }, - ) = (self_params, other_params); + params: + WindowFunctionParams { + args: other_args, + window_frame: other_window_frame, + partition_by: other_partition_by, + order_by: other_order_by, + null_treatment: other_null_treatment, + }, + } = other.as_ref(); self_fun.name() == other_fun.name() && self_window_frame == other_window_frame @@ -2256,7 +2504,7 @@ impl HashNode for Expr { data_type.hash(state); name.hash(state); } - Expr::Literal(scalar_value) => { + Expr::Literal(scalar_value, _) => { scalar_value.hash(state); } Expr::BinaryExpr(BinaryExpr { @@ -2335,14 +2583,18 @@ impl HashNode for Expr { distinct.hash(state); null_treatment.hash(state); } - Expr::WindowFunction(WindowFunction { fun, params }) => { - let WindowFunctionParams { - args: _args, - partition_by: _, - order_by: _, - window_frame, - null_treatment, - } = params; + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args: _args, + partition_by: _, + order_by: _, + window_frame, + null_treatment, + }, + } = window_fun.as_ref(); fun.hash(state); window_frame.hash(state); null_treatment.hash(state); @@ -2432,7 +2684,7 @@ impl Display for SchemaDisplay<'_> { // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::Column(_) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::ScalarVariable(..) | Expr::OuterReferenceColumn(..) | Expr::Placeholder(_) @@ -2443,7 +2695,7 @@ impl Display for SchemaDisplay<'_> { write!(f, "{name}") } Err(e) => { - write!(f, "got error from schema_name {}", e) + write!(f, "got error from schema_name {e}") } } } @@ -2594,7 +2846,7 @@ impl Display for SchemaDisplay<'_> { write!(f, "{name}") } Err(e) => { - write!(f, "got error from schema_name {}", e) + write!(f, "got error from schema_name {e}") } } } @@ -2625,52 +2877,62 @@ impl Display for SchemaDisplay<'_> { Ok(()) } - Expr::WindowFunction(WindowFunction { fun, params }) => match fun { - WindowFunctionDefinition::AggregateUDF(fun) => { - match fun.window_function_schema_name(params) { - Ok(name) => { - write!(f, "{name}") - } - Err(e) => { - write!(f, "got error from window_function_schema_name {}", e) + Expr::WindowFunction(window_fun) => { + let WindowFunction { fun, params } = window_fun.as_ref(); + match fun { + WindowFunctionDefinition::AggregateUDF(fun) => { + match fun.window_function_schema_name(params) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!( + f, + "got error from window_function_schema_name {e}" + ) + } } } - } - _ => { - let WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - } = params; - - write!( - f, - "{}({})", - fun, - schema_name_from_exprs_comma_separated_without_space(args)? - )?; + _ => { + let WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + } = params; - if let Some(null_treatment) = null_treatment { - write!(f, " {}", null_treatment)?; - } - - if !partition_by.is_empty() { write!( f, - " PARTITION BY [{}]", - schema_name_from_exprs(partition_by)? + "{}({})", + fun, + schema_name_from_exprs_comma_separated_without_space(args)? )?; - } - if !order_by.is_empty() { - write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?; - }; + if let Some(null_treatment) = null_treatment { + write!(f, " {null_treatment}")?; + } + + if !partition_by.is_empty() { + write!( + f, + " PARTITION BY [{}]", + schema_name_from_exprs(partition_by)? + )?; + } - write!(f, " {window_frame}") + if !order_by.is_empty() { + write!( + f, + " ORDER BY [{}]", + schema_name_from_sorts(order_by)? + )?; + }; + + write!(f, " {window_frame}") + } } - }, + } } } } @@ -2681,7 +2943,7 @@ struct SqlDisplay<'a>(&'a Expr); impl Display for SqlDisplay<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self.0 { - Expr::Literal(scalar) => scalar.fmt(f), + Expr::Literal(scalar, _) => scalar.fmt(f), Expr::Alias(Alias { name, .. }) => write!(f, "{name}"), Expr::Between(Between { expr, @@ -2847,7 +3109,7 @@ impl Display for SqlDisplay<'_> { write!(f, "{name}") } Err(e) => { - write!(f, "got error from schema_name {}", e) + write!(f, "got error from schema_name {e}") } } } @@ -2948,7 +3210,12 @@ impl Display for Expr { write!(f, "{OUTER_REFERENCE_COLUMN_PREFIX}({c})") } Expr::ScalarVariable(_, var_names) => write!(f, "{}", var_names.join(".")), - Expr::Literal(v) => write!(f, "{v:?}"), + Expr::Literal(v, metadata) => { + match metadata.as_ref().map(|m| m.is_empty()).unwrap_or(true) { + false => write!(f, "{v:?} {:?}", metadata.as_ref().unwrap()), + true => write!(f, "{v:?}"), + } + } Expr::Case(case) => { write!(f, "CASE ")?; if let Some(e) = &case.expr { @@ -3005,54 +3272,60 @@ impl Display for Expr { // Expr::ScalarFunction(ScalarFunction { func, args }) => { // write!(f, "{}", func.display_name(args).unwrap()) // } - Expr::WindowFunction(WindowFunction { fun, params }) => match fun { - WindowFunctionDefinition::AggregateUDF(fun) => { - match fun.window_function_display_name(params) { - Ok(name) => { - write!(f, "{}", name) - } - Err(e) => { - write!(f, "got error from window_function_display_name {}", e) + Expr::WindowFunction(window_fun) => { + let WindowFunction { fun, params } = window_fun.as_ref(); + match fun { + WindowFunctionDefinition::AggregateUDF(fun) => { + match fun.window_function_display_name(params) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!( + f, + "got error from window_function_display_name {e}" + ) + } } } - } - WindowFunctionDefinition::WindowUDF(fun) => { - let WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - } = params; - - fmt_function(f, &fun.to_string(), false, args, true)?; - - if let Some(nt) = null_treatment { - write!(f, "{}", nt)?; - } + WindowFunctionDefinition::WindowUDF(fun) => { + let WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + } = params; + + fmt_function(f, &fun.to_string(), false, args, true)?; + + if let Some(nt) = null_treatment { + write!(f, "{nt}")?; + } - if !partition_by.is_empty() { - write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?; - } - if !order_by.is_empty() { - write!(f, " ORDER BY [{}]", expr_vec_fmt!(order_by))?; + if !partition_by.is_empty() { + write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?; + } + if !order_by.is_empty() { + write!(f, " ORDER BY [{}]", expr_vec_fmt!(order_by))?; + } + write!( + f, + " {} BETWEEN {} AND {}", + window_frame.units, + window_frame.start_bound, + window_frame.end_bound + ) } - write!( - f, - " {} BETWEEN {} AND {}", - window_frame.units, - window_frame.start_bound, - window_frame.end_bound - ) } - }, + } Expr::AggregateFunction(AggregateFunction { func, params }) => { match func.display_name(params) { Ok(name) => { - write!(f, "{}", name) + write!(f, "{name}") } Err(e) => { - write!(f, "got error from display_name {}", e) + write!(f, "got error from display_name {e}") } } } @@ -3185,10 +3458,117 @@ mod test { case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility, }; + use arrow::datatypes::{Field, Schema}; use sqlparser::ast; use sqlparser::ast::{Ident, IdentWithAlias}; use std::any::Any; + #[test] + fn infer_placeholder_in_clause() { + // SELECT * FROM employees WHERE department_id IN ($1, $2, $3); + let column = col("department_id"); + let param_placeholders = vec![ + Expr::Placeholder(Placeholder { + id: "$1".to_string(), + data_type: None, + }), + Expr::Placeholder(Placeholder { + id: "$2".to_string(), + data_type: None, + }), + Expr::Placeholder(Placeholder { + id: "$3".to_string(), + data_type: None, + }), + ]; + let in_list = Expr::InList(InList { + expr: Box::new(column), + list: param_placeholders, + negated: false, + }); + + let schema = Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, true), + Field::new("department_id", DataType::Int32, true), + ])); + let df_schema = DFSchema::try_from(schema).unwrap(); + + let (inferred_expr, contains_placeholder) = + in_list.infer_placeholder_types(&df_schema).unwrap(); + + assert!(contains_placeholder); + + match inferred_expr { + Expr::InList(in_list) => { + for expr in in_list.list { + match expr { + Expr::Placeholder(placeholder) => { + assert_eq!( + placeholder.data_type, + Some(DataType::Int32), + "Placeholder {} should infer Int32", + placeholder.id + ); + } + _ => panic!("Expected Placeholder expression"), + } + } + } + _ => panic!("Expected InList expression"), + } + } + + #[test] + fn infer_placeholder_like_and_similar_to() { + // name LIKE $1 + let schema = + Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, true)])); + let df_schema = DFSchema::try_from(schema).unwrap(); + + let like = Like { + expr: Box::new(col("name")), + pattern: Box::new(Expr::Placeholder(Placeholder { + id: "$1".to_string(), + data_type: None, + })), + negated: false, + case_insensitive: false, + escape_char: None, + }; + + let expr = Expr::Like(like.clone()); + + let (inferred_expr, _) = expr.infer_placeholder_types(&df_schema).unwrap(); + match inferred_expr { + Expr::Like(like) => match *like.pattern { + Expr::Placeholder(placeholder) => { + assert_eq!(placeholder.data_type, Some(DataType::Utf8)); + } + _ => panic!("Expected Placeholder"), + }, + _ => panic!("Expected Like"), + } + + // name SIMILAR TO $1 + let expr = Expr::SimilarTo(like); + + let (inferred_expr, _) = expr.infer_placeholder_types(&df_schema).unwrap(); + match inferred_expr { + Expr::SimilarTo(like) => match *like.pattern { + Expr::Placeholder(placeholder) => { + assert_eq!( + placeholder.data_type, + Some(DataType::Utf8), + "Placeholder {} should infer Utf8", + placeholder.id + ); + } + _ => panic!("Expected Placeholder expression"), + }, + _ => panic!("Expected SimilarTo expression"), + } + } + #[test] #[allow(deprecated)] fn format_case_when() -> Result<()> { @@ -3206,7 +3586,7 @@ mod test { #[allow(deprecated)] fn format_cast() -> Result<()> { let expr = Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)))), + expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)), None)), data_type: DataType::Utf8, }); let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; @@ -3464,4 +3844,19 @@ mod test { rename: opt_rename, } } + + #[test] + fn test_size_of_expr() { + // because Expr is such a widely used struct in DataFusion + // it is important to keep its size as small as possible + // + // If this test fails when you change `Expr`, please try + // `Box`ing the fields to make `Expr` smaller + // See https://github.com/apache/datafusion/issues/16199 for details + assert_eq!(size_of::(), 128); + assert_eq!(size_of::(), 64); + assert_eq!(size_of::(), 24); // 3 ptrs + assert_eq!(size_of::>(), 24); + assert_eq!(size_of::>(), 8); + } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 966aba7d1195..e8885ed6b724 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -37,7 +37,7 @@ use crate::{ use arrow::compute::kernels::cast_utils::{ parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, }; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{plan_err, Column, Result, ScalarValue, Spans, TableReference}; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; @@ -492,6 +492,7 @@ pub fn create_udaf( .into_iter() .enumerate() .map(|(i, t)| Field::new(format!("{i}"), t, true)) + .map(Arc::new) .collect::>(); AggregateUDF::from(SimpleAggregateUDF::new( name, @@ -510,7 +511,7 @@ pub struct SimpleAggregateUDF { signature: Signature, return_type: DataType, accumulator: AccumulatorFactoryFunction, - state_fields: Vec, + state_fields: Vec, } impl Debug for SimpleAggregateUDF { @@ -533,7 +534,7 @@ impl SimpleAggregateUDF { return_type: DataType, volatility: Volatility, accumulator: AccumulatorFactoryFunction, - state_fields: Vec, + state_fields: Vec, ) -> Self { let name = name.into(); let signature = Signature::exact(input_type, volatility); @@ -553,7 +554,7 @@ impl SimpleAggregateUDF { signature: Signature, return_type: DataType, accumulator: AccumulatorFactoryFunction, - state_fields: Vec, + state_fields: Vec, ) -> Self { let name = name.into(); Self { @@ -590,7 +591,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF { (self.accumulator)(acc_args) } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(self.state_fields.clone()) } } @@ -678,28 +679,28 @@ impl WindowUDFImpl for SimpleWindowUDF { (self.partition_evaluator_factory)() } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new( + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Arc::new(Field::new( field_args.name(), self.return_type.clone(), true, - )) + ))) } } pub fn interval_year_month_lit(value: &str) -> Expr { let interval = parse_interval_year_month(value).ok(); - Expr::Literal(ScalarValue::IntervalYearMonth(interval)) + Expr::Literal(ScalarValue::IntervalYearMonth(interval), None) } pub fn interval_datetime_lit(value: &str) -> Expr { let interval = parse_interval_day_time(value).ok(); - Expr::Literal(ScalarValue::IntervalDayTime(interval)) + Expr::Literal(ScalarValue::IntervalDayTime(interval), None) } pub fn interval_month_day_nano_lit(value: &str) -> Expr { let interval = parse_interval_month_day_nano(value).ok(); - Expr::Literal(ScalarValue::IntervalMonthDayNano(interval)) + Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), None) } /// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] @@ -831,14 +832,14 @@ impl ExprFuncBuilder { params: WindowFunctionParams { args, .. }, }) => { let has_order_by = order_by.as_ref().map(|o| !o.is_empty()); - Expr::WindowFunction(WindowFunction { + Expr::from(WindowFunction { fun, params: WindowFunctionParams { args, partition_by: partition_by.unwrap_or_default(), order_by: order_by.unwrap_or_default(), window_frame: window_frame - .unwrap_or(WindowFrame::new(has_order_by)), + .unwrap_or_else(|| WindowFrame::new(has_order_by)), null_treatment, }, }) @@ -895,7 +896,7 @@ impl ExprFunctionExt for Expr { ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) } Expr::WindowFunction(udwf) => { - ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))) } _ => ExprFuncBuilder::new(None), }; @@ -935,7 +936,7 @@ impl ExprFunctionExt for Expr { ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) } Expr::WindowFunction(udwf) => { - ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))) } _ => ExprFuncBuilder::new(None), }; @@ -948,7 +949,7 @@ impl ExprFunctionExt for Expr { fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder { match self { Expr::WindowFunction(udwf) => { - let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))); builder.partition_by = Some(partition_by); builder } @@ -959,7 +960,7 @@ impl ExprFunctionExt for Expr { fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder { match self { Expr::WindowFunction(udwf) => { - let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))); builder.window_frame = Some(window_frame); builder } diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 90dcbce46b01..05a9425452a1 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -354,6 +354,7 @@ mod test { use std::ops::Add; use super::*; + use crate::literal::lit_with_metadata; use crate::{col, lit, Cast}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::TreeNodeRewriter; @@ -383,13 +384,13 @@ mod test { // rewrites all "foo" string literals to "bar" let transformer = |expr: Expr| -> Result> { match expr { - Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => { + Expr::Literal(ScalarValue::Utf8(Some(utf8_val)), metadata) => { let utf8_val = if utf8_val == "foo" { "bar".to_string() } else { utf8_val }; - Ok(Transformed::yes(lit(utf8_val))) + Ok(Transformed::yes(lit_with_metadata(utf8_val, metadata))) } // otherwise, return None _ => Ok(Transformed::no(expr)), diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index a349c83a4934..8ca479bb6f9b 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -17,24 +17,23 @@ use super::{Between, Expr, Like}; use crate::expr::{ - AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, - InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, + AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, FieldMetadata, + InList, InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; use crate::type_coercion::functions::{ - data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf, + data_types_with_scalar_udf, fields_with_aggregate_udf, fields_with_window_udf, }; -use crate::udf::ReturnTypeArgs; +use crate::udf::ReturnFieldArgs; use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, Result, Spans, TableReference, }; use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer; use datafusion_functions_window_common::field::WindowUDFFieldArgs; -use std::collections::HashMap; use std::sync::Arc; /// Trait to allow expr to typable with respect to a schema @@ -46,7 +45,7 @@ pub trait ExprSchemable { fn nullable(&self, input_schema: &dyn ExprSchema) -> Result; /// Given a schema, return the expr's optional metadata - fn metadata(&self, schema: &dyn ExprSchema) -> Result>; + fn metadata(&self, schema: &dyn ExprSchema) -> Result; /// Convert to a field with respect to a schema fn to_field( @@ -115,7 +114,7 @@ impl ExprSchemable for Expr { Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), - Expr::Literal(l) => Ok(l.data_type()), + Expr::Literal(l, _) => Ok(l.data_type()), Expr::Case(case) => { for (_, then_expr) in &case.when_then_expr { let then_type = then_expr.get_type(schema)?; @@ -158,12 +157,16 @@ impl ExprSchemable for Expr { func, params: AggregateFunctionParams { args, .. }, }) => { - let data_types = args + let fields = args .iter() - .map(|e| e.get_type(schema)) + .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; - let new_types = data_types_with_aggregate_udf(&data_types, func) + let new_fields = fields_with_aggregate_udf(&fields, func) .map_err(|err| { + let data_types = fields + .iter() + .map(|f| f.data_type().clone()) + .collect::>(); plan_datafusion_err!( "{} {}", match err { @@ -176,8 +179,10 @@ impl ExprSchemable for Expr { &data_types ) ) - })?; - Ok(func.return_type(&new_types)?) + })? + .into_iter() + .collect::>(); + Ok(func.return_field(&new_fields)?.data_type().clone()) } Expr::Not(_) | Expr::IsNull(_) @@ -272,7 +277,7 @@ impl ExprSchemable for Expr { Expr::Column(c) => input_schema.nullable(c), Expr::OuterReferenceColumn(_, _) => Ok(true), - Expr::Literal(value) => Ok(value.is_null()), + Expr::Literal(value, _) => Ok(value.is_null()), Expr::Case(case) => { // This expression is nullable if any of the input expressions are nullable let then_nullable = case @@ -340,22 +345,9 @@ impl ExprSchemable for Expr { } } - fn metadata(&self, schema: &dyn ExprSchema) -> Result> { - match self { - Expr::Column(c) => Ok(schema.metadata(c)?.clone()), - Expr::Alias(Alias { expr, metadata, .. }) => { - let mut ret = expr.metadata(schema)?; - if let Some(metadata) = metadata { - if !metadata.is_empty() { - ret.extend(metadata.clone()); - return Ok(ret); - } - } - Ok(ret) - } - Expr::Cast(Cast { expr, .. }) => expr.metadata(schema), - _ => Ok(HashMap::new()), - } + fn metadata(&self, schema: &dyn ExprSchema) -> Result { + self.to_field(schema) + .map(|(_, field)| FieldMetadata::from(field.metadata())) } /// Returns the datatype and nullability of the expression based on [ExprSchema]. @@ -372,23 +364,66 @@ impl ExprSchemable for Expr { &self, schema: &dyn ExprSchema, ) -> Result<(DataType, bool)> { - match self { - Expr::Alias(Alias { expr, name, .. }) => match &**expr { - Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { - None => schema - .data_type_and_nullable(&Column::from_name(name)) - .map(|(d, n)| (d.clone(), n)), - Some(dt) => Ok((dt.clone(), expr.nullable(schema)?)), - }, - _ => expr.data_type_and_nullable(schema), - }, - Expr::Negative(expr) => expr.data_type_and_nullable(schema), - Expr::Column(c) => schema - .data_type_and_nullable(c) - .map(|(d, n)| (d.clone(), n)), - Expr::OuterReferenceColumn(ty, _) => Ok((ty.clone(), true)), - Expr::ScalarVariable(ty, _) => Ok((ty.clone(), true)), - Expr::Literal(l) => Ok((l.data_type(), l.is_null())), + let field = self.to_field(schema)?.1; + + Ok((field.data_type().clone(), field.is_nullable())) + } + + /// Returns a [arrow::datatypes::Field] compatible with this expression. + /// + /// So for example, a projected expression `col(c1) + col(c2)` is + /// placed in an output field **named** col("c1 + c2") + fn to_field( + &self, + schema: &dyn ExprSchema, + ) -> Result<(Option, Arc)> { + let (relation, schema_name) = self.qualified_name(); + #[allow(deprecated)] + let field = match self { + Expr::Alias(Alias { + expr, + name, + metadata, + .. + }) => { + let field = match &**expr { + Expr::Placeholder(Placeholder { data_type, .. }) => { + match &data_type { + None => schema + .data_type_and_nullable(&Column::from_name(name)) + .map(|(d, n)| Field::new(&schema_name, d.clone(), n)), + Some(dt) => Ok(Field::new( + &schema_name, + dt.clone(), + expr.nullable(schema)?, + )), + } + } + _ => expr.to_field(schema).map(|(_, f)| f.as_ref().clone()), + }?; + + let mut combined_metadata = expr.metadata(schema)?; + if let Some(metadata) = metadata { + combined_metadata.extend(metadata.clone()); + } + + Ok(Arc::new(combined_metadata.add_to_field(field))) + } + Expr::Negative(expr) => expr.to_field(schema).map(|(_, f)| f), + Expr::Column(c) => schema.field_from_column(c).map(|f| Arc::new(f.clone())), + Expr::OuterReferenceColumn(ty, _) => { + Ok(Arc::new(Field::new(&schema_name, ty.clone(), true))) + } + Expr::ScalarVariable(ty, _) => { + Ok(Arc::new(Field::new(&schema_name, ty.clone(), true))) + } + Expr::Literal(l, metadata) => { + let mut field = Field::new(&schema_name, l.data_type(), l.is_null()); + if let Some(metadata) = metadata { + field = metadata.add_to_field(field); + } + Ok(Arc::new(field)) + } Expr::IsNull(_) | Expr::IsNotNull(_) | Expr::IsTrue(_) @@ -397,11 +432,12 @@ impl ExprSchemable for Expr { | Expr::IsNotTrue(_) | Expr::IsNotFalse(_) | Expr::IsNotUnknown(_) - | Expr::Exists { .. } => Ok((DataType::Boolean, false)), - Expr::ScalarSubquery(subquery) => Ok(( - subquery.subquery.schema().field(0).data_type().clone(), - subquery.subquery.schema().field(0).is_nullable(), - )), + | Expr::Exists { .. } => { + Ok(Arc::new(Field::new(&schema_name, DataType::Boolean, false))) + } + Expr::ScalarSubquery(subquery) => { + Ok(Arc::new(subquery.subquery.schema().field(0).clone())) + } Expr::BinaryExpr(BinaryExpr { ref left, ref right, @@ -412,17 +448,63 @@ impl ExprSchemable for Expr { let mut coercer = BinaryTypeCoercer::new(&lhs_type, op, &rhs_type); coercer.set_lhs_spans(left.spans().cloned().unwrap_or_default()); coercer.set_rhs_spans(right.spans().cloned().unwrap_or_default()); - Ok((coercer.get_result_type()?, lhs_nullable || rhs_nullable)) + Ok(Arc::new(Field::new( + &schema_name, + coercer.get_result_type()?, + lhs_nullable || rhs_nullable, + ))) } Expr::WindowFunction(window_function) => { - self.data_type_and_nullable_with_window_function(schema, window_function) + let (dt, nullable) = self.data_type_and_nullable_with_window_function( + schema, + window_function, + )?; + Ok(Arc::new(Field::new(&schema_name, dt, nullable))) + } + Expr::AggregateFunction(aggregate_function) => { + let AggregateFunction { + func, + params: AggregateFunctionParams { args, .. }, + .. + } = aggregate_function; + + let fields = args + .iter() + .map(|e| e.to_field(schema).map(|(_, f)| f)) + .collect::>>()?; + // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` + let new_fields = fields_with_aggregate_udf(&fields, func) + .map_err(|err| { + let arg_types = fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + plan_datafusion_err!( + "{} {}", + match err { + DataFusionError::Plan(msg) => msg, + err => err.to_string(), + }, + utils::generate_signature_error_msg( + func.name(), + func.signature().clone(), + &arg_types, + ) + ) + })? + .into_iter() + .collect::>(); + + func.return_field(&new_fields) } Expr::ScalarFunction(ScalarFunction { func, args }) => { - let (arg_types, nullables): (Vec, Vec) = args + let (arg_types, fields): (Vec, Vec>) = args .iter() - .map(|e| e.data_type_and_nullable(schema)) + .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()? .into_iter() + .map(|f| (f.data_type().clone(), f)) .unzip(); // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` let new_data_types = data_types_with_scalar_udf(&arg_types, func) @@ -440,42 +522,54 @@ impl ExprSchemable for Expr { ) ) })?; + let new_fields = fields + .into_iter() + .zip(new_data_types) + .map(|(f, d)| f.as_ref().clone().with_data_type(d)) + .map(Arc::new) + .collect::>(); let arguments = args .iter() .map(|e| match e { - Expr::Literal(sv) => Some(sv), + Expr::Literal(sv, _) => Some(sv), _ => None, }) .collect::>(); - let args = ReturnTypeArgs { - arg_types: &new_data_types, + let args = ReturnFieldArgs { + arg_fields: &new_fields, scalar_arguments: &arguments, - nullables: &nullables, }; - let (return_type, nullable) = - func.return_type_from_args(args)?.into_parts(); - Ok((return_type, nullable)) + func.return_field_from_args(args) } - _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), - } - } - - /// Returns a [arrow::datatypes::Field] compatible with this expression. - /// - /// So for example, a projected expression `col(c1) + col(c2)` is - /// placed in an output field **named** col("c1 + c2") - fn to_field( - &self, - input_schema: &dyn ExprSchema, - ) -> Result<(Option, Arc)> { - let (relation, schema_name) = self.qualified_name(); - let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; - let field = Field::new(schema_name, data_type, nullable) - .with_metadata(self.metadata(input_schema)?) - .into(); - Ok((relation, field)) + // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), + Expr::Cast(Cast { expr, data_type }) => expr + .to_field(schema) + .map(|(_, f)| f.as_ref().clone().with_data_type(data_type.clone())) + .map(Arc::new), + Expr::Like(_) + | Expr::SimilarTo(_) + | Expr::Not(_) + | Expr::Between(_) + | Expr::Case(_) + | Expr::TryCast(_) + | Expr::InList(_) + | Expr::InSubquery(_) + | Expr::Wildcard { .. } + | Expr::GroupingSet(_) + | Expr::Placeholder(_) + | Expr::Unnest(_) => Ok(Arc::new(Field::new( + &schema_name, + self.get_type(schema)?, + self.nullable(schema)?, + ))), + }?; + + Ok(( + relation, + Arc::new(field.as_ref().clone().with_name(schema_name)), + )) } /// Wraps this expression in a cast to a target [arrow::datatypes::DataType]. @@ -528,13 +622,18 @@ impl Expr { .. } = window_function; - let data_types = args + let fields = args .iter() - .map(|e| e.get_type(schema)) + .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; match fun { WindowFunctionDefinition::AggregateUDF(udaf) => { - let new_types = data_types_with_aggregate_udf(&data_types, udaf) + let data_types = fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + let new_fields = fields_with_aggregate_udf(&fields, udaf) .map_err(|err| { plan_datafusion_err!( "{} {}", @@ -548,16 +647,22 @@ impl Expr { &data_types ) ) - })?; + })? + .into_iter() + .collect::>(); - let return_type = udaf.return_type(&new_types)?; - let nullable = udaf.is_nullable(); + let return_field = udaf.return_field(&new_fields)?; - Ok((return_type, nullable)) + Ok((return_field.data_type().clone(), return_field.is_nullable())) } WindowFunctionDefinition::WindowUDF(udwf) => { - let new_types = - data_types_with_window_udf(&data_types, udwf).map_err(|err| { + let data_types = fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + let new_fields = fields_with_window_udf(&fields, udwf) + .map_err(|err| { plan_datafusion_err!( "{} {}", match err { @@ -570,9 +675,11 @@ impl Expr { &data_types ) ) - })?; + })? + .into_iter() + .collect::>(); let (_, function_name) = self.qualified_name(); - let field_args = WindowUDFFieldArgs::new(&new_types, &function_name); + let field_args = WindowUDFFieldArgs::new(&new_fields, &function_name); udwf.field(field_args) .map(|field| (field.data_type().clone(), field.is_nullable())) @@ -626,7 +733,7 @@ mod tests { use super::*; use crate::{col, lit}; - use datafusion_common::{internal_err, DFSchema, ScalarValue}; + use datafusion_common::{internal_err, DFSchema, HashMap, ScalarValue}; macro_rules! test_is_expr_nullable { ($EXPR_TYPE:ident) => {{ @@ -732,6 +839,7 @@ mod tests { fn test_expr_metadata() { let mut meta = HashMap::new(); meta.insert("bar".to_string(), "buzz".to_string()); + let meta = FieldMetadata::from(meta); let expr = col("foo"); let schema = MockExprSchema::new() .with_data_type(DataType::Int32) @@ -750,41 +858,36 @@ mod tests { ); let schema = DFSchema::from_unqualified_fields( - vec![Field::new("foo", DataType::Int32, true).with_metadata(meta.clone())] - .into(), - HashMap::new(), + vec![meta.add_to_field(Field::new("foo", DataType::Int32, true))].into(), + std::collections::HashMap::new(), ) .unwrap(); // verify to_field method populates metadata - assert_eq!(&meta, expr.to_field(&schema).unwrap().1.metadata()); + assert_eq!(meta, expr.metadata(&schema).unwrap()); } #[derive(Debug)] struct MockExprSchema { - nullable: bool, - data_type: DataType, + field: Field, error_on_nullable: bool, - metadata: HashMap, } impl MockExprSchema { fn new() -> Self { Self { - nullable: false, - data_type: DataType::Null, + field: Field::new("mock_field", DataType::Null, false), error_on_nullable: false, - metadata: HashMap::new(), } } fn with_nullable(mut self, nullable: bool) -> Self { - self.nullable = nullable; + self.field = self.field.with_nullable(nullable); self } fn with_data_type(mut self, data_type: DataType) -> Self { - self.data_type = data_type; + self.field = self.field.with_data_type(data_type); self } @@ -793,8 +896,8 @@ mod tests { self } - fn with_metadata(mut self, metadata: HashMap) -> Self { - self.metadata = metadata; + fn with_metadata(mut self, metadata: FieldMetadata) -> Self { + self.field = metadata.add_to_field(self.field); self } } @@ -804,20 +907,12 @@ mod tests { if self.error_on_nullable { internal_err!("nullable error") } else { - Ok(self.nullable) + Ok(self.field.is_nullable()) } } - fn data_type(&self, _col: &Column) -> Result<&DataType> { - Ok(&self.data_type) - } - - fn metadata(&self, _col: &Column) -> Result<&HashMap> { - Ok(&self.metadata) - } - - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { - Ok((self.data_type(col)?, self.nullable(col)?)) + fn field_from_column(&self, _col: &Column) -> Result<&Field> { + Ok(&self.field) } } } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index d3cc881af361..0c822bbb337b 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -63,6 +63,7 @@ pub mod simplify; pub mod sort_properties { pub use datafusion_expr_common::sort_properties::*; } +pub mod async_udf; pub mod statistics { pub use datafusion_expr_common::statistics::*; } @@ -94,7 +95,9 @@ pub use function::{ AccumulatorFactoryFunction, PartitionEvaluatorFactory, ReturnTypeFunction, ScalarFunctionImplementation, StateTypeFunction, }; -pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral}; +pub use literal::{ + lit, lit_timestamp_nano, lit_with_metadata, Literal, TimestampLiteral, +}; pub use logical_plan::*; pub use partition_evaluator::PartitionEvaluator; pub use sqlparser; @@ -104,8 +107,7 @@ pub use udaf::{ SetMonotonicity, StatisticsArgs, }; pub use udf::{ - scalar_doc_sections, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, + scalar_doc_sections, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, }; pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/literal.rs b/datafusion/expr/src/literal.rs index 90ba5a9a693c..c4bd43bc0a62 100644 --- a/datafusion/expr/src/literal.rs +++ b/datafusion/expr/src/literal.rs @@ -17,6 +17,7 @@ //! Literal module contains foundational types that are used to represent literals in DataFusion. +use crate::expr::FieldMetadata; use crate::Expr; use datafusion_common::ScalarValue; @@ -25,6 +26,25 @@ pub fn lit(n: T) -> Expr { n.lit() } +pub fn lit_with_metadata(n: T, metadata: Option) -> Expr { + let Some(metadata) = metadata else { + return n.lit(); + }; + + let Expr::Literal(sv, prior_metadata) = n.lit() else { + unreachable!(); + }; + let new_metadata = match prior_metadata { + Some(mut prior) => { + prior.extend(metadata); + prior + } + None => metadata, + }; + + Expr::Literal(sv, Some(new_metadata)) +} + /// Create a literal timestamp expression pub fn lit_timestamp_nano(n: T) -> Expr { n.lit_timestamp_nano() @@ -43,37 +63,37 @@ pub trait TimestampLiteral { impl Literal for &str { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::from(*self)) + Expr::Literal(ScalarValue::from(*self), None) } } impl Literal for String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::from(self.as_ref())) + Expr::Literal(ScalarValue::from(self.as_ref()), None) } } impl Literal for &String { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::from(self.as_ref())) + Expr::Literal(ScalarValue::from(self.as_ref()), None) } } impl Literal for Vec { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::Binary(Some((*self).to_owned())), None) } } impl Literal for &[u8] { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::Binary(Some((*self).to_owned()))) + Expr::Literal(ScalarValue::Binary(Some((*self).to_owned())), None) } } impl Literal for ScalarValue { fn lit(&self) -> Expr { - Expr::Literal(self.clone()) + Expr::Literal(self.clone(), None) } } @@ -82,7 +102,7 @@ macro_rules! make_literal { #[doc = $DOC] impl Literal for $TYPE { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::$SCALAR(Some(self.clone()))) + Expr::Literal(ScalarValue::$SCALAR(Some(self.clone())), None) } } }; @@ -93,7 +113,7 @@ macro_rules! make_nonzero_literal { #[doc = $DOC] impl Literal for $TYPE { fn lit(&self) -> Expr { - Expr::Literal(ScalarValue::$SCALAR(Some(self.get()))) + Expr::Literal(ScalarValue::$SCALAR(Some(self.get())), None) } } }; @@ -104,10 +124,10 @@ macro_rules! make_timestamp_literal { #[doc = $DOC] impl TimestampLiteral for $TYPE { fn lit_timestamp_nano(&self) -> Expr { - Expr::Literal(ScalarValue::TimestampNanosecond( - Some((self.clone()).into()), + Expr::Literal( + ScalarValue::TimestampNanosecond(Some((self.clone()).into()), None), None, - )) + ) } } }; diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 64931df5a83f..836911bd9f3b 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -56,7 +56,8 @@ use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ exec_err, get_target_functional_dependencies, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, - DataFusionError, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions, + DataFusionError, NullEquality, Result, ScalarValue, TableReference, ToDFSchema, + UnnestOptions, }; use datafusion_expr_common::type_coercion::binary::type_union_resolution; @@ -341,8 +342,11 @@ impl LogicalPlanBuilder { // wrap cast if data type is not same as common type. for row in &mut values { for (j, field_type) in fields.iter().map(|f| f.data_type()).enumerate() { - if let Expr::Literal(ScalarValue::Null) = row[j] { - row[j] = Expr::Literal(ScalarValue::try_from(field_type)?); + if let Expr::Literal(ScalarValue::Null, metadata) = &row[j] { + row[j] = Expr::Literal( + ScalarValue::try_from(field_type)?, + metadata.clone(), + ); } else { row[j] = std::mem::take(&mut row[j]).cast_to(field_type, schema)?; } @@ -501,6 +505,21 @@ impl LogicalPlanBuilder { if table_scan.filters.is_empty() { if let Some(p) = table_scan.source.get_logical_plan() { let sub_plan = p.into_owned(); + + if let Some(proj) = table_scan.projection { + let projection_exprs = proj + .into_iter() + .map(|i| { + Expr::Column(Column::from( + sub_plan.schema().qualified_field(i), + )) + }) + .collect::>(); + return Self::new(sub_plan) + .project(projection_exprs)? + .alias(table_scan.table_name); + } + // Ensures that the reference to the inlined table remains the // same, meaning we don't have to change any of the parent nodes // that reference this table. @@ -586,7 +605,7 @@ impl LogicalPlanBuilder { /// Apply a filter which is used for a having clause pub fn having(self, expr: impl Into) -> Result { let expr = normalize_col(expr.into(), &self.plan)?; - Filter::try_new_with_having(expr, self.plan) + Filter::try_new(expr, self.plan) .map(LogicalPlan::Filter) .map(Self::from) } @@ -885,7 +904,13 @@ impl LogicalPlanBuilder { join_keys: (Vec>, Vec>), filter: Option, ) -> Result { - self.join_detailed(right, join_type, join_keys, filter, false) + self.join_detailed( + right, + join_type, + join_keys, + filter, + NullEquality::NullEqualsNothing, + ) } /// Apply a join using the specified expressions. @@ -941,15 +966,11 @@ impl LogicalPlanBuilder { join_type, (Vec::::new(), Vec::::new()), filter, - false, + NullEquality::NullEqualsNothing, ) } - pub(crate) fn normalize( - plan: &LogicalPlan, - column: impl Into, - ) -> Result { - let column = column.into(); + pub(crate) fn normalize(plan: &LogicalPlan, column: Column) -> Result { if column.relation.is_some() { // column is already normalized return Ok(column); @@ -969,16 +990,14 @@ impl LogicalPlanBuilder { /// The behavior is the same as [`join`](Self::join) except that it allows /// specifying the null equality behavior. /// - /// If `null_equals_null=true`, rows where both join keys are `null` will be - /// emitted. Otherwise rows where either or both join keys are `null` will be - /// omitted. + /// The `null_equality` dictates how `null` values are joined. pub fn join_detailed( self, right: LogicalPlan, join_type: JoinType, join_keys: (Vec>, Vec>), filter: Option, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result { if join_keys.0.len() != join_keys.1.len() { return plan_err!("left_keys and right_keys were not the same length"); @@ -1095,7 +1114,7 @@ impl LogicalPlanBuilder { join_type, join_constraint: JoinConstraint::On, schema: DFSchemaRef::new(join_schema), - null_equals_null, + null_equality, }))) } @@ -1104,7 +1123,7 @@ impl LogicalPlanBuilder { self, right: LogicalPlan, join_type: JoinType, - using_keys: Vec + Clone>, + using_keys: Vec, ) -> Result { let left_keys: Vec = using_keys .clone() @@ -1117,19 +1136,29 @@ impl LogicalPlanBuilder { .collect::>()?; let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys).collect(); - let join_schema = - build_join_schema(self.plan.schema(), right.schema(), &join_type)?; let mut join_on: Vec<(Expr, Expr)> = vec![]; let mut filters: Option = None; for (l, r) in &on { if self.plan.schema().has_column(l) && right.schema().has_column(r) - && can_hash(self.plan.schema().field_from_column(l)?.data_type()) + && can_hash( + datafusion_common::ExprSchema::field_from_column( + self.plan.schema(), + l, + )? + .data_type(), + ) { join_on.push((Expr::Column(l.clone()), Expr::Column(r.clone()))); } else if self.plan.schema().has_column(l) && right.schema().has_column(r) - && can_hash(self.plan.schema().field_from_column(r)?.data_type()) + && can_hash( + datafusion_common::ExprSchema::field_from_column( + self.plan.schema(), + r, + )? + .data_type(), + ) { join_on.push((Expr::Column(r.clone()), Expr::Column(l.clone()))); } else { @@ -1151,33 +1180,33 @@ impl LogicalPlanBuilder { DataFusionError::Internal("filters should not be None here".to_string()) })?) } else { - Ok(Self::new(LogicalPlan::Join(Join { - left: self.plan, - right: Arc::new(right), - on: join_on, - filter: filters, + let join = Join::try_new( + self.plan, + Arc::new(right), + join_on, + filters, join_type, - join_constraint: JoinConstraint::Using, - schema: DFSchemaRef::new(join_schema), - null_equals_null: false, - }))) + JoinConstraint::Using, + NullEquality::NullEqualsNothing, + )?; + + Ok(Self::new(LogicalPlan::Join(join))) } } /// Apply a cross join pub fn cross_join(self, right: LogicalPlan) -> Result { - let join_schema = - build_join_schema(self.plan.schema(), right.schema(), &JoinType::Inner)?; - Ok(Self::new(LogicalPlan::Join(Join { - left: self.plan, - right: Arc::new(right), - on: vec![], - filter: None, - join_type: JoinType::Inner, - join_constraint: JoinConstraint::On, - null_equals_null: false, - schema: DFSchemaRef::new(join_schema), - }))) + let join = Join::try_new( + self.plan, + Arc::new(right), + vec![], + None, + JoinType::Inner, + JoinConstraint::On, + NullEquality::NullEqualsNothing, + )?; + + Ok(Self::new(LogicalPlan::Join(join))) } /// Repartition @@ -1312,12 +1341,24 @@ impl LogicalPlanBuilder { .unzip(); if is_all { LogicalPlanBuilder::from(left_plan) - .join_detailed(right_plan, join_type, join_keys, None, true)? + .join_detailed( + right_plan, + join_type, + join_keys, + None, + NullEquality::NullEqualsNull, + )? .build() } else { LogicalPlanBuilder::from(left_plan) .distinct()? - .join_detailed(right_plan, join_type, join_keys, None, true)? + .join_detailed( + right_plan, + join_type, + join_keys, + None, + NullEquality::NullEqualsNull, + )? .build() } } @@ -1338,7 +1379,7 @@ impl LogicalPlanBuilder { /// to columns from the existing input. `r`, the second element of the tuple, /// must only refer to columns from the right input. /// - /// `filter` contains any other other filter expression to apply during the + /// `filter` contains any other filter expression to apply during the /// join. Note that `equi_exprs` predicates are evaluated more efficiently /// than the filter expressions, so they are preferred. pub fn join_with_expr_keys( @@ -1388,19 +1429,17 @@ impl LogicalPlanBuilder { }) .collect::>>()?; - let join_schema = - build_join_schema(self.plan.schema(), right.schema(), &join_type)?; - - Ok(Self::new(LogicalPlan::Join(Join { - left: self.plan, - right: Arc::new(right), - on: join_key_pairs, + let join = Join::try_new( + self.plan, + Arc::new(right), + join_key_pairs, filter, join_type, - join_constraint: JoinConstraint::On, - schema: DFSchemaRef::new(join_schema), - null_equals_null: false, - }))) + JoinConstraint::On, + NullEquality::NullEqualsNothing, + )?; + + Ok(Self::new(LogicalPlan::Join(join))) } /// Unnest the given column. @@ -1490,7 +1529,7 @@ pub fn change_redundant_column(fields: &Fields) -> Vec { // Loop until we find a name that hasn't been used while seen.contains(&new_name) { *count += 1; - new_name = format!("{}:{}", base_name, count); + new_name = format!("{base_name}:{count}"); } seen.insert(new_name.clone()); @@ -1597,18 +1636,29 @@ pub fn build_join_schema( .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect() } + JoinType::RightMark => right_fields + .map(|(q, f)| (q.cloned(), Arc::clone(f))) + .chain(once(mark_field(left))) + .collect(), }; let func_dependencies = left.functional_dependencies().join( right.functional_dependencies(), join_type, left.fields().len(), ); - let metadata = left + + let (schema1, schema2) = match join_type { + JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => (left, right), + _ => (right, left), + }; + + let metadata = schema1 .metadata() .clone() .into_iter() - .chain(right.metadata().clone()) + .chain(schema2.metadata().clone()) .collect(); + let dfschema = DFSchema::new_with_metadata(qualified_fields, metadata)?; dfschema.with_functional_dependencies(func_dependencies) } @@ -2237,6 +2287,7 @@ mod tests { use crate::test::function_stub::sum; use datafusion_common::{Constraint, RecursionUnnestOption, SchemaError}; + use insta::assert_snapshot; #[test] fn plan_builder_simple() -> Result<()> { @@ -2246,11 +2297,11 @@ mod tests { .project(vec![col("id")])? .build()?; - let expected = "Projection: employee_csv.id\ - \n Filter: employee_csv.state = Utf8(\"CO\")\ - \n TableScan: employee_csv projection=[id, state]"; - - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r#" + Projection: employee_csv.id + Filter: employee_csv.state = Utf8("CO") + TableScan: employee_csv projection=[id, state] + "#); Ok(()) } @@ -2262,12 +2313,7 @@ mod tests { let plan = LogicalPlanBuilder::scan("employee_csv", table_source(&schema), projection) .unwrap(); - let expected = DFSchema::try_from_qualified_schema( - TableReference::bare("employee_csv"), - &schema, - ) - .unwrap(); - assert_eq!(&expected, plan.schema().as_ref()); + assert_snapshot!(plan.schema().as_ref(), @"fields:[employee_csv.id, employee_csv.first_name, employee_csv.last_name, employee_csv.state, employee_csv.salary], metadata:{}"); // Note scan of "EMPLOYEE_CSV" is treated as a SQL identifier // (and thus normalized to "employee"csv") as well @@ -2275,7 +2321,7 @@ mod tests { let plan = LogicalPlanBuilder::scan("EMPLOYEE_CSV", table_source(&schema), projection) .unwrap(); - assert_eq!(&expected, plan.schema().as_ref()); + assert_snapshot!(plan.schema().as_ref(), @"fields:[employee_csv.id, employee_csv.first_name, employee_csv.last_name, employee_csv.state, employee_csv.salary], metadata:{}"); } #[test] @@ -2284,9 +2330,9 @@ mod tests { let projection = None; let err = LogicalPlanBuilder::scan("", table_source(&schema), projection).unwrap_err(); - assert_eq!( + assert_snapshot!( err.strip_backtrace(), - "Error during planning: table_name cannot be empty" + @"Error during planning: table_name cannot be empty" ); } @@ -2300,10 +2346,10 @@ mod tests { ])? .build()?; - let expected = "Sort: employee_csv.state ASC NULLS FIRST, employee_csv.salary DESC NULLS LAST\ - \n TableScan: employee_csv projection=[state, salary]"; - - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Sort: employee_csv.state ASC NULLS FIRST, employee_csv.salary DESC NULLS LAST + TableScan: employee_csv projection=[state, salary] + "); Ok(()) } @@ -2320,15 +2366,15 @@ mod tests { .union(plan.build()?)? .build()?; - let expected = "Union\ - \n Union\ - \n Union\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]"; - - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Union + Union + Union + TableScan: employee_csv projection=[state, salary] + TableScan: employee_csv projection=[state, salary] + TableScan: employee_csv projection=[state, salary] + TableScan: employee_csv projection=[state, salary] + "); Ok(()) } @@ -2345,19 +2391,18 @@ mod tests { .union_distinct(plan.build()?)? .build()?; - let expected = "\ - Distinct:\ - \n Union\ - \n Distinct:\ - \n Union\ - \n Distinct:\ - \n Union\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]\ - \n TableScan: employee_csv projection=[state, salary]"; - - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Distinct: + Union + Distinct: + Union + Distinct: + Union + TableScan: employee_csv projection=[state, salary] + TableScan: employee_csv projection=[state, salary] + TableScan: employee_csv projection=[state, salary] + TableScan: employee_csv projection=[state, salary] + "); Ok(()) } @@ -2371,13 +2416,12 @@ mod tests { .distinct()? .build()?; - let expected = "\ - Distinct:\ - \n Projection: employee_csv.id\ - \n Filter: employee_csv.state = Utf8(\"CO\")\ - \n TableScan: employee_csv projection=[id, state]"; - - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r#" + Distinct: + Projection: employee_csv.id + Filter: employee_csv.state = Utf8("CO") + TableScan: employee_csv projection=[id, state] + "#); Ok(()) } @@ -2397,14 +2441,15 @@ mod tests { .filter(exists(Arc::new(subquery)))? .build()?; - let expected = "Filter: EXISTS ()\ - \n Subquery:\ - \n Filter: foo.a = bar.a\ - \n Projection: foo.a\ - \n TableScan: foo\ - \n Projection: bar.a\ - \n TableScan: bar"; - assert_eq!(expected, format!("{outer_query}")); + assert_snapshot!(outer_query, @r" + Filter: EXISTS () + Subquery: + Filter: foo.a = bar.a + Projection: foo.a + TableScan: foo + Projection: bar.a + TableScan: bar + "); Ok(()) } @@ -2425,14 +2470,15 @@ mod tests { .filter(in_subquery(col("a"), Arc::new(subquery)))? .build()?; - let expected = "Filter: bar.a IN ()\ - \n Subquery:\ - \n Filter: foo.a = bar.a\ - \n Projection: foo.a\ - \n TableScan: foo\ - \n Projection: bar.a\ - \n TableScan: bar"; - assert_eq!(expected, format!("{outer_query}")); + assert_snapshot!(outer_query, @r" + Filter: bar.a IN () + Subquery: + Filter: foo.a = bar.a + Projection: foo.a + TableScan: foo + Projection: bar.a + TableScan: bar + "); Ok(()) } @@ -2452,13 +2498,14 @@ mod tests { .project(vec![scalar_subquery(Arc::new(subquery))])? .build()?; - let expected = "Projection: ()\ - \n Subquery:\ - \n Filter: foo.a = bar.a\ - \n Projection: foo.b\ - \n TableScan: foo\ - \n TableScan: bar"; - assert_eq!(expected, format!("{outer_query}")); + assert_snapshot!(outer_query, @r" + Projection: () + Subquery: + Filter: foo.a = bar.a + Projection: foo.b + TableScan: foo + TableScan: bar + "); Ok(()) } @@ -2552,13 +2599,11 @@ mod tests { let plan2 = table_scan(TableReference::none(), &employee_schema(), Some(vec![3, 4]))?; - let expected = "Error during planning: INTERSECT/EXCEPT query must have the same number of columns. \ - Left is 1 and right is 2."; let err_msg1 = LogicalPlanBuilder::intersect(plan1.build()?, plan2.build()?, true) .unwrap_err(); - assert_eq!(err_msg1.strip_backtrace(), expected); + assert_snapshot!(err_msg1.strip_backtrace(), @"Error during planning: INTERSECT/EXCEPT query must have the same number of columns. Left is 1 and right is 2."); Ok(()) } @@ -2569,19 +2614,29 @@ mod tests { let err = nested_table_scan("test_table")? .unnest_column("scalar") .unwrap_err(); - assert!(err - .to_string() - .starts_with("Internal error: trying to unnest on invalid data type UInt32")); + + let DataFusionError::Internal(desc) = err else { + return plan_err!("Plan should have returned an DataFusionError::Internal"); + }; + + let desc = desc + .split(DataFusionError::BACK_TRACE_SEP) + .collect::>() + .first() + .unwrap_or(&"") + .to_string(); + + assert_snapshot!(desc, @"trying to unnest on invalid data type UInt32"); // Unnesting the strings list. let plan = nested_table_scan("test_table")? .unnest_column("strings")? .build()?; - let expected = "\ - Unnest: lists[test_table.strings|depth=1] structs[]\ - \n TableScan: test_table"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Unnest: lists[test_table.strings|depth=1] structs[] + TableScan: test_table + "); // Check unnested field is a scalar let field = plan.schema().field_with_name(None, "strings").unwrap(); @@ -2592,16 +2647,16 @@ mod tests { .unnest_column("struct_singular")? .build()?; - let expected = "\ - Unnest: lists[] structs[test_table.struct_singular]\ - \n TableScan: test_table"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Unnest: lists[] structs[test_table.struct_singular] + TableScan: test_table + "); for field_name in &["a", "b"] { // Check unnested struct field is a scalar let field = plan .schema() - .field_with_name(None, &format!("struct_singular.{}", field_name)) + .field_with_name(None, &format!("struct_singular.{field_name}")) .unwrap(); assert_eq!(&DataType::UInt32, field.data_type()); } @@ -2613,12 +2668,12 @@ mod tests { .unnest_column("struct_singular")? .build()?; - let expected = "\ - Unnest: lists[] structs[test_table.struct_singular]\ - \n Unnest: lists[test_table.structs|depth=1] structs[]\ - \n Unnest: lists[test_table.strings|depth=1] structs[]\ - \n TableScan: test_table"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Unnest: lists[] structs[test_table.struct_singular] + Unnest: lists[test_table.structs|depth=1] structs[] + Unnest: lists[test_table.strings|depth=1] structs[] + TableScan: test_table + "); // Check unnested struct list field should be a struct. let field = plan.schema().field_with_name(None, "structs").unwrap(); @@ -2634,10 +2689,10 @@ mod tests { .unnest_columns_with_options(cols, UnnestOptions::default())? .build()?; - let expected = "\ - Unnest: lists[test_table.strings|depth=1, test_table.structs|depth=1] structs[test_table.struct_singular]\ - \n TableScan: test_table"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Unnest: lists[test_table.strings|depth=1, test_table.structs|depth=1] structs[test_table.struct_singular] + TableScan: test_table + "); // Unnesting missing column should fail. let plan = nested_table_scan("test_table")?.unnest_column("missing"); @@ -2661,10 +2716,10 @@ mod tests { )? .build()?; - let expected = "\ - Unnest: lists[test_table.stringss|depth=1, test_table.stringss|depth=2] structs[test_table.struct_singular]\ - \n TableScan: test_table"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Unnest: lists[test_table.stringss|depth=1, test_table.stringss|depth=2] structs[test_table.struct_singular] + TableScan: test_table + "); // Check output columns has correct type let field = plan @@ -2684,7 +2739,7 @@ mod tests { for field_name in &["a", "b"] { let field = plan .schema() - .field_with_name(None, &format!("struct_singular.{}", field_name)) + .field_with_name(None, &format!("struct_singular.{field_name}")) .unwrap(); assert_eq!(&DataType::UInt32, field.data_type()); } @@ -2736,10 +2791,24 @@ mod tests { let join = LogicalPlanBuilder::from(left).cross_join(right)?.build()?; - let _ = LogicalPlanBuilder::from(join.clone()) + let plan = LogicalPlanBuilder::from(join.clone()) .union(join)? .build()?; + assert_snapshot!(plan, @r" + Union + Cross Join: + SubqueryAlias: left + Values: (Int32(1)) + SubqueryAlias: right + Values: (Int32(1)) + Cross Join: + SubqueryAlias: left + Values: (Int32(1)) + SubqueryAlias: right + Values: (Int32(1)) + "); + Ok(()) } @@ -2799,10 +2868,10 @@ mod tests { .aggregate(vec![col("id")], vec![sum(col("salary"))])? .build()?; - let expected = - "Aggregate: groupBy=[[employee_csv.id]], aggr=[[sum(employee_csv.salary)]]\ - \n TableScan: employee_csv projection=[id, state, salary]"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Aggregate: groupBy=[[employee_csv.id]], aggr=[[sum(employee_csv.salary)]] + TableScan: employee_csv projection=[id, state, salary] + "); Ok(()) } @@ -2821,10 +2890,37 @@ mod tests { .aggregate(vec![col("id")], vec![sum(col("salary"))])? .build()?; - let expected = - "Aggregate: groupBy=[[employee_csv.id, employee_csv.state, employee_csv.salary]], aggr=[[sum(employee_csv.salary)]]\ - \n TableScan: employee_csv projection=[id, state, salary]"; - assert_eq!(expected, format!("{plan}")); + assert_snapshot!(plan, @r" + Aggregate: groupBy=[[employee_csv.id, employee_csv.state, employee_csv.salary]], aggr=[[sum(employee_csv.salary)]] + TableScan: employee_csv projection=[id, state, salary] + "); + + Ok(()) + } + + #[test] + fn test_join_metadata() -> Result<()> { + let left_schema = DFSchema::new_with_metadata( + vec![(None, Arc::new(Field::new("a", DataType::Int32, false)))], + HashMap::from([("key".to_string(), "left".to_string())]), + )?; + let right_schema = DFSchema::new_with_metadata( + vec![(None, Arc::new(Field::new("b", DataType::Int32, false)))], + HashMap::from([("key".to_string(), "right".to_string())]), + )?; + + let join_schema = + build_join_schema(&left_schema, &right_schema, &JoinType::Left)?; + assert_eq!( + join_schema.metadata(), + &HashMap::from([("key".to_string(), "left".to_string())]) + ); + let join_schema = + build_join_schema(&left_schema, &right_schema, &JoinType::Right)?; + assert_eq!( + join_schema.metadata(), + &HashMap::from([("key".to_string(), "right".to_string())]) + ); Ok(()) } diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 14758b61e859..f1e455f46db3 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -341,7 +341,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { let eclipse = if values.len() > 5 { "..." } else { "" }; - let values_str = format!("{}{}", str_values, eclipse); + let values_str = format!("{str_values}{eclipse}"); json!({ "Node Type": "Values", "Values": values_str @@ -429,7 +429,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { }) => { let op_str = options .iter() - .map(|(k, v)| format!("{}={}", k, v)) + .map(|(k, v)| format!("{k}={v}")) .collect::>() .join(", "); json!({ @@ -722,13 +722,14 @@ impl<'n> TreeNodeVisitor<'n> for PgJsonVisitor<'_, '_> { #[cfg(test)] mod tests { use arrow::datatypes::{DataType, Field}; + use insta::assert_snapshot; use super::*; #[test] fn test_display_empty_schema() { let schema = Schema::empty(); - assert_eq!("[]", format!("{}", display_schema(&schema))); + assert_snapshot!(display_schema(&schema), @"[]"); } #[test] @@ -738,9 +739,6 @@ mod tests { Field::new("first_name", DataType::Utf8, true), ]); - assert_eq!( - "[id:Int32, first_name:Utf8;N]", - format!("{}", display_schema(&schema)) - ); + assert_snapshot!(display_schema(&schema), @"[id:Int32, first_name:Utf8;N]"); } } diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index d4d50ac4eae4..f3c95e696b4b 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -89,8 +89,28 @@ impl Hash for CopyTo { } } -/// The operator that modifies the content of a database (adapted from -/// substrait WriteRel) +/// Modifies the content of a database +/// +/// This operator is used to perform DML operations such as INSERT, DELETE, +/// UPDATE, and CTAS (CREATE TABLE AS SELECT). +/// +/// * `INSERT` - Appends new rows to the existing table. Calls +/// [`TableProvider::insert_into`] +/// +/// * `DELETE` - Removes rows from the table. Currently NOT supported by the +/// [`TableProvider`] trait or builtin sources. +/// +/// * `UPDATE` - Modifies existing rows in the table. Currently NOT supported by +/// the [`TableProvider`] trait or builtin sources. +/// +/// * `CREATE TABLE AS SELECT` - Creates a new table and populates it with data +/// from a query. This is similar to the `INSERT` operation, but it creates a new +/// table instead of modifying an existing one. +/// +/// Note that the structure is adapted from substrait WriteRel) +/// +/// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html +/// [`TableProvider::insert_into`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html#method.insert_into #[derive(Clone)] pub struct DmlStatement { /// The table name @@ -177,11 +197,18 @@ impl PartialOrd for DmlStatement { } } +/// The type of DML operation to perform. +/// +/// See [`DmlStatement`] for more details. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum WriteOp { + /// `INSERT INTO` operation Insert(InsertOp), + /// `DELETE` operation Delete, + /// `UPDATE` operation Update, + /// `CREATE TABLE AS SELECT` operation Ctas, } diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index 0c30c9785766..d8d6739b0e8f 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -310,7 +310,10 @@ fn check_inner_plan(inner_plan: &LogicalPlan) -> Result<()> { check_inner_plan(left)?; check_no_outer_references(right) } - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { + JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => { check_no_outer_references(left)?; check_inner_plan(right) } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 76b45d5d723a..876c14f1000f 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -56,8 +56,8 @@ use datafusion_common::tree_node::{ use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, - FunctionalDependencies, ParamValues, Result, ScalarValue, Spans, TableReference, - UnnestOptions, + FunctionalDependencies, NullEquality, ParamValues, Result, ScalarValue, Spans, + TableReference, UnnestOptions, }; use indexmap::IndexSet; @@ -556,7 +556,9 @@ impl LogicalPlan { JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { left.head_output_expr() } - JoinType::RightSemi | JoinType::RightAnti => right.head_output_expr(), + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + right.head_output_expr() + } }, LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { static_term.head_output_expr() @@ -630,12 +632,9 @@ impl LogicalPlan { // todo it isn't clear why the schema is not recomputed here Ok(LogicalPlan::Values(Values { schema, values })) } - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) => Filter::try_new_internal(predicate, input, having) - .map(LogicalPlan::Filter), + LogicalPlan::Filter(Filter { predicate, input }) => { + Filter::try_new(predicate, input).map(LogicalPlan::Filter) + } LogicalPlan::Repartition(_) => Ok(self), LogicalPlan::Window(Window { input, @@ -658,7 +657,7 @@ impl LogicalPlan { join_constraint, on, schema: _, - null_equals_null, + null_equality, }) => { let schema = build_join_schema(left.schema(), right.schema(), &join_type)?; @@ -679,7 +678,7 @@ impl LogicalPlan { on: new_on, filter, schema: DFSchemaRef::new(schema), - null_equals_null, + null_equality, })) } LogicalPlan::Subquery(_) => Ok(self), @@ -897,7 +896,7 @@ impl LogicalPlan { join_type, join_constraint, on, - null_equals_null, + null_equality, .. }) => { let (left, right) = self.only_two_inputs(inputs)?; @@ -936,7 +935,7 @@ impl LogicalPlan { on: new_on, filter: filter_expr, schema: DFSchemaRef::new(schema), - null_equals_null: *null_equals_null, + null_equality: *null_equality, })) } LogicalPlan::Subquery(Subquery { @@ -991,7 +990,7 @@ impl LogicalPlan { Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { input: Arc::new(input), - constraints: Constraints::empty(), + constraints: Constraints::default(), name: name.clone(), if_not_exists: *if_not_exists, or_replace: *or_replace, @@ -1308,7 +1307,7 @@ impl LogicalPlan { // Empty group_expr will return Some(1) if group_expr .iter() - .all(|expr| matches!(expr, Expr::Literal(_))) + .all(|expr| matches!(expr, Expr::Literal(_, _))) { Some(1) } else { @@ -1343,7 +1342,9 @@ impl LogicalPlan { JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { left.max_rows() } - JoinType::RightSemi | JoinType::RightAnti => right.max_rows(), + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + right.max_rows() + } }, LogicalPlan::Repartition(Repartition { input, .. }) => input.max_rows(), LogicalPlan::Union(Union { inputs, .. }) => { @@ -1458,7 +1459,7 @@ impl LogicalPlan { let transformed_expr = e.transform_up(|e| { if let Expr::Placeholder(Placeholder { id, .. }) = e { let value = param_values.get_placeholders_with_values(&id)?; - Ok(Transformed::yes(Expr::Literal(value))) + Ok(Transformed::yes(Expr::Literal(value, None))) } else { Ok(Transformed::no(e)) } @@ -1721,7 +1722,7 @@ impl LogicalPlan { LogicalPlan::RecursiveQuery(RecursiveQuery { is_distinct, .. }) => { - write!(f, "RecursiveQuery: is_distinct={}", is_distinct) + write!(f, "RecursiveQuery: is_distinct={is_distinct}") } LogicalPlan::Values(Values { ref values, .. }) => { let str_values: Vec<_> = values @@ -1818,12 +1819,12 @@ impl LogicalPlan { Ok(()) } LogicalPlan::Projection(Projection { ref expr, .. }) => { - write!(f, "Projection: ")?; + write!(f, "Projection:")?; for (i, expr_item) in expr.iter().enumerate() { if i > 0 { - write!(f, ", ")?; + write!(f, ",")?; } - write!(f, "{expr_item}")?; + write!(f, " {expr_item}")?; } Ok(()) } @@ -1964,7 +1965,7 @@ impl LogicalPlan { }; write!( f, - "Limit: skip={}, fetch={}", skip_str,fetch_str, + "Limit: skip={skip_str}, fetch={fetch_str}", ) } LogicalPlan::Subquery(Subquery { .. }) => { @@ -2259,8 +2260,6 @@ pub struct Filter { pub predicate: Expr, /// The incoming logical plan pub input: Arc, - /// The flag to indicate if the filter is a having clause - pub having: bool, } impl Filter { @@ -2269,13 +2268,14 @@ impl Filter { /// Notes: as Aliases have no effect on the output of a filter operator, /// they are removed from the predicate expression. pub fn try_new(predicate: Expr, input: Arc) -> Result { - Self::try_new_internal(predicate, input, false) + Self::try_new_internal(predicate, input) } /// Create a new filter operator for a having clause. /// This is similar to a filter, but its having flag is set to true. + #[deprecated(since = "48.0.0", note = "Use `try_new` instead")] pub fn try_new_with_having(predicate: Expr, input: Arc) -> Result { - Self::try_new_internal(predicate, input, true) + Self::try_new_internal(predicate, input) } fn is_allowed_filter_type(data_type: &DataType) -> bool { @@ -2289,11 +2289,7 @@ impl Filter { } } - fn try_new_internal( - predicate: Expr, - input: Arc, - having: bool, - ) -> Result { + fn try_new_internal(predicate: Expr, input: Arc) -> Result { // Filter predicates must return a boolean value so we try and validate that here. // Note that it is not always possible to resolve the predicate expression during plan // construction (such as with correlated subqueries) so we make a best effort here and @@ -2309,7 +2305,6 @@ impl Filter { Ok(Self { predicate: predicate.unalias_nested().data, input, - having, }) } @@ -2431,18 +2426,23 @@ impl Window { .iter() .enumerate() .filter_map(|(idx, expr)| { - if let Expr::WindowFunction(WindowFunction { + let Expr::WindowFunction(window_fun) = expr else { + return None; + }; + let WindowFunction { fun: WindowFunctionDefinition::WindowUDF(udwf), params: WindowFunctionParams { partition_by, .. }, - }) = expr - { - // When there is no PARTITION BY, row number will be unique - // across the entire table. - if udwf.name() == "row_number" && partition_by.is_empty() { - return Some(idx + input_len); - } + } = window_fun.as_ref() + else { + return None; + }; + // When there is no PARTITION BY, row number will be unique + // across the entire table. + if udwf.name() == "row_number" && partition_by.is_empty() { + Some(idx + input_len) + } else { + None } - None }) .map(|idx| { FunctionalDependence::new(vec![idx], vec![], false) @@ -2702,7 +2702,9 @@ impl Union { { expr.push(Expr::Column(column)); } else { - expr.push(Expr::Literal(ScalarValue::Null).alias(column.name())); + expr.push( + Expr::Literal(ScalarValue::Null, None).alias(column.name()), + ); } } wrapped_inputs.push(Arc::new(LogicalPlan::Projection( @@ -2860,7 +2862,7 @@ impl Union { // Generate unique field name let name = if let Some(count) = name_counts.get_mut(&base_name) { *count += 1; - format!("{}_{}", base_name, count) + format!("{base_name}_{count}") } else { name_counts.insert(base_name.clone(), 0); base_name @@ -3228,7 +3230,7 @@ impl Limit { pub fn get_skip_type(&self) -> Result { match self.skip.as_deref() { Some(expr) => match *expr { - Expr::Literal(ScalarValue::Int64(s)) => { + Expr::Literal(ScalarValue::Int64(s), _) => { // `skip = NULL` is equivalent to `skip = 0` let s = s.unwrap_or(0); if s >= 0 { @@ -3248,14 +3250,16 @@ impl Limit { pub fn get_fetch_type(&self) -> Result { match self.fetch.as_deref() { Some(expr) => match *expr { - Expr::Literal(ScalarValue::Int64(Some(s))) => { + Expr::Literal(ScalarValue::Int64(Some(s)), _) => { if s >= 0 { Ok(FetchType::Literal(Some(s as usize))) } else { plan_err!("LIMIT must be >= 0, '{}' was provided", s) } } - Expr::Literal(ScalarValue::Int64(None)) => Ok(FetchType::Literal(None)), + Expr::Literal(ScalarValue::Int64(None), _) => { + Ok(FetchType::Literal(None)) + } _ => Ok(FetchType::UnsupportedExpr), }, None => Ok(FetchType::Literal(None)), @@ -3657,7 +3661,7 @@ fn calc_func_dependencies_for_project( .unwrap_or(vec![])) } _ => { - let name = format!("{}", expr); + let name = format!("{expr}"); Ok(input_fields .iter() .position(|item| *item == name) @@ -3704,11 +3708,52 @@ pub struct Join { pub join_constraint: JoinConstraint, /// The output schema, containing fields from the left and right inputs pub schema: DFSchemaRef, - /// If null_equals_null is true, null == null else null != null - pub null_equals_null: bool, + /// Defines the null equality for the join. + pub null_equality: NullEquality, } impl Join { + /// Creates a new Join operator with automatically computed schema. + /// + /// This constructor computes the schema based on the join type and inputs, + /// removing the need to manually specify the schema or call `recompute_schema`. + /// + /// # Arguments + /// + /// * `left` - Left input plan + /// * `right` - Right input plan + /// * `on` - Join condition as a vector of (left_expr, right_expr) pairs + /// * `filter` - Optional filter expression (for non-equijoin conditions) + /// * `join_type` - Type of join (Inner, Left, Right, etc.) + /// * `join_constraint` - Join constraint (On, Using) + /// * `null_equality` - How to handle nulls in join comparisons + /// + /// # Returns + /// + /// A new Join operator with the computed schema + pub fn try_new( + left: Arc, + right: Arc, + on: Vec<(Expr, Expr)>, + filter: Option, + join_type: JoinType, + join_constraint: JoinConstraint, + null_equality: NullEquality, + ) -> Result { + let join_schema = build_join_schema(left.schema(), right.schema(), &join_type)?; + + Ok(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema: Arc::new(join_schema), + null_equality, + }) + } + /// Create Join with input which wrapped with projection, this method is used to help create physical join. pub fn try_new_with_project_input( original: &LogicalPlan, @@ -3738,7 +3783,7 @@ impl Join { join_type: original_join.join_type, join_constraint: original_join.join_constraint, schema: Arc::new(join_schema), - null_equals_null: original_join.null_equals_null, + null_equality: original_join.null_equality, }) } } @@ -3760,8 +3805,8 @@ impl PartialOrd for Join { pub join_type: &'a JoinType, /// Join constraint pub join_constraint: &'a JoinConstraint, - /// If null_equals_null is true, null == null else null != null - pub null_equals_null: &'a bool, + /// The null handling behavior for equalities + pub null_equality: &'a NullEquality, } let comparable_self = ComparableJoin { left: &self.left, @@ -3770,7 +3815,7 @@ impl PartialOrd for Join { filter: &self.filter, join_type: &self.join_type, join_constraint: &self.join_constraint, - null_equals_null: &self.null_equals_null, + null_equality: &self.null_equality, }; let comparable_other = ComparableJoin { left: &other.left, @@ -3779,7 +3824,7 @@ impl PartialOrd for Join { filter: &other.filter, join_type: &other.join_type, join_constraint: &other.join_constraint, - null_equals_null: &other.null_equals_null, + null_equality: &other.null_equality, }; comparable_self.partial_cmp(&comparable_other) } @@ -3965,6 +4010,7 @@ mod tests { TransformedResult, TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{not_impl_err, Constraint, ScalarValue}; + use insta::{assert_debug_snapshot, assert_snapshot}; use crate::test::function_stub::count; @@ -3992,13 +4038,13 @@ mod tests { fn test_display_indent() -> Result<()> { let plan = display_plan()?; - let expected = "Projection: employee_csv.id\ - \n Filter: employee_csv.state IN ()\ - \n Subquery:\ - \n TableScan: employee_csv projection=[state]\ - \n TableScan: employee_csv projection=[id, state]"; - - assert_eq!(expected, format!("{}", plan.display_indent())); + assert_snapshot!(plan.display_indent(), @r" + Projection: employee_csv.id + Filter: employee_csv.state IN () + Subquery: + TableScan: employee_csv projection=[state] + TableScan: employee_csv projection=[id, state] + "); Ok(()) } @@ -4006,13 +4052,13 @@ mod tests { fn test_display_indent_schema() -> Result<()> { let plan = display_plan()?; - let expected = "Projection: employee_csv.id [id:Int32]\ - \n Filter: employee_csv.state IN () [id:Int32, state:Utf8]\ - \n Subquery: [state:Utf8]\ - \n TableScan: employee_csv projection=[state] [state:Utf8]\ - \n TableScan: employee_csv projection=[id, state] [id:Int32, state:Utf8]"; - - assert_eq!(expected, format!("{}", plan.display_indent_schema())); + assert_snapshot!(plan.display_indent_schema(), @r" + Projection: employee_csv.id [id:Int32] + Filter: employee_csv.state IN () [id:Int32, state:Utf8] + Subquery: [state:Utf8] + TableScan: employee_csv projection=[state] [state:Utf8] + TableScan: employee_csv projection=[id, state] [id:Int32, state:Utf8] + "); Ok(()) } @@ -4027,12 +4073,12 @@ mod tests { .project(vec![col("id"), exists(plan1).alias("exists")])? .build(); - let expected = "Projection: employee_csv.id, EXISTS () AS exists\ - \n Subquery:\ - \n TableScan: employee_csv projection=[state]\ - \n TableScan: employee_csv projection=[id, state]"; - - assert_eq!(expected, format!("{}", plan?.display_indent())); + assert_snapshot!(plan?.display_indent(), @r" + Projection: employee_csv.id, EXISTS () AS exists + Subquery: + TableScan: employee_csv projection=[state] + TableScan: employee_csv projection=[id, state] + "); Ok(()) } @@ -4040,46 +4086,42 @@ mod tests { fn test_display_graphviz() -> Result<()> { let plan = display_plan()?; - let expected_graphviz = r#" -// Begin DataFusion GraphViz Plan, -// display it online here: https://dreampuf.github.io/GraphvizOnline - -digraph { - subgraph cluster_1 - { - graph[label="LogicalPlan"] - 2[shape=box label="Projection: employee_csv.id"] - 3[shape=box label="Filter: employee_csv.state IN ()"] - 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] - 4[shape=box label="Subquery:"] - 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] - 5[shape=box label="TableScan: employee_csv projection=[state]"] - 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] - 6[shape=box label="TableScan: employee_csv projection=[id, state]"] - 3 -> 6 [arrowhead=none, arrowtail=normal, dir=back] - } - subgraph cluster_7 - { - graph[label="Detailed LogicalPlan"] - 8[shape=box label="Projection: employee_csv.id\nSchema: [id:Int32]"] - 9[shape=box label="Filter: employee_csv.state IN ()\nSchema: [id:Int32, state:Utf8]"] - 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] - 10[shape=box label="Subquery:\nSchema: [state:Utf8]"] - 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] - 11[shape=box label="TableScan: employee_csv projection=[state]\nSchema: [state:Utf8]"] - 10 -> 11 [arrowhead=none, arrowtail=normal, dir=back] - 12[shape=box label="TableScan: employee_csv projection=[id, state]\nSchema: [id:Int32, state:Utf8]"] - 9 -> 12 [arrowhead=none, arrowtail=normal, dir=back] - } -} -// End DataFusion GraphViz Plan -"#; - // just test for a few key lines in the output rather than the // whole thing to make test maintenance easier. - let graphviz = format!("{}", plan.display_graphviz()); - - assert_eq!(expected_graphviz, graphviz); + assert_snapshot!(plan.display_graphviz(), @r#" + // Begin DataFusion GraphViz Plan, + // display it online here: https://dreampuf.github.io/GraphvizOnline + + digraph { + subgraph cluster_1 + { + graph[label="LogicalPlan"] + 2[shape=box label="Projection: employee_csv.id"] + 3[shape=box label="Filter: employee_csv.state IN ()"] + 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back] + 4[shape=box label="Subquery:"] + 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back] + 5[shape=box label="TableScan: employee_csv projection=[state]"] + 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back] + 6[shape=box label="TableScan: employee_csv projection=[id, state]"] + 3 -> 6 [arrowhead=none, arrowtail=normal, dir=back] + } + subgraph cluster_7 + { + graph[label="Detailed LogicalPlan"] + 8[shape=box label="Projection: employee_csv.id\nSchema: [id:Int32]"] + 9[shape=box label="Filter: employee_csv.state IN ()\nSchema: [id:Int32, state:Utf8]"] + 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back] + 10[shape=box label="Subquery:\nSchema: [state:Utf8]"] + 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back] + 11[shape=box label="TableScan: employee_csv projection=[state]\nSchema: [state:Utf8]"] + 10 -> 11 [arrowhead=none, arrowtail=normal, dir=back] + 12[shape=box label="TableScan: employee_csv projection=[id, state]\nSchema: [id:Int32, state:Utf8]"] + 9 -> 12 [arrowhead=none, arrowtail=normal, dir=back] + } + } + // End DataFusion GraphViz Plan + "#); Ok(()) } @@ -4087,60 +4129,58 @@ digraph { fn test_display_pg_json() -> Result<()> { let plan = display_plan()?; - let expected_pg_json = r#"[ - { - "Plan": { - "Expressions": [ - "employee_csv.id" - ], - "Node Type": "Projection", - "Output": [ - "id" - ], - "Plans": [ - { - "Condition": "employee_csv.state IN ()", - "Node Type": "Filter", - "Output": [ - "id", - "state" - ], - "Plans": [ - { - "Node Type": "Subquery", + assert_snapshot!(plan.display_pg_json(), @r#" + [ + { + "Plan": { + "Expressions": [ + "employee_csv.id" + ], + "Node Type": "Projection", "Output": [ - "state" + "id" ], "Plans": [ { - "Node Type": "TableScan", + "Condition": "employee_csv.state IN ()", + "Node Type": "Filter", "Output": [ + "id", "state" ], - "Plans": [], - "Relation Name": "employee_csv" + "Plans": [ + { + "Node Type": "Subquery", + "Output": [ + "state" + ], + "Plans": [ + { + "Node Type": "TableScan", + "Output": [ + "state" + ], + "Plans": [], + "Relation Name": "employee_csv" + } + ] + }, + { + "Node Type": "TableScan", + "Output": [ + "id", + "state" + ], + "Plans": [], + "Relation Name": "employee_csv" + } + ] } ] - }, - { - "Node Type": "TableScan", - "Output": [ - "id", - "state" - ], - "Plans": [], - "Relation Name": "employee_csv" } - ] - } - ] - } - } -]"#; - - let pg_json = format!("{}", plan.display_pg_json()); - - assert_eq!(expected_pg_json, pg_json); + } + ] + "#); Ok(()) } @@ -4189,17 +4229,16 @@ digraph { let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); - assert_eq!( - visitor.strings, - vec![ - "pre_visit Projection", - "pre_visit Filter", - "pre_visit TableScan", - "post_visit TableScan", - "post_visit Filter", - "post_visit Projection", - ] - ); + assert_debug_snapshot!(visitor.strings, @r#" + [ + "pre_visit Projection", + "pre_visit Filter", + "pre_visit TableScan", + "post_visit TableScan", + "post_visit Filter", + "post_visit Projection", + ] + "#); } #[derive(Debug, Default)] @@ -4265,9 +4304,14 @@ digraph { let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); - assert_eq!( + assert_debug_snapshot!( visitor.inner.strings, - vec!["pre_visit Projection", "pre_visit Filter"] + @r#" + [ + "pre_visit Projection", + "pre_visit Filter", + ] + "# ); } @@ -4281,14 +4325,16 @@ digraph { let res = plan.visit_with_subqueries(&mut visitor); assert!(res.is_ok()); - assert_eq!( + assert_debug_snapshot!( visitor.inner.strings, - vec![ - "pre_visit Projection", - "pre_visit Filter", - "pre_visit TableScan", - "post_visit TableScan", - ] + @r#" + [ + "pre_visit Projection", + "pre_visit Filter", + "pre_visit TableScan", + "post_visit TableScan", + ] + "# ); } @@ -4330,13 +4376,18 @@ digraph { }; let plan = test_plan(); let res = plan.visit_with_subqueries(&mut visitor).unwrap_err(); - assert_eq!( - "This feature is not implemented: Error in pre_visit", - res.strip_backtrace() + assert_snapshot!( + res.strip_backtrace(), + @"This feature is not implemented: Error in pre_visit" ); - assert_eq!( + assert_debug_snapshot!( visitor.inner.strings, - vec!["pre_visit Projection", "pre_visit Filter"] + @r#" + [ + "pre_visit Projection", + "pre_visit Filter", + ] + "# ); } @@ -4348,18 +4399,20 @@ digraph { }; let plan = test_plan(); let res = plan.visit_with_subqueries(&mut visitor).unwrap_err(); - assert_eq!( - "This feature is not implemented: Error in post_visit", - res.strip_backtrace() + assert_snapshot!( + res.strip_backtrace(), + @"This feature is not implemented: Error in post_visit" ); - assert_eq!( + assert_debug_snapshot!( visitor.inner.strings, - vec![ - "pre_visit Projection", - "pre_visit Filter", - "pre_visit TableScan", - "post_visit TableScan", - ] + @r#" + [ + "pre_visit Projection", + "pre_visit Filter", + "pre_visit TableScan", + "post_visit TableScan", + ] + "# ); } @@ -4374,7 +4427,7 @@ digraph { })), empty_schema, ); - assert_eq!(p.err().unwrap().strip_backtrace(), "Error during planning: Projection has mismatch between number of expressions (1) and number of fields in schema (0)"); + assert_snapshot!(p.unwrap_err().strip_backtrace(), @"Error during planning: Projection has mismatch between number of expressions (1) and number of fields in schema (0)"); Ok(()) } @@ -4494,7 +4547,7 @@ digraph { let col = schema.field_names()[0].clone(); let filter = Filter::try_new( - Expr::Column(col.into()).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), + Expr::Column(col.into()).eq(Expr::Literal(ScalarValue::Int32(Some(1)), None)), scan, ) .unwrap(); @@ -4561,11 +4614,12 @@ digraph { .data() .unwrap(); - let expected = "Explain\ - \n Filter: foo = Boolean(true)\ - \n TableScan: ?table?"; let actual = format!("{}", plan.display_indent()); - assert_eq!(expected.to_string(), actual) + assert_snapshot!(actual, @r" + Explain + Filter: foo = Boolean(true) + TableScan: ?table? + ") } #[test] @@ -4620,12 +4674,14 @@ digraph { skip: None, fetch: Some(Box::new(Expr::Literal( ScalarValue::new_ten(&DataType::UInt32).unwrap(), + None, ))), input: Arc::clone(&input), }), LogicalPlan::Limit(Limit { skip: Some(Box::new(Expr::Literal( ScalarValue::new_ten(&DataType::UInt32).unwrap(), + None, ))), fetch: None, input: Arc::clone(&input), @@ -4633,9 +4689,11 @@ digraph { LogicalPlan::Limit(Limit { skip: Some(Box::new(Expr::Literal( ScalarValue::new_one(&DataType::UInt32).unwrap(), + None, ))), fetch: Some(Box::new(Expr::Literal( ScalarValue::new_ten(&DataType::UInt32).unwrap(), + None, ))), input, }), @@ -4837,7 +4895,7 @@ digraph { join_type: JoinType::Inner, join_constraint: JoinConstraint::On, schema: Arc::new(left_schema.join(&right_schema)?), - null_equals_null: false, + null_equality: NullEquality::NullEqualsNothing, })) } @@ -4916,4 +4974,374 @@ digraph { Ok(()) } + + #[test] + fn test_join_try_new() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + let left_scan = table_scan(Some("t1"), &schema, None)?.build()?; + + let right_scan = table_scan(Some("t2"), &schema, None)?.build()?; + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightSemi, + JoinType::RightAnti, + JoinType::LeftMark, + ]; + + for join_type in join_types { + let join = Join::try_new( + Arc::new(left_scan.clone()), + Arc::new(right_scan.clone()), + vec![(col("t1.a"), col("t2.a"))], + Some(col("t1.b").gt(col("t2.b"))), + join_type, + JoinConstraint::On, + NullEquality::NullEqualsNothing, + )?; + + match join_type { + JoinType::LeftSemi | JoinType::LeftAnti => { + assert_eq!(join.schema.fields().len(), 2); + + let fields = join.schema.fields(); + assert_eq!( + fields[0].name(), + "a", + "First field should be 'a' from left table" + ); + assert_eq!( + fields[1].name(), + "b", + "Second field should be 'b' from left table" + ); + } + JoinType::RightSemi | JoinType::RightAnti => { + assert_eq!(join.schema.fields().len(), 2); + + let fields = join.schema.fields(); + assert_eq!( + fields[0].name(), + "a", + "First field should be 'a' from right table" + ); + assert_eq!( + fields[1].name(), + "b", + "Second field should be 'b' from right table" + ); + } + JoinType::LeftMark => { + assert_eq!(join.schema.fields().len(), 3); + + let fields = join.schema.fields(); + assert_eq!( + fields[0].name(), + "a", + "First field should be 'a' from left table" + ); + assert_eq!( + fields[1].name(), + "b", + "Second field should be 'b' from left table" + ); + assert_eq!( + fields[2].name(), + "mark", + "Third field should be the mark column" + ); + + assert!(!fields[0].is_nullable()); + assert!(!fields[1].is_nullable()); + assert!(!fields[2].is_nullable()); + } + _ => { + assert_eq!(join.schema.fields().len(), 4); + + let fields = join.schema.fields(); + assert_eq!( + fields[0].name(), + "a", + "First field should be 'a' from left table" + ); + assert_eq!( + fields[1].name(), + "b", + "Second field should be 'b' from left table" + ); + assert_eq!( + fields[2].name(), + "a", + "Third field should be 'a' from right table" + ); + assert_eq!( + fields[3].name(), + "b", + "Fourth field should be 'b' from right table" + ); + + if join_type == JoinType::Left { + // Left side fields (first two) shouldn't be nullable + assert!(!fields[0].is_nullable()); + assert!(!fields[1].is_nullable()); + // Right side fields (third and fourth) should be nullable + assert!(fields[2].is_nullable()); + assert!(fields[3].is_nullable()); + } else if join_type == JoinType::Right { + // Left side fields (first two) should be nullable + assert!(fields[0].is_nullable()); + assert!(fields[1].is_nullable()); + // Right side fields (third and fourth) shouldn't be nullable + assert!(!fields[2].is_nullable()); + assert!(!fields[3].is_nullable()); + } else if join_type == JoinType::Full { + assert!(fields[0].is_nullable()); + assert!(fields[1].is_nullable()); + assert!(fields[2].is_nullable()); + assert!(fields[3].is_nullable()); + } + } + } + + assert_eq!(join.on, vec![(col("t1.a"), col("t2.a"))]); + assert_eq!(join.filter, Some(col("t1.b").gt(col("t2.b")))); + assert_eq!(join.join_type, join_type); + assert_eq!(join.join_constraint, JoinConstraint::On); + assert_eq!(join.null_equality, NullEquality::NullEqualsNothing); + } + + Ok(()) + } + + #[test] + fn test_join_try_new_with_using_constraint_and_overlapping_columns() -> Result<()> { + let left_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), // Common column in both tables + Field::new("name", DataType::Utf8, false), // Unique to left + Field::new("value", DataType::Int32, false), // Common column, different meaning + ]); + + let right_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), // Common column in both tables + Field::new("category", DataType::Utf8, false), // Unique to right + Field::new("value", DataType::Float64, true), // Common column, different meaning + ]); + + let left_plan = table_scan(Some("t1"), &left_schema, None)?.build()?; + + let right_plan = table_scan(Some("t2"), &right_schema, None)?.build()?; + + // Test 1: USING constraint with a common column + { + // In the logical plan, both copies of the `id` column are preserved + // The USING constraint is handled later during physical execution, where the common column appears once + let join = Join::try_new( + Arc::new(left_plan.clone()), + Arc::new(right_plan.clone()), + vec![(col("t1.id"), col("t2.id"))], + None, + JoinType::Inner, + JoinConstraint::Using, + NullEquality::NullEqualsNothing, + )?; + + let fields = join.schema.fields(); + + assert_eq!(fields.len(), 6); + + assert_eq!( + fields[0].name(), + "id", + "First field should be 'id' from left table" + ); + assert_eq!( + fields[1].name(), + "name", + "Second field should be 'name' from left table" + ); + assert_eq!( + fields[2].name(), + "value", + "Third field should be 'value' from left table" + ); + assert_eq!( + fields[3].name(), + "id", + "Fourth field should be 'id' from right table" + ); + assert_eq!( + fields[4].name(), + "category", + "Fifth field should be 'category' from right table" + ); + assert_eq!( + fields[5].name(), + "value", + "Sixth field should be 'value' from right table" + ); + + assert_eq!(join.join_constraint, JoinConstraint::Using); + } + + // Test 2: Complex join condition with expressions + { + // Complex condition: join on id equality AND where left.value < right.value + let join = Join::try_new( + Arc::new(left_plan.clone()), + Arc::new(right_plan.clone()), + vec![(col("t1.id"), col("t2.id"))], // Equijoin condition + Some(col("t1.value").lt(col("t2.value"))), // Non-equi filter condition + JoinType::Inner, + JoinConstraint::On, + NullEquality::NullEqualsNothing, + )?; + + let fields = join.schema.fields(); + assert_eq!(fields.len(), 6); + + assert_eq!( + fields[0].name(), + "id", + "First field should be 'id' from left table" + ); + assert_eq!( + fields[1].name(), + "name", + "Second field should be 'name' from left table" + ); + assert_eq!( + fields[2].name(), + "value", + "Third field should be 'value' from left table" + ); + assert_eq!( + fields[3].name(), + "id", + "Fourth field should be 'id' from right table" + ); + assert_eq!( + fields[4].name(), + "category", + "Fifth field should be 'category' from right table" + ); + assert_eq!( + fields[5].name(), + "value", + "Sixth field should be 'value' from right table" + ); + + assert_eq!(join.filter, Some(col("t1.value").lt(col("t2.value")))); + } + + // Test 3: Join with null equality behavior set to true + { + let join = Join::try_new( + Arc::new(left_plan.clone()), + Arc::new(right_plan.clone()), + vec![(col("t1.id"), col("t2.id"))], + None, + JoinType::Inner, + JoinConstraint::On, + NullEquality::NullEqualsNull, + )?; + + assert_eq!(join.null_equality, NullEquality::NullEqualsNull); + } + + Ok(()) + } + + #[test] + fn test_join_try_new_schema_validation() -> Result<()> { + let left_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + Field::new("value", DataType::Float64, true), + ]); + + let right_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("category", DataType::Utf8, true), + Field::new("code", DataType::Int16, false), + ]); + + let left_plan = table_scan(Some("t1"), &left_schema, None)?.build()?; + + let right_plan = table_scan(Some("t2"), &right_schema, None)?.build()?; + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + ]; + + for join_type in join_types { + let join = Join::try_new( + Arc::new(left_plan.clone()), + Arc::new(right_plan.clone()), + vec![(col("t1.id"), col("t2.id"))], + Some(col("t1.value").gt(lit(5.0))), + join_type, + JoinConstraint::On, + NullEquality::NullEqualsNothing, + )?; + + let fields = join.schema.fields(); + assert_eq!(fields.len(), 6, "Expected 6 fields for {join_type:?} join"); + + for (i, field) in fields.iter().enumerate() { + let expected_nullable = match (i, &join_type) { + // Left table fields (indices 0, 1, 2) + (0, JoinType::Right | JoinType::Full) => true, // id becomes nullable in RIGHT/FULL + (1, JoinType::Right | JoinType::Full) => true, // name becomes nullable in RIGHT/FULL + (2, _) => true, // value is already nullable + + // Right table fields (indices 3, 4, 5) + (3, JoinType::Left | JoinType::Full) => true, // id becomes nullable in LEFT/FULL + (4, _) => true, // category is already nullable + (5, JoinType::Left | JoinType::Full) => true, // code becomes nullable in LEFT/FULL + + _ => false, + }; + + assert_eq!( + field.is_nullable(), + expected_nullable, + "Field {} ({}) nullability incorrect for {:?} join", + i, + field.name(), + join_type + ); + } + } + + let using_join = Join::try_new( + Arc::new(left_plan.clone()), + Arc::new(right_plan.clone()), + vec![(col("t1.id"), col("t2.id"))], + None, + JoinType::Inner, + JoinConstraint::Using, + NullEquality::NullEqualsNothing, + )?; + + assert_eq!( + using_join.schema.fields().len(), + 6, + "USING join should have all fields" + ); + assert_eq!(using_join.join_constraint, JoinConstraint::Using); + + Ok(()) + } } diff --git a/datafusion/expr/src/logical_plan/statement.rs b/datafusion/expr/src/logical_plan/statement.rs index 82acebee3de6..72eb6b39bb47 100644 --- a/datafusion/expr/src/logical_plan/statement.rs +++ b/datafusion/expr/src/logical_plan/statement.rs @@ -110,7 +110,7 @@ impl Statement { Statement::Prepare(Prepare { name, data_types, .. }) => { - write!(f, "Prepare: {name:?} {data_types:?} ") + write!(f, "Prepare: {name:?} {data_types:?}") } Statement::Execute(Execute { name, parameters, .. @@ -123,7 +123,7 @@ impl Statement { ) } Statement::Deallocate(Deallocate { name }) => { - write!(f, "Deallocate: {}", name) + write!(f, "Deallocate: {name}") } } } diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 7f6e1e025387..527248ad39c2 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -85,17 +85,9 @@ impl TreeNode for LogicalPlan { schema, }) }), - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) => input.map_elements(f)?.update_data(|input| { - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) - }), + LogicalPlan::Filter(Filter { predicate, input }) => input + .map_elements(f)? + .update_data(|input| LogicalPlan::Filter(Filter { predicate, input })), LogicalPlan::Repartition(Repartition { input, partitioning_scheme, @@ -140,7 +132,7 @@ impl TreeNode for LogicalPlan { join_type, join_constraint, schema, - null_equals_null, + null_equality, }) => (left, right).map_elements(f)?.update_data(|(left, right)| { LogicalPlan::Join(Join { left, @@ -150,7 +142,7 @@ impl TreeNode for LogicalPlan { join_type, join_constraint, schema, - null_equals_null, + null_equality, }) }), LogicalPlan::Limit(Limit { skip, fetch, input }) => input @@ -444,11 +436,11 @@ impl LogicalPlan { filters.apply_elements(f) } LogicalPlan::Unnest(unnest) => { - let columns = unnest.exec_columns.clone(); - - let exprs = columns + let exprs = unnest + .exec_columns .iter() - .map(|c| Expr::Column(c.clone())) + .cloned() + .map(Expr::Column) .collect::>(); exprs.apply_elements(f) } @@ -509,17 +501,10 @@ impl LogicalPlan { LogicalPlan::Values(Values { schema, values }) => values .map_elements(f)? .update_data(|values| LogicalPlan::Values(Values { schema, values })), - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) => f(predicate)?.update_data(|predicate| { - LogicalPlan::Filter(Filter { - predicate, - input, - having, - }) - }), + LogicalPlan::Filter(Filter { predicate, input }) => f(predicate)? + .update_data(|predicate| { + LogicalPlan::Filter(Filter { predicate, input }) + }), LogicalPlan::Repartition(Repartition { input, partitioning_scheme, @@ -576,7 +561,7 @@ impl LogicalPlan { join_type, join_constraint, schema, - null_equals_null, + null_equality, }) => (on, filter).map_elements(f)?.update_data(|(on, filter)| { LogicalPlan::Join(Join { left, @@ -586,7 +571,7 @@ impl LogicalPlan { join_type, join_constraint, schema, - null_equals_null, + null_equality, }) }), LogicalPlan::Sort(Sort { expr, input, fetch }) => expr diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index a2ed0592efdb..4c03f919312e 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -312,6 +312,7 @@ pub struct RawWindowExpr { /// Result of planning a raw expr with [`ExprPlanner`] #[derive(Debug, Clone)] +#[allow(clippy::large_enum_variant)] pub enum PlannerResult { /// The raw expression was successfully planned as a new [`Expr`] Planned(Expr), diff --git a/datafusion/expr/src/simplify.rs b/datafusion/expr/src/simplify.rs index 467ce8bf53e2..411dbbdc4034 100644 --- a/datafusion/expr/src/simplify.rs +++ b/datafusion/expr/src/simplify.rs @@ -110,6 +110,7 @@ impl SimplifyInfo for SimplifyContext<'_> { /// Was the expression simplified? #[derive(Debug)] +#[allow(clippy::large_enum_variant)] pub enum ExprSimplifyResult { /// The function call was simplified to an entirely new Expr Simplified(Expr), diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index a753f4c376c6..f310f31be352 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -22,7 +22,7 @@ use std::any::Any; use arrow::datatypes::{ - DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result}; @@ -175,14 +175,10 @@ impl AggregateUDFImpl for Sum { unreachable!("stub should not have accumulate()") } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unreachable!("stub should not have state_fields()") } - fn aliases(&self) -> &[String] { - &[] - } - fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { false } @@ -254,7 +250,7 @@ impl AggregateUDFImpl for Count { false } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { not_impl_err!("no impl for stub") } @@ -336,7 +332,7 @@ impl AggregateUDFImpl for Min { Ok(DataType::Int64) } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { not_impl_err!("no impl for stub") } @@ -344,10 +340,6 @@ impl AggregateUDFImpl for Min { not_impl_err!("no impl for stub") } - fn aliases(&self) -> &[String] { - &[] - } - fn create_groups_accumulator( &self, _args: AccumulatorArgs, @@ -421,7 +413,7 @@ impl AggregateUDFImpl for Max { Ok(DataType::Int64) } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { not_impl_err!("no impl for stub") } @@ -429,10 +421,6 @@ impl AggregateUDFImpl for Max { not_impl_err!("no impl for stub") } - fn aliases(&self) -> &[String] { - &[] - } - fn create_groups_accumulator( &self, _args: AccumulatorArgs, @@ -491,9 +479,10 @@ impl AggregateUDFImpl for Avg { not_impl_err!("no impl for stub") } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { not_impl_err!("no impl for stub") } + fn aliases(&self) -> &[String] { &self.aliases } diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index f20dab7e165f..f953aec5a1e3 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -73,7 +73,7 @@ impl TreeNode for Expr { // Treat OuterReferenceColumn as a leaf expression | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } @@ -92,14 +92,16 @@ impl TreeNode for Expr { (expr, when_then_expr, else_expr).apply_ref_elements(f), Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) => (args, filter, order_by).apply_ref_elements(f), - Expr::WindowFunction(WindowFunction { - params : WindowFunctionParams { + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams { args, partition_by, order_by, - ..}, ..}) => { + .. + } = &window_fun.as_ref().params; (args, partition_by, order_by).apply_ref_elements(f) } + Expr::InList(InList { expr, list, .. }) => { (expr, list).apply_ref_elements(f) } @@ -124,7 +126,7 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) - | Expr::Literal(_) => Transformed::no(self), + | Expr::Literal(_, _) => Transformed::no(self), Expr::Unnest(Unnest { expr, .. }) => expr .map_elements(f)? .update_data(|expr| Expr::Unnest(Unnest { expr })), @@ -230,27 +232,30 @@ impl TreeNode for Expr { ))) })? } - Expr::WindowFunction(WindowFunction { - fun, - params: - WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + } = *window_fun; + (args, partition_by, order_by).map_elements(f)?.update_data( + |(new_args, new_partition_by, new_order_by)| { + Expr::from(WindowFunction::new(fun, new_args)) + .partition_by(new_partition_by) + .order_by(new_order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap() }, - }) => (args, partition_by, order_by).map_elements(f)?.update_data( - |(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new(fun, new_args)) - .partition_by(new_partition_by) - .order_by(new_order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build() - .unwrap() - }, - ), + ) + } Expr::AggregateFunction(AggregateFunction { func, params: diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 3b34718062eb..763a4e6539fd 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -15,19 +15,22 @@ // specific language governing permissions and limitations // under the License. -use super::binary::{binary_numeric_coercion, comparison_coercion}; +use super::binary::binary_numeric_coercion; use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; +use arrow::datatypes::FieldRef; use arrow::{ compute::can_cast_types, - datatypes::{DataType, Field, TimeUnit}, + datatypes::{DataType, TimeUnit}, }; use datafusion_common::types::LogicalType; -use datafusion_common::utils::{coerced_fixed_size_list_to_list, ListCoercion}; +use datafusion_common::utils::{ + base_type, coerced_fixed_size_list_to_list, ListCoercion, +}; use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, plan_err, types::NativeType, - utils::list_ndims, Result, + exec_err, internal_err, plan_err, types::NativeType, utils::list_ndims, Result, }; use datafusion_expr_common::signature::ArrayFunctionArgument; +use datafusion_expr_common::type_coercion::binary::type_union_resolution; use datafusion_expr_common::{ signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD}, type_coercion::binary::comparison_coercion_numeric, @@ -75,19 +78,19 @@ pub fn data_types_with_scalar_udf( /// Performs type coercion for aggregate function arguments. /// -/// Returns the data types to which each argument must be coerced to +/// Returns the fields to which each argument must be coerced to /// match `signature`. /// /// For more details on coercion in general, please see the /// [`type_coercion`](crate::type_coercion) module. -pub fn data_types_with_aggregate_udf( - current_types: &[DataType], +pub fn fields_with_aggregate_udf( + current_fields: &[FieldRef], func: &AggregateUDF, -) -> Result> { +) -> Result> { let signature = func.signature(); let type_signature = &signature.type_signature; - if current_types.is_empty() && type_signature != &TypeSignature::UserDefined { + if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined { if type_signature.supports_zero_argument() { return Ok(vec![]); } else if type_signature.used_to_support_zero_arguments() { @@ -97,17 +100,32 @@ pub fn data_types_with_aggregate_udf( return plan_err!("'{}' does not support zero arguments", func.name()); } } + let current_types = current_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); let valid_types = - get_valid_types_with_aggregate_udf(type_signature, current_types, func)?; + get_valid_types_with_aggregate_udf(type_signature, ¤t_types, func)?; if valid_types .iter() - .any(|data_type| data_type == current_types) + .any(|data_type| data_type == ¤t_types) { - return Ok(current_types.to_vec()); + return Ok(current_fields.to_vec()); } - try_coerce_types(func.name(), valid_types, current_types, type_signature) + let updated_types = + try_coerce_types(func.name(), valid_types, ¤t_types, type_signature)?; + + Ok(current_fields + .iter() + .zip(updated_types) + .map(|(current_field, new_type)| { + current_field.as_ref().clone().with_data_type(new_type) + }) + .map(Arc::new) + .collect()) } /// Performs type coercion for window function arguments. @@ -117,14 +135,14 @@ pub fn data_types_with_aggregate_udf( /// /// For more details on coercion in general, please see the /// [`type_coercion`](crate::type_coercion) module. -pub fn data_types_with_window_udf( - current_types: &[DataType], +pub fn fields_with_window_udf( + current_fields: &[FieldRef], func: &WindowUDF, -) -> Result> { +) -> Result> { let signature = func.signature(); let type_signature = &signature.type_signature; - if current_types.is_empty() && type_signature != &TypeSignature::UserDefined { + if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined { if type_signature.supports_zero_argument() { return Ok(vec![]); } else if type_signature.used_to_support_zero_arguments() { @@ -135,16 +153,31 @@ pub fn data_types_with_window_udf( } } + let current_types = current_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); let valid_types = - get_valid_types_with_window_udf(type_signature, current_types, func)?; + get_valid_types_with_window_udf(type_signature, ¤t_types, func)?; if valid_types .iter() - .any(|data_type| data_type == current_types) + .any(|data_type| data_type == ¤t_types) { - return Ok(current_types.to_vec()); + return Ok(current_fields.to_vec()); } - try_coerce_types(func.name(), valid_types, current_types, type_signature) + let updated_types = + try_coerce_types(func.name(), valid_types, ¤t_types, type_signature)?; + + Ok(current_fields + .iter() + .zip(updated_types) + .map(|(current_field, new_type)| { + current_field.as_ref().clone().with_data_type(new_type) + }) + .map(Arc::new) + .collect()) } /// Performs type coercion for function arguments. @@ -364,98 +397,67 @@ fn get_valid_types( return Ok(vec![vec![]]); } - let array_idx = arguments.iter().enumerate().find_map(|(idx, arg)| { - if *arg == ArrayFunctionArgument::Array { - Some(idx) - } else { - None - } - }); - let Some(array_idx) = array_idx else { - return Err(internal_datafusion_err!("Function '{function_name}' expected at least one argument array argument")); - }; - let Some(array_type) = array(¤t_types[array_idx]) else { - return Ok(vec![vec![]]); - }; - - // We need to find the coerced base type, mainly for cases like: - // `array_append(List(null), i64)` -> `List(i64)` - let mut new_base_type = datafusion_common::utils::base_type(&array_type); - for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) { - match argument_type { - ArrayFunctionArgument::Element | ArrayFunctionArgument::Array => { - new_base_type = - coerce_array_types(function_name, current_type, &new_base_type)?; + let mut large_list = false; + let mut fixed_size = array_coercion != Some(&ListCoercion::FixedSizedListToList); + let mut list_sizes = Vec::with_capacity(arguments.len()); + let mut element_types = Vec::with_capacity(arguments.len()); + for (argument, current_type) in arguments.iter().zip(current_types.iter()) { + match argument { + ArrayFunctionArgument::Index | ArrayFunctionArgument::String => (), + ArrayFunctionArgument::Element => { + element_types.push(current_type.clone()) } - ArrayFunctionArgument::Index | ArrayFunctionArgument::String => {} + ArrayFunctionArgument::Array => match current_type { + DataType::Null => element_types.push(DataType::Null), + DataType::List(field) => { + element_types.push(field.data_type().clone()); + fixed_size = false; + } + DataType::LargeList(field) => { + element_types.push(field.data_type().clone()); + large_list = true; + fixed_size = false; + } + DataType::FixedSizeList(field, size) => { + element_types.push(field.data_type().clone()); + list_sizes.push(*size) + } + arg_type => { + plan_err!("{function_name} does not support type {arg_type}")? + } + }, } } - let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only( - &array_type, - &new_base_type, - array_coercion, - ); - let new_elem_type = match new_array_type { - DataType::List(ref field) - | DataType::LargeList(ref field) - | DataType::FixedSizeList(ref field, _) => field.data_type(), - _ => return Ok(vec![vec![]]), + let Some(element_type) = type_union_resolution(&element_types) else { + return Ok(vec![vec![]]); }; - let mut valid_types = Vec::with_capacity(arguments.len()); - for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) { - let valid_type = match argument_type { - ArrayFunctionArgument::Element => new_elem_type.clone(), + if !fixed_size { + list_sizes.clear() + } + + let mut list_sizes = list_sizes.into_iter(); + let valid_types = arguments.iter().zip(current_types.iter()).map( + |(argument_type, current_type)| match argument_type { ArrayFunctionArgument::Index => DataType::Int64, ArrayFunctionArgument::String => DataType::Utf8, + ArrayFunctionArgument::Element => element_type.clone(), ArrayFunctionArgument::Array => { - let Some(current_type) = array(current_type) else { - return Ok(vec![vec![]]); - }; - let new_type = - datafusion_common::utils::coerced_type_with_base_type_only( - ¤t_type, - &new_base_type, - array_coercion, - ); - // All array arguments must be coercible to the same type - if new_type != new_array_type { - return Ok(vec![vec![]]); + if current_type.is_null() { + DataType::Null + } else if large_list { + DataType::new_large_list(element_type.clone(), true) + } else if let Some(size) = list_sizes.next() { + DataType::new_fixed_size_list(element_type.clone(), size, true) + } else { + DataType::new_list(element_type.clone(), true) } - new_type } - }; - valid_types.push(valid_type); - } - - Ok(vec![valid_types]) - } - - fn array(array_type: &DataType) -> Option { - match array_type { - DataType::List(_) | DataType::LargeList(_) => Some(array_type.clone()), - DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))), - DataType::Null => Some(DataType::List(Arc::new(Field::new_list_field( - DataType::Int64, - true, - )))), - _ => None, - } - } + }, + ); - fn coerce_array_types( - function_name: &str, - current_type: &DataType, - base_type: &DataType, - ) -> Result { - let current_base_type = datafusion_common::utils::base_type(current_type); - let new_base_type = comparison_coercion(base_type, ¤t_base_type); - new_base_type.ok_or_else(|| { - internal_datafusion_err!( - "Function '{function_name}' does not support coercion from {base_type:?} to {current_base_type:?}" - ) - }) + Ok(vec![valid_types.collect()]) } fn recursive_array(array_type: &DataType) -> Option { @@ -800,7 +802,7 @@ pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { /// /// Expect uni-directional coercion, for example, i32 is coerced to i64, but i64 is not coerced to i32. /// -/// Unlike [comparison_coercion], the coerced type is usually `wider` for lossless conversion. +/// Unlike [crate::binary::comparison_coercion], the coerced type is usually `wider` for lossless conversion. fn coerced_from<'a>( type_into: &'a DataType, type_from: &'a DataType, @@ -867,7 +869,7 @@ fn coerced_from<'a>( // Only accept list and largelist with the same number of dimensions unless the type is Null. // List or LargeList with different dimensions should be handled in TypeSignature or other places before this (List(_) | LargeList(_), _) - if datafusion_common::utils::base_type(type_from).eq(&Null) + if base_type(type_from).is_null() || list_ndims(type_from) == list_ndims(type_into) => { Some(type_into.clone()) @@ -906,7 +908,6 @@ fn coerced_from<'a>( #[cfg(test)] mod tests { - use crate::Volatility; use super::*; @@ -1193,4 +1194,155 @@ mod tests { Some(type_into.clone()) ); } + + #[test] + fn test_get_valid_types_array_and_array() -> Result<()> { + let function = "array_and_array"; + let signature = Signature::arrays( + 2, + Some(ListCoercion::FixedSizedListToList), + Volatility::Immutable, + ); + + let data_types = vec![ + DataType::new_list(DataType::Int32, true), + DataType::new_large_list(DataType::Float64, true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_large_list(DataType::Float64, true), + DataType::new_large_list(DataType::Float64, true), + ]] + ); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Int64, 3, true), + DataType::new_fixed_size_list(DataType::Int32, 5, true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_list(DataType::Int64, true), + DataType::new_list(DataType::Int64, true), + ]] + ); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Null, 3, true), + DataType::new_large_list(DataType::Utf8, true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_large_list(DataType::Utf8, true), + DataType::new_large_list(DataType::Utf8, true), + ]] + ); + + Ok(()) + } + + #[test] + fn test_get_valid_types_array_and_element() -> Result<()> { + let function = "array_and_element"; + let signature = Signature::array_and_element(Volatility::Immutable); + + let data_types = + vec![DataType::new_list(DataType::Int32, true), DataType::Float64]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_list(DataType::Float64, true), + DataType::Float64, + ]] + ); + + let data_types = vec![ + DataType::new_large_list(DataType::Int32, true), + DataType::Null, + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_large_list(DataType::Int32, true), + DataType::Int32, + ]] + ); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Null, 3, true), + DataType::Utf8, + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_list(DataType::Utf8, true), + DataType::Utf8, + ]] + ); + + Ok(()) + } + + #[test] + fn test_get_valid_types_element_and_array() -> Result<()> { + let function = "element_and_array"; + let signature = Signature::element_and_array(Volatility::Immutable); + + let data_types = vec![ + DataType::new_large_list(DataType::Null, false), + DataType::new_list(DataType::new_list(DataType::Int64, true), true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_large_list(DataType::Int64, true), + DataType::new_list(DataType::new_large_list(DataType::Int64, true), true), + ]] + ); + + Ok(()) + } + + #[test] + fn test_get_valid_types_fixed_size_arrays() -> Result<()> { + let function = "fixed_size_arrays"; + let signature = Signature::arrays(2, None, Volatility::Immutable); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Int64, 3, true), + DataType::new_fixed_size_list(DataType::Int32, 5, true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_fixed_size_list(DataType::Int64, 3, true), + DataType::new_fixed_size_list(DataType::Int64, 5, true), + ]] + ); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Int64, 3, true), + DataType::new_list(DataType::Int32, true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_list(DataType::Int64, true), + DataType::new_list(DataType::Int64, true), + ]] + ); + + let data_types = vec![ + DataType::new_fixed_size_list(DataType::Utf8, 3, true), + DataType::new_list(DataType::new_list(DataType::Int32, true), true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![]] + ); + + Ok(()) + } } diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index b75e8fd3cd3c..d1bf45ce2fe8 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -24,7 +24,7 @@ use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; use std::vec; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -224,6 +224,13 @@ impl AggregateUDF { self.inner.return_type(args) } + /// Return the field of the function given its input fields + /// + /// See [`AggregateUDFImpl::return_field`] for more details. + pub fn return_field(&self, args: &[FieldRef]) -> Result { + self.inner.return_field(args) + } + /// Return an accumulator the given aggregate, given its return datatype pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { self.inner.accumulator(acc_args) @@ -234,7 +241,7 @@ impl AggregateUDF { /// for more details. /// /// This is used to support multi-phase aggregations - pub fn state_fields(&self, args: StateFieldsArgs) -> Result> { + pub fn state_fields(&self, args: StateFieldsArgs) -> Result> { self.inner.state_fields(args) } @@ -315,6 +322,16 @@ impl AggregateUDF { self.inner.default_value(data_type) } + /// See [`AggregateUDFImpl::supports_null_handling_clause`] for more details. + pub fn supports_null_handling_clause(&self) -> bool { + self.inner.supports_null_handling_clause() + } + + /// See [`AggregateUDFImpl::is_ordered_set_aggregate`] for more details. + pub fn is_ordered_set_aggregate(&self) -> bool { + self.inner.is_ordered_set_aggregate() + } + /// Returns the documentation for this Aggregate UDF. /// /// Documentation can be accessed programmatically as well as @@ -346,8 +363,8 @@ where /// # Basic Example /// ``` /// # use std::any::Any; -/// # use std::sync::LazyLock; -/// # use arrow::datatypes::DataType; +/// # use std::sync::{Arc, LazyLock}; +/// # use arrow::datatypes::{DataType, FieldRef}; /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr, Documentation}; /// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}}; @@ -391,10 +408,10 @@ where /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. /// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { unimplemented!() } -/// fn state_fields(&self, args: StateFieldsArgs) -> Result> { +/// fn state_fields(&self, args: StateFieldsArgs) -> Result> { /// Ok(vec![ -/// Field::new("value", args.return_type.clone(), true), -/// Field::new("ordering", DataType::UInt32, true) +/// Arc::new(args.return_field.as_ref().clone().with_name("value")), +/// Arc::new(Field::new("ordering", DataType::UInt32, true)) /// ]) /// } /// fn documentation(&self) -> Option<&Documentation> { @@ -432,6 +449,14 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { null_treatment, } = params; + // exclude the first function argument(= column) in ordered set aggregate function, + // because it is duplicated with the WITHIN GROUP clause in schema name. + let args = if self.is_ordered_set_aggregate() { + &args[1..] + } else { + &args[..] + }; + let mut schema_name = String::new(); schema_name.write_fmt(format_args!( @@ -442,7 +467,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { ))?; if let Some(null_treatment) = null_treatment { - schema_name.write_fmt(format_args!(" {}", null_treatment))?; + schema_name.write_fmt(format_args!(" {null_treatment}"))?; } if let Some(filter) = filter { @@ -450,8 +475,14 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { }; if let Some(order_by) = order_by { + let clause = match self.is_ordered_set_aggregate() { + true => "WITHIN GROUP", + false => "ORDER BY", + }; + schema_name.write_fmt(format_args!( - " ORDER BY [{}]", + " {} [{}]", + clause, schema_name_from_sorts(order_by)? ))?; }; @@ -481,7 +512,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { ))?; if let Some(null_treatment) = null_treatment { - schema_name.write_fmt(format_args!(" {}", null_treatment))?; + schema_name.write_fmt(format_args!(" {null_treatment}"))?; } if let Some(filter) = filter { @@ -525,7 +556,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { ))?; if let Some(null_treatment) = null_treatment { - schema_name.write_fmt(format_args!(" {}", null_treatment))?; + schema_name.write_fmt(format_args!(" {null_treatment}"))?; } if !partition_by.is_empty() { @@ -572,7 +603,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { ))?; if let Some(nt) = null_treatment { - display_name.write_fmt(format_args!(" {}", nt))?; + display_name.write_fmt(format_args!(" {nt}"))?; } if let Some(fe) = filter { display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?; @@ -619,7 +650,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { ))?; if let Some(null_treatment) = null_treatment { - display_name.write_fmt(format_args!(" {}", null_treatment))?; + display_name.write_fmt(format_args!(" {null_treatment}"))?; } if !partition_by.is_empty() { @@ -650,6 +681,35 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// the arguments fn return_type(&self, arg_types: &[DataType]) -> Result; + /// What type will be returned by this function, given the arguments? + /// + /// By default, this function calls [`Self::return_type`] with the + /// types of each argument. + /// + /// # Notes + /// + /// Most UDFs should implement [`Self::return_type`] and not this + /// function as the output type for most functions only depends on the types + /// of their inputs (e.g. `sum(f64)` is always `f64`). + /// + /// This function can be used for more advanced cases such as: + /// + /// 1. specifying nullability + /// 2. return types based on the **values** of the arguments (rather than + /// their **types**. + /// 3. return types based on metadata within the fields of the inputs + fn return_field(&self, arg_fields: &[FieldRef]) -> Result { + let arg_types: Vec<_> = + arg_fields.iter().map(|f| f.data_type()).cloned().collect(); + let data_type = self.return_type(&arg_types)?; + + Ok(Arc::new(Field::new( + self.name(), + data_type, + self.is_nullable(), + ))) + } + /// Whether the aggregate function is nullable. /// /// Nullable means that the function could return `null` for any inputs. @@ -688,15 +748,16 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// The name of the fields must be unique within the query and thus should /// be derived from `name`. See [`format_state_name`] for a utility function /// to generate a unique name. - fn state_fields(&self, args: StateFieldsArgs) -> Result> { - let fields = vec![Field::new( - format_state_name(args.name, "value"), - args.return_type.clone(), - true, - )]; + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let fields = vec![args + .return_field + .as_ref() + .clone() + .with_name(format_state_name(args.name, "value"))]; Ok(fields .into_iter() + .map(Arc::new) .chain(args.ordering_fields.to_vec()) .collect()) } @@ -891,6 +952,18 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { ScalarValue::try_from(data_type) } + /// If this function supports `[IGNORE NULLS | RESPECT NULLS]` clause, return true + /// If the function does not, return false + fn supports_null_handling_clause(&self) -> bool { + true + } + + /// If this function is ordered-set aggregate function, return true + /// If the function is not, return false + fn is_ordered_set_aggregate(&self) -> bool { + false + } + /// Returns the documentation for this Aggregate UDF. /// /// Documentation can be accessed programmatically as well as @@ -978,7 +1051,7 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { &self.aliases } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { self.inner.state_fields(args) } @@ -1111,7 +1184,7 @@ pub enum SetMonotonicity { #[cfg(test)] mod test { use crate::{AggregateUDF, AggregateUDFImpl}; - use arrow::datatypes::{DataType, Field}; + use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::Result; use datafusion_expr_common::accumulator::Accumulator; use datafusion_expr_common::signature::{Signature, Volatility}; @@ -1157,7 +1230,7 @@ mod test { ) -> Result> { unimplemented!() } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unimplemented!() } } @@ -1197,7 +1270,7 @@ mod test { ) -> Result> { unimplemented!() } - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unimplemented!() } } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 9b2400774a3d..8b6ffba04ff6 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -17,11 +17,12 @@ //! [`ScalarUDF`]: Scalar User Defined Functions +use crate::async_udf::AsyncScalarUDF; use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::{ColumnarValue, Documentation, Expr, Signature}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::Interval; use std::any::Any; @@ -34,7 +35,7 @@ use std::sync::Arc; /// /// A scalar function produces a single row output for each row of input. This /// struct contains the information DataFusion needs to plan and invoke -/// functions you supply such name, type signature, return type, and actual +/// functions you supply such as name, type signature, return type, and actual /// implementation. /// /// 1. For simple use cases, use [`create_udf`] (examples in [`simple_udf.rs`]). @@ -42,11 +43,11 @@ use std::sync::Arc; /// 2. For advanced use cases, use [`ScalarUDFImpl`] which provides full API /// access (examples in [`advanced_udf.rs`]). /// -/// See [`Self::call`] to invoke a `ScalarUDF` with arguments. +/// See [`Self::call`] to create an `Expr` which invokes a `ScalarUDF` with arguments. /// /// # API Note /// -/// This is a separate struct from `ScalarUDFImpl` to maintain backwards +/// This is a separate struct from [`ScalarUDFImpl`] to maintain backwards /// compatibility with the older API. /// /// [`create_udf`]: crate::expr_fn::create_udf @@ -170,7 +171,7 @@ impl ScalarUDF { /// /// # Notes /// - /// If a function implement [`ScalarUDFImpl::return_type_from_args`], + /// If a function implement [`ScalarUDFImpl::return_field_from_args`], /// its [`ScalarUDFImpl::return_type`] should raise an error. /// /// See [`ScalarUDFImpl::return_type`] for more details. @@ -180,9 +181,9 @@ impl ScalarUDF { /// Return the datatype this function returns given the input argument types. /// - /// See [`ScalarUDFImpl::return_type_from_args`] for more details. - pub fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - self.inner.return_type_from_args(args) + /// See [`ScalarUDFImpl::return_field_from_args`] for more details. + pub fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + self.inner.return_field_from_args(args) } /// Do the function rewrite @@ -280,6 +281,11 @@ impl ScalarUDF { pub fn documentation(&self) -> Option<&Documentation> { self.inner.documentation() } + + /// Return true if this function is an async function + pub fn as_async(&self) -> Option<&AsyncScalarUDF> { + self.inner().as_any().downcast_ref::() + } } impl From for ScalarUDF @@ -293,14 +299,26 @@ where /// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a /// scalar function. -pub struct ScalarFunctionArgs<'a> { +#[derive(Debug, Clone)] +pub struct ScalarFunctionArgs { /// The evaluated arguments to the function pub args: Vec, + /// Field associated with each arg, if it exists + pub arg_fields: Vec, /// The number of rows in record batch being evaluated pub number_rows: usize, - /// The return type of the scalar function returned (from `return_type` or `return_type_from_args`) - /// when creating the physical expression from the logical expression - pub return_type: &'a DataType, + /// The return field of the scalar function returned (from `return_type` + /// or `return_field_from_args`) when creating the physical expression + /// from the logical expression + pub return_field: FieldRef, +} + +impl ScalarFunctionArgs { + /// The return type of the function. See [`Self::return_field`] for more + /// details. + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } } /// Information about arguments passed to the function @@ -309,64 +327,18 @@ pub struct ScalarFunctionArgs<'a> { /// such as the type of the arguments, any scalar arguments and if the /// arguments can (ever) be null /// -/// See [`ScalarUDFImpl::return_type_from_args`] for more information +/// See [`ScalarUDFImpl::return_field_from_args`] for more information #[derive(Debug)] -pub struct ReturnTypeArgs<'a> { +pub struct ReturnFieldArgs<'a> { /// The data types of the arguments to the function - pub arg_types: &'a [DataType], - /// Is argument `i` to the function a scalar (constant) + pub arg_fields: &'a [FieldRef], + /// Is argument `i` to the function a scalar (constant)? /// - /// If argument `i` is not a scalar, it will be None + /// If the argument `i` is not a scalar, it will be None /// /// For example, if a function is called like `my_function(column_a, 5)` /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` pub scalar_arguments: &'a [Option<&'a ScalarValue>], - /// Can argument `i` (ever) null? - pub nullables: &'a [bool], -} - -/// Return metadata for this function. -/// -/// See [`ScalarUDFImpl::return_type_from_args`] for more information -#[derive(Debug)] -pub struct ReturnInfo { - return_type: DataType, - nullable: bool, -} - -impl ReturnInfo { - pub fn new(return_type: DataType, nullable: bool) -> Self { - Self { - return_type, - nullable, - } - } - - pub fn new_nullable(return_type: DataType) -> Self { - Self { - return_type, - nullable: true, - } - } - - pub fn new_non_nullable(return_type: DataType) -> Self { - Self { - return_type, - nullable: false, - } - } - - pub fn return_type(&self) -> &DataType { - &self.return_type - } - - pub fn nullable(&self) -> bool { - self.nullable - } - - pub fn into_parts(self) -> (DataType, bool) { - (self.return_type, self.nullable) - } } /// Trait for implementing user defined scalar functions. @@ -480,7 +452,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// /// # Notes /// - /// If you provide an implementation for [`Self::return_type_from_args`], + /// If you provide an implementation for [`Self::return_field_from_args`], /// DataFusion will not call `return_type` (this function). In such cases /// is recommended to return [`DataFusionError::Internal`]. /// @@ -494,9 +466,10 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// /// # Notes /// - /// Most UDFs should implement [`Self::return_type`] and not this - /// function as the output type for most functions only depends on the types - /// of their inputs (e.g. `sqrt(f32)` is always `f32`). + /// For the majority of UDFs, implementing [`Self::return_type`] is sufficient, + /// as the result type is typically a deterministic function of the input types + /// (e.g., `sqrt(f32)` consistently yields `f32`). Implementing this method directly + /// is generally unnecessary unless the return type depends on runtime values. /// /// This function can be used for more advanced cases such as: /// @@ -504,6 +477,27 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// 2. return types based on the **values** of the arguments (rather than /// their **types**. /// + /// # Example creating `Field` + /// + /// Note the name of the [`Field`] is ignored, except for structured types such as + /// `DataType::Struct`. + /// + /// ```rust + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field, FieldRef}; + /// # use datafusion_common::Result; + /// # use datafusion_expr::ReturnFieldArgs; + /// # struct Example{} + /// # impl Example { + /// fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + /// // report output is only nullable if any one of the arguments are nullable + /// let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + /// let field = Arc::new(Field::new("ignored_name", DataType::Int32, true)); + /// Ok(field) + /// } + /// # } + /// ``` + /// /// # Output Type based on Values /// /// For example, the following two function calls get the same argument @@ -518,14 +512,20 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// This function **must** consistently return the same type for the same /// logical input even if the input is simplified (e.g. it must return the same /// value for `('foo' | 'bar')` as it does for ('foobar'). - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - let return_type = self.return_type(args.arg_types)?; - Ok(ReturnInfo::new_nullable(return_type)) + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let data_types = args + .arg_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + let return_type = self.return_type(&data_types)?; + Ok(Arc::new(Field::new(self.name(), return_type, true))) } #[deprecated( since = "45.0.0", - note = "Use `return_type_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_type_from_args`, you might have error" + note = "Use `return_field_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_field_from_args`, you might have error" )] fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool { true @@ -584,13 +584,15 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { } /// Returns true if some of this `exprs` subexpressions may not be evaluated - /// and thus any side effects (like divide by zero) may not be encountered - /// Setting this to true prevents certain optimizations such as common subexpression elimination + /// and thus any side effects (like divide by zero) may not be encountered. + /// + /// Setting this to true prevents certain optimizations such as common + /// subexpression elimination fn short_circuits(&self) -> bool { false } - /// Computes the output interval for a [`ScalarUDFImpl`], given the input + /// Computes the output [`Interval`] for a [`ScalarUDFImpl`], given the input /// intervals. /// /// # Parameters @@ -606,9 +608,11 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { Interval::make_unbounded(&DataType::Null) } - /// Updates bounds for child expressions, given a known interval for this - /// function. This is used to propagate constraints down through an expression - /// tree. + /// Updates bounds for child expressions, given a known [`Interval`]s for this + /// function. + /// + /// This function is used to propagate constraints down through an + /// expression tree. /// /// # Parameters /// @@ -657,20 +661,25 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { } } - /// Whether the function preserves lexicographical ordering based on the input ordering + /// Returns true if the function preserves lexicographical ordering based on + /// the input ordering. + /// + /// For example, `concat(a || b)` preserves lexicographical ordering, but `abs(a)` does not. fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result { Ok(false) } /// Coerce arguments of a function call to types that the function can evaluate. /// - /// This function is only called if [`ScalarUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most - /// UDFs should return one of the other variants of `TypeSignature` which handle common - /// cases + /// This function is only called if [`ScalarUDFImpl::signature`] returns + /// [`crate::TypeSignature::UserDefined`]. Most UDFs should return one of + /// the other variants of [`TypeSignature`] which handle common cases. /// /// See the [type coercion module](crate::type_coercion) /// documentation for more details on type coercion /// + /// [`TypeSignature`]: crate::TypeSignature + /// /// For example, if your function requires a floating point arguments, but the user calls /// it like `my_func(1::int)` (i.e. with `1` as an integer), coerce_types can return `[DataType::Float64]` /// to ensure the argument is converted to `1::double` @@ -714,8 +723,8 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// Returns the documentation for this Scalar UDF. /// - /// Documentation can be accessed programmatically as well as - /// generating publicly facing documentation. + /// Documentation can be accessed programmatically as well as generating + /// publicly facing documentation. fn documentation(&self) -> Option<&Documentation> { None } @@ -765,18 +774,18 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.return_type(arg_types) } - fn aliases(&self) -> &[String] { - &self.aliases - } - - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - self.inner.return_type_from_args(args) + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + self.inner.return_field_from_args(args) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { self.inner.invoke_with_args(args) } + fn aliases(&self) -> &[String] { + &self.aliases + } + fn simplify( &self, args: Vec, diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 4da63d7955f5..155de232285e 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -26,7 +26,7 @@ use std::{ sync::Arc, }; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, FieldRef}; use crate::expr::WindowFunction; use crate::{ @@ -133,7 +133,7 @@ impl WindowUDF { pub fn call(&self, args: Vec) -> Expr { let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); - Expr::WindowFunction(WindowFunction::new(fun, args)) + Expr::from(WindowFunction::new(fun, args)) } /// Returns this function's name @@ -179,7 +179,7 @@ impl WindowUDF { /// Returns the field of the final result of evaluating this window function. /// /// See [`WindowUDFImpl::field`] for more details. - pub fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + pub fn field(&self, field_args: WindowUDFFieldArgs) -> Result { self.inner.field(field_args) } @@ -236,7 +236,7 @@ where /// ``` /// # use std::any::Any; /// # use std::sync::LazyLock; -/// # use arrow::datatypes::{DataType, Field}; +/// # use arrow::datatypes::{DataType, Field, FieldRef}; /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame, ExprFunctionExt, Documentation}; /// # use datafusion_expr::{WindowUDFImpl, WindowUDF}; @@ -279,9 +279,9 @@ where /// ) -> Result> { /// unimplemented!() /// } -/// fn field(&self, field_args: WindowUDFFieldArgs) -> Result { -/// if let Some(DataType::Int32) = field_args.get_input_type(0) { -/// Ok(Field::new(field_args.name(), DataType::Int32, false)) +/// fn field(&self, field_args: WindowUDFFieldArgs) -> Result { +/// if let Some(DataType::Int32) = field_args.get_input_field(0).map(|f| f.data_type().clone()) { +/// Ok(Field::new(field_args.name(), DataType::Int32, false).into()) /// } else { /// plan_err!("smooth_it only accepts Int32 arguments") /// } @@ -386,12 +386,12 @@ pub trait WindowUDFImpl: Debug + Send + Sync { hasher.finish() } - /// The [`Field`] of the final result of evaluating this window function. + /// The [`FieldRef`] of the final result of evaluating this window function. /// /// Call `field_args.name()` to get the fully qualified name for defining - /// the [`Field`]. For a complete example see the implementation in the + /// the [`FieldRef`]. For a complete example see the implementation in the /// [Basic Example](WindowUDFImpl#basic-example) section. - fn field(&self, field_args: WindowUDFFieldArgs) -> Result; + fn field(&self, field_args: WindowUDFFieldArgs) -> Result; /// Allows the window UDF to define a custom result ordering. /// @@ -537,7 +537,7 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { hasher.finish() } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { self.inner.field(field_args) } @@ -588,7 +588,7 @@ pub mod window_doc_sections { #[cfg(test)] mod test { use crate::{PartitionEvaluator, WindowUDF, WindowUDFImpl}; - use arrow::datatypes::{DataType, Field}; + use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::Result; use datafusion_expr_common::signature::{Signature, Volatility}; use datafusion_functions_window_common::field::WindowUDFFieldArgs; @@ -630,7 +630,7 @@ mod test { ) -> Result> { unimplemented!() } - fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { unimplemented!() } } @@ -669,7 +669,7 @@ mod test { ) -> Result> { unimplemented!() } - fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { unimplemented!() } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 552ce1502d46..8950f5e450e0 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -21,7 +21,7 @@ use std::cmp::Ordering; use std::collections::{BTreeSet, HashSet}; use std::sync::Arc; -use crate::expr::{Alias, Sort, WildcardOptions, WindowFunction, WindowFunctionParams}; +use crate::expr::{Alias, Sort, WildcardOptions, WindowFunctionParams}; use crate::expr_rewriter::strip_outer_reference; use crate::{ and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, @@ -276,7 +276,7 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { Expr::Unnest(_) | Expr::ScalarVariable(_, _) | Expr::Alias(_) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::BinaryExpr { .. } | Expr::Like { .. } | Expr::SimilarTo { .. } @@ -579,7 +579,8 @@ pub fn group_window_expr_by_sort_keys( ) -> Result)>> { let mut result = vec![]; window_expr.into_iter().try_for_each(|expr| match &expr { - Expr::WindowFunction( WindowFunction{ params: WindowFunctionParams { partition_by, order_by, ..}, .. }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunctionParams{ partition_by, order_by, ..} = &window_fun.as_ref().params; let sort_key = generate_sort_key(partition_by, order_by)?; if let Some((_, values)) = result.iter_mut().find( |group: &&mut (WindowSortKey, Vec)| matches!(group, (key, _) if *key == sort_key), @@ -608,7 +609,7 @@ pub fn find_aggregate_exprs<'a>(exprs: impl IntoIterator) -> Ve /// Collect all deeply nested `Expr::WindowFunction`. They are returned in order of occurrence /// (depth first), with duplicates omitted. -pub fn find_window_exprs(exprs: &[Expr]) -> Vec { +pub fn find_window_exprs<'a>(exprs: impl IntoIterator) -> Vec { find_exprs_in_exprs(exprs, &|nested_expr| { matches!(nested_expr, Expr::WindowFunction { .. }) }) @@ -784,7 +785,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr( indexes.push(idx); } } - Expr::Literal(_) => { + Expr::Literal(_, _) => { indexes.push(usize::MAX); } _ => {} @@ -1263,9 +1264,11 @@ pub fn collect_subquery_cols( mod tests { use super::*; use crate::{ - col, cube, expr_vec_fmt, grouping_set, lit, rollup, - test::function_stub::max_udaf, test::function_stub::min_udaf, - test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition, + col, cube, + expr::WindowFunction, + expr_vec_fmt, grouping_set, lit, rollup, + test::function_stub::{max_udaf, min_udaf, sum_udaf}, + Cast, ExprFunctionExt, WindowFunctionDefinition, }; use arrow::datatypes::{UnionFields, UnionMode}; @@ -1279,19 +1282,19 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { - let max1 = Expr::WindowFunction(WindowFunction::new( + let max1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let max2 = Expr::WindowFunction(WindowFunction::new( + let max2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(WindowFunction::new( + let min3 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )); - let sum4 = Expr::WindowFunction(WindowFunction::new( + let sum4 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )); @@ -1309,25 +1312,25 @@ mod tests { let age_asc = Sort::new(col("age"), true, true); let name_desc = Sort::new(col("name"), false, true); let created_at_desc = Sort::new(col("created_at"), false, true); - let max1 = Expr::WindowFunction(WindowFunction::new( + let max1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let max2 = Expr::WindowFunction(WindowFunction::new( + let max2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(WindowFunction::new( + let min3 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let sum4 = Expr::WindowFunction(WindowFunction::new( + let sum4 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )) diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 8771b25137cf..b91bbddd8bac 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -160,7 +160,7 @@ impl WindowFrame { } else { WindowFrameUnits::Range }, - start_bound: WindowFrameBound::Preceding(ScalarValue::Null), + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)), end_bound: WindowFrameBound::CurrentRow, causal: strict, } @@ -351,11 +351,15 @@ impl WindowFrameBound { ast::WindowFrameBound::Preceding(Some(v)) => { Self::Preceding(convert_frame_bound_to_scalar_value(*v, units)?) } - ast::WindowFrameBound::Preceding(None) => Self::Preceding(ScalarValue::Null), + ast::WindowFrameBound::Preceding(None) => { + Self::Preceding(ScalarValue::UInt64(None)) + } ast::WindowFrameBound::Following(Some(v)) => { Self::Following(convert_frame_bound_to_scalar_value(*v, units)?) } - ast::WindowFrameBound::Following(None) => Self::Following(ScalarValue::Null), + ast::WindowFrameBound::Following(None) => { + Self::Following(ScalarValue::UInt64(None)) + } ast::WindowFrameBound::CurrentRow => Self::CurrentRow, }) } @@ -570,9 +574,9 @@ mod tests { #[test] fn test_window_frame_bound_creation() -> Result<()> { // Unbounded - test_bound!(Rows, None, ScalarValue::Null); - test_bound!(Groups, None, ScalarValue::Null); - test_bound!(Range, None, ScalarValue::Null); + test_bound!(Rows, None, ScalarValue::UInt64(None)); + test_bound!(Groups, None, ScalarValue::UInt64(None)); + test_bound!(Range, None, ScalarValue::UInt64(None)); // Number let number = Some(Box::new(ast::Expr::Value( diff --git a/datafusion/expr/src/window_state.rs b/datafusion/expr/src/window_state.rs index f1d0ead23ab1..a101b8fe4df6 100644 --- a/datafusion/expr/src/window_state.rs +++ b/datafusion/expr/src/window_state.rs @@ -193,11 +193,7 @@ impl WindowFrameContext { // UNBOUNDED PRECEDING WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0, WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => { - if idx >= n as usize { - idx - n as usize - } else { - 0 - } + idx.saturating_sub(n as usize) } WindowFrameBound::CurrentRow => idx, // UNBOUNDED FOLLOWING @@ -211,7 +207,7 @@ impl WindowFrameContext { } // ERRONEOUS FRAMES WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => { - return internal_err!("Rows should be Uint") + return internal_err!("Rows should be UInt64") } }; let end = match window_frame.end_bound { @@ -236,7 +232,7 @@ impl WindowFrameContext { } // ERRONEOUS FRAMES WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => { - return internal_err!("Rows should be Uint") + return internal_err!("Rows should be UInt64") } }; Ok(Range { start, end }) @@ -602,11 +598,7 @@ impl WindowFrameStateGroups { // Find the group index of the frame boundary: let group_idx = if SEARCH_SIDE { - if self.current_group_idx > delta { - self.current_group_idx - delta - } else { - 0 - } + self.current_group_idx.saturating_sub(delta) } else { self.current_group_idx + delta }; @@ -683,9 +675,9 @@ mod tests { (range_columns, sort_options) } - fn assert_expected( - expected_results: Vec<(Range, usize)>, + fn assert_group_ranges( window_frame: &Arc, + expected_results: Vec<(Range, usize)>, ) -> Result<()> { let mut window_frame_groups = WindowFrameStateGroups::default(); let (range_columns, _) = get_test_data(); @@ -705,6 +697,136 @@ mod tests { Ok(()) } + fn assert_frame_ranges( + window_frame: &Arc, + expected_results: Vec>, + ) -> Result<()> { + let mut window_frame_context = + WindowFrameContext::new(Arc::clone(window_frame), vec![]); + let (range_columns, _) = get_test_data(); + let n_row = range_columns[0].len(); + let mut last_range = Range { start: 0, end: 0 }; + for (idx, expected_range) in expected_results.into_iter().enumerate() { + let range = window_frame_context.calculate_range( + &range_columns, + &last_range, + n_row, + idx, + )?; + assert_eq!(range, expected_range); + last_range = range; + } + Ok(()) + } + + #[test] + fn test_default_window_frame_group_boundaries() -> Result<()> { + let window_frame = Arc::new(WindowFrame::new(None)); + assert_group_ranges( + &window_frame, + vec![ + (Range { start: 0, end: 9 }, 0), + (Range { start: 0, end: 9 }, 0), + (Range { start: 0, end: 9 }, 0), + (Range { start: 0, end: 9 }, 0), + (Range { start: 0, end: 9 }, 0), + (Range { start: 0, end: 9 }, 0), + (Range { start: 0, end: 9 }, 0), + (Range { start: 0, end: 9 }, 0), + (Range { start: 0, end: 9 }, 0), + ], + )?; + + assert_frame_ranges( + &window_frame, + vec![ + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + ], + )?; + + Ok(()) + } + + #[test] + fn test_unordered_window_frame_group_boundaries() -> Result<()> { + let window_frame = Arc::new(WindowFrame::new(Some(false))); + assert_group_ranges( + &window_frame, + vec![ + (Range { start: 0, end: 1 }, 0), + (Range { start: 0, end: 2 }, 1), + (Range { start: 0, end: 4 }, 2), + (Range { start: 0, end: 4 }, 2), + (Range { start: 0, end: 5 }, 3), + (Range { start: 0, end: 8 }, 4), + (Range { start: 0, end: 8 }, 4), + (Range { start: 0, end: 8 }, 4), + (Range { start: 0, end: 9 }, 5), + ], + )?; + + assert_frame_ranges( + &window_frame, + vec![ + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + Range { start: 0, end: 9 }, + ], + )?; + + Ok(()) + } + + #[test] + fn test_ordered_window_frame_group_boundaries() -> Result<()> { + let window_frame = Arc::new(WindowFrame::new(Some(true))); + assert_group_ranges( + &window_frame, + vec![ + (Range { start: 0, end: 1 }, 0), + (Range { start: 0, end: 2 }, 1), + (Range { start: 0, end: 4 }, 2), + (Range { start: 0, end: 4 }, 2), + (Range { start: 0, end: 5 }, 3), + (Range { start: 0, end: 8 }, 4), + (Range { start: 0, end: 8 }, 4), + (Range { start: 0, end: 8 }, 4), + (Range { start: 0, end: 9 }, 5), + ], + )?; + + assert_frame_ranges( + &window_frame, + vec![ + Range { start: 0, end: 1 }, + Range { start: 0, end: 2 }, + Range { start: 0, end: 3 }, + Range { start: 0, end: 4 }, + Range { start: 0, end: 5 }, + Range { start: 0, end: 6 }, + Range { start: 0, end: 7 }, + Range { start: 0, end: 8 }, + Range { start: 0, end: 9 }, + ], + )?; + + Ok(()) + } + #[test] fn test_window_frame_group_boundaries() -> Result<()> { let window_frame = Arc::new(WindowFrame::new_bounds( @@ -712,18 +834,20 @@ mod tests { WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))), WindowFrameBound::Following(ScalarValue::UInt64(Some(1))), )); - let expected_results = vec![ - (Range { start: 0, end: 2 }, 0), - (Range { start: 0, end: 4 }, 1), - (Range { start: 1, end: 5 }, 2), - (Range { start: 1, end: 5 }, 2), - (Range { start: 2, end: 8 }, 3), - (Range { start: 4, end: 9 }, 4), - (Range { start: 4, end: 9 }, 4), - (Range { start: 4, end: 9 }, 4), - (Range { start: 5, end: 9 }, 5), - ]; - assert_expected(expected_results, &window_frame) + assert_group_ranges( + &window_frame, + vec![ + (Range { start: 0, end: 2 }, 0), + (Range { start: 0, end: 4 }, 1), + (Range { start: 1, end: 5 }, 2), + (Range { start: 1, end: 5 }, 2), + (Range { start: 2, end: 8 }, 3), + (Range { start: 4, end: 9 }, 4), + (Range { start: 4, end: 9 }, 4), + (Range { start: 4, end: 9 }, 4), + (Range { start: 5, end: 9 }, 5), + ], + ) } #[test] @@ -733,18 +857,20 @@ mod tests { WindowFrameBound::Following(ScalarValue::UInt64(Some(1))), WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), )); - let expected_results = vec![ - (Range:: { start: 1, end: 4 }, 0), - (Range:: { start: 2, end: 5 }, 1), - (Range:: { start: 4, end: 8 }, 2), - (Range:: { start: 4, end: 8 }, 2), - (Range:: { start: 5, end: 9 }, 3), - (Range:: { start: 8, end: 9 }, 4), - (Range:: { start: 8, end: 9 }, 4), - (Range:: { start: 8, end: 9 }, 4), - (Range:: { start: 9, end: 9 }, 5), - ]; - assert_expected(expected_results, &window_frame) + assert_group_ranges( + &window_frame, + vec![ + (Range:: { start: 1, end: 4 }, 0), + (Range:: { start: 2, end: 5 }, 1), + (Range:: { start: 4, end: 8 }, 2), + (Range:: { start: 4, end: 8 }, 2), + (Range:: { start: 5, end: 9 }, 3), + (Range:: { start: 8, end: 9 }, 4), + (Range:: { start: 8, end: 9 }, 4), + (Range:: { start: 8, end: 9 }, 4), + (Range:: { start: 9, end: 9 }, 5), + ], + ) } #[test] @@ -754,17 +880,19 @@ mod tests { WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))), )); - let expected_results = vec![ - (Range:: { start: 0, end: 0 }, 0), - (Range:: { start: 0, end: 1 }, 1), - (Range:: { start: 0, end: 2 }, 2), - (Range:: { start: 0, end: 2 }, 2), - (Range:: { start: 1, end: 4 }, 3), - (Range:: { start: 2, end: 5 }, 4), - (Range:: { start: 2, end: 5 }, 4), - (Range:: { start: 2, end: 5 }, 4), - (Range:: { start: 4, end: 8 }, 5), - ]; - assert_expected(expected_results, &window_frame) + assert_group_ranges( + &window_frame, + vec![ + (Range:: { start: 0, end: 0 }, 0), + (Range:: { start: 0, end: 1 }, 1), + (Range:: { start: 0, end: 2 }, 2), + (Range:: { start: 0, end: 2 }, 2), + (Range:: { start: 1, end: 4 }, 3), + (Range:: { start: 2, end: 5 }, 4), + (Range:: { start: 2, end: 5 }, 4), + (Range:: { start: 2, end: 5 }, 4), + (Range:: { start: 4, end: 8 }, 5), + ], + ) } } diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml index 29f40df51444..a8335769ec29 100644 --- a/datafusion/ffi/Cargo.toml +++ b/datafusion/ffi/Cargo.toml @@ -44,7 +44,9 @@ arrow-schema = { workspace = true } async-ffi = { version = "0.5.0", features = ["abi_stable"] } async-trait = { workspace = true } datafusion = { workspace = true, default-features = false } +datafusion-functions-aggregate-common = { workspace = true } datafusion-proto = { workspace = true } +datafusion-proto-common = { workspace = true } futures = { workspace = true } log = { workspace = true } prost = { workspace = true } @@ -56,3 +58,4 @@ doc-comment = { workspace = true } [features] integration-tests = [] +tarpaulin_include = [] # Exists only to prevent warnings on stable and still have accurate coverage diff --git a/datafusion/ffi/src/arrow_wrappers.rs b/datafusion/ffi/src/arrow_wrappers.rs index a18e6df59bf1..7b3751dcae82 100644 --- a/datafusion/ffi/src/arrow_wrappers.rs +++ b/datafusion/ffi/src/arrow_wrappers.rs @@ -21,7 +21,8 @@ use abi_stable::StableAbi; use arrow::{ array::{make_array, ArrayRef}, datatypes::{Schema, SchemaRef}, - ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, + error::ArrowError, + ffi::{from_ffi, to_ffi, FFI_ArrowArray, FFI_ArrowSchema}, }; use log::error; @@ -36,7 +37,7 @@ impl From for WrappedSchema { let ffi_schema = match FFI_ArrowSchema::try_from(value.as_ref()) { Ok(s) => s, Err(e) => { - error!("Unable to convert DataFusion Schema to FFI_ArrowSchema in FFI_PlanProperties. {}", e); + error!("Unable to convert DataFusion Schema to FFI_ArrowSchema in FFI_PlanProperties. {e}"); FFI_ArrowSchema::empty() } }; @@ -44,16 +45,19 @@ impl From for WrappedSchema { WrappedSchema(ffi_schema) } } +/// Some functions are expected to always succeed, like getting the schema from a TableProvider. +/// Since going through the FFI always has the potential to fail, we need to catch these errors, +/// give the user a warning, and return some kind of result. In this case we default to an +/// empty schema. +#[cfg(not(tarpaulin_include))] +fn catch_df_schema_error(e: ArrowError) -> Schema { + error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {e}"); + Schema::empty() +} impl From for SchemaRef { fn from(value: WrappedSchema) -> Self { - let schema = match Schema::try_from(&value.0) { - Ok(s) => s, - Err(e) => { - error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {}", e); - Schema::empty() - } - }; + let schema = Schema::try_from(&value.0).unwrap_or_else(catch_df_schema_error); Arc::new(schema) } } @@ -71,7 +75,7 @@ pub struct WrappedArray { } impl TryFrom for ArrayRef { - type Error = arrow::error::ArrowError; + type Error = ArrowError; fn try_from(value: WrappedArray) -> Result { let data = unsafe { from_ffi(value.array, &value.schema.0)? }; @@ -79,3 +83,14 @@ impl TryFrom for ArrayRef { Ok(make_array(data)) } } + +impl TryFrom<&ArrayRef> for WrappedArray { + type Error = ArrowError; + + fn try_from(array: &ArrayRef) -> Result { + let (array, schema) = to_ffi(&array.to_data())?; + let schema = WrappedSchema(schema); + + Ok(WrappedArray { array, schema }) + } +} diff --git a/datafusion/ffi/src/lib.rs b/datafusion/ffi/src/lib.rs index d877e182a1d8..ff641e8315c7 100644 --- a/datafusion/ffi/src/lib.rs +++ b/datafusion/ffi/src/lib.rs @@ -34,8 +34,10 @@ pub mod schema_provider; pub mod session_config; pub mod table_provider; pub mod table_source; +pub mod udaf; pub mod udf; pub mod udtf; +pub mod udwf; pub mod util; pub mod volatility; diff --git a/datafusion/ffi/src/plan_properties.rs b/datafusion/ffi/src/plan_properties.rs index 3592c16b8fab..832e82dda35b 100644 --- a/datafusion/ffi/src/plan_properties.rs +++ b/datafusion/ffi/src/plan_properties.rs @@ -188,12 +188,12 @@ impl TryFrom for PlanProperties { let proto_output_ordering = PhysicalSortExprNodeCollection::decode(df_result!(ffi_orderings)?.as_ref()) .map_err(|e| DataFusionError::External(Box::new(e)))?; - let orderings = Some(parse_physical_sort_exprs( + let sort_exprs = parse_physical_sort_exprs( &proto_output_ordering.physical_sort_expr_nodes, &default_ctx, &schema, &codex, - )?); + )?; let partitioning_vec = unsafe { df_result!((ffi_props.output_partitioning)(&ffi_props))? }; @@ -211,11 +211,10 @@ impl TryFrom for PlanProperties { .to_string(), ))?; - let eq_properties = match orderings { - Some(ordering) => { - EquivalenceProperties::new_with_orderings(Arc::new(schema), &[ordering]) - } - None => EquivalenceProperties::new(Arc::new(schema)), + let eq_properties = if sort_exprs.is_empty() { + EquivalenceProperties::new(Arc::new(schema)) + } else { + EquivalenceProperties::new_with_orderings(Arc::new(schema), [sort_exprs]) }; let emission_type: EmissionType = @@ -300,7 +299,7 @@ impl From for EmissionType { #[cfg(test)] mod tests { - use datafusion::physical_plan::Partitioning; + use datafusion::{physical_expr::PhysicalSortExpr, physical_plan::Partitioning}; use super::*; @@ -310,9 +309,13 @@ mod tests { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + let mut eqp = EquivalenceProperties::new(Arc::clone(&schema)); + let _ = eqp.reorder([PhysicalSortExpr::new_default( + datafusion::physical_plan::expressions::col("a", &schema)?, + )]); let original_props = PlanProperties::new( - EquivalenceProperties::new(schema), - Partitioning::UnknownPartitioning(3), + eqp, + Partitioning::RoundRobinBatch(3), EmissionType::Incremental, Boundedness::Bounded, ); @@ -321,7 +324,7 @@ mod tests { let foreign_props: PlanProperties = local_props_ptr.try_into()?; - assert!(format!("{:?}", foreign_props) == format!("{:?}", original_props)); + assert_eq!(format!("{foreign_props:?}"), format!("{original_props:?}")); Ok(()) } diff --git a/datafusion/ffi/src/record_batch_stream.rs b/datafusion/ffi/src/record_batch_stream.rs index 939c4050028c..78d65a816fcc 100644 --- a/datafusion/ffi/src/record_batch_stream.rs +++ b/datafusion/ffi/src/record_batch_stream.rs @@ -196,3 +196,49 @@ impl Stream for FFI_RecordBatchStream { } } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::{ + common::record_batch, error::Result, execution::SendableRecordBatchStream, + test_util::bounded_stream, + }; + + use super::FFI_RecordBatchStream; + use futures::StreamExt; + + #[tokio::test] + async fn test_round_trip_record_batch_stream() -> Result<()> { + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 3]), + ("b", Float64, vec![Some(4.0), None, Some(5.0)]) + )?; + let original_rbs = bounded_stream(record_batch.clone(), 1); + + let ffi_rbs: FFI_RecordBatchStream = original_rbs.into(); + let mut ffi_rbs: SendableRecordBatchStream = Box::pin(ffi_rbs); + + let schema = ffi_rbs.schema(); + assert_eq!( + schema, + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Float64, true) + ])) + ); + + let batch = ffi_rbs.next().await; + assert!(batch.is_some()); + assert!(batch.as_ref().unwrap().is_ok()); + assert_eq!(batch.unwrap().unwrap(), record_batch); + + // There should only be one batch + let no_batch = ffi_rbs.next().await; + assert!(no_batch.is_none()); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/tests/async_provider.rs b/datafusion/ffi/src/tests/async_provider.rs index cf05d596308f..60434a7dda12 100644 --- a/datafusion/ffi/src/tests/async_provider.rs +++ b/datafusion/ffi/src/tests/async_provider.rs @@ -260,7 +260,7 @@ impl Stream for AsyncTestRecordBatchStream { if let Err(e) = this.batch_request.try_send(true) { return std::task::Poll::Ready(Some(Err(DataFusionError::Execution( - format!("Unable to send batch request, {}", e), + format!("Unable to send batch request, {e}"), )))); } @@ -270,7 +270,7 @@ impl Stream for AsyncTestRecordBatchStream { None => std::task::Poll::Ready(None), }, Err(e) => std::task::Poll::Ready(Some(Err(DataFusionError::Execution( - format!("Unable receive record batch: {}", e), + format!("Unable to receive record batch: {e}"), )))), } } diff --git a/datafusion/ffi/src/tests/mod.rs b/datafusion/ffi/src/tests/mod.rs index 7a36ee52bdb4..db596f51fcd9 100644 --- a/datafusion/ffi/src/tests/mod.rs +++ b/datafusion/ffi/src/tests/mod.rs @@ -29,6 +29,10 @@ use catalog::create_catalog_provider; use crate::{catalog_provider::FFI_CatalogProvider, udtf::FFI_TableFunction}; +use crate::udaf::FFI_AggregateUDF; + +use crate::udwf::FFI_WindowUDF; + use super::{table_provider::FFI_TableProvider, udf::FFI_ScalarUDF}; use arrow::array::RecordBatch; use async_provider::create_async_table_provider; @@ -37,7 +41,10 @@ use datafusion::{ common::record_batch, }; use sync_provider::create_sync_table_provider; -use udf_udaf_udwf::{create_ffi_abs_func, create_ffi_random_func, create_ffi_table_func}; +use udf_udaf_udwf::{ + create_ffi_abs_func, create_ffi_random_func, create_ffi_rank_func, + create_ffi_stddev_func, create_ffi_sum_func, create_ffi_table_func, +}; mod async_provider; pub mod catalog; @@ -65,6 +72,14 @@ pub struct ForeignLibraryModule { pub create_table_function: extern "C" fn() -> FFI_TableFunction, + /// Create an aggregate UDAF using sum + pub create_sum_udaf: extern "C" fn() -> FFI_AggregateUDF, + + /// Createa grouping UDAF using stddev + pub create_stddev_udaf: extern "C" fn() -> FFI_AggregateUDF, + + pub create_rank_udwf: extern "C" fn() -> FFI_WindowUDF, + pub version: extern "C" fn() -> u64, } @@ -112,6 +127,9 @@ pub fn get_foreign_library_module() -> ForeignLibraryModuleRef { create_scalar_udf: create_ffi_abs_func, create_nullary_udf: create_ffi_random_func, create_table_function: create_ffi_table_func, + create_sum_udaf: create_ffi_sum_func, + create_stddev_udaf: create_ffi_stddev_func, + create_rank_udwf: create_ffi_rank_func, version: super::version, } .leak_into_prefix() diff --git a/datafusion/ffi/src/tests/udf_udaf_udwf.rs b/datafusion/ffi/src/tests/udf_udaf_udwf.rs index c3cb1bcc3533..55e31ef3ab77 100644 --- a/datafusion/ffi/src/tests/udf_udaf_udwf.rs +++ b/datafusion/ffi/src/tests/udf_udaf_udwf.rs @@ -15,12 +15,17 @@ // specific language governing permissions and limitations // under the License. -use crate::{udf::FFI_ScalarUDF, udtf::FFI_TableFunction}; +use crate::{ + udaf::FFI_AggregateUDF, udf::FFI_ScalarUDF, udtf::FFI_TableFunction, + udwf::FFI_WindowUDF, +}; use datafusion::{ catalog::TableFunctionImpl, functions::math::{abs::AbsFunc, random::RandomFunc}, + functions_aggregate::{stddev::Stddev, sum::Sum}, functions_table::generate_series::RangeFunc, - logical_expr::ScalarUDF, + functions_window::rank::Rank, + logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}, }; use std::sync::Arc; @@ -42,3 +47,27 @@ pub(crate) extern "C" fn create_ffi_table_func() -> FFI_TableFunction { FFI_TableFunction::new(udtf, None) } + +pub(crate) extern "C" fn create_ffi_sum_func() -> FFI_AggregateUDF { + let udaf: Arc = Arc::new(Sum::new().into()); + + udaf.into() +} + +pub(crate) extern "C" fn create_ffi_stddev_func() -> FFI_AggregateUDF { + let udaf: Arc = Arc::new(Stddev::new().into()); + + udaf.into() +} + +pub(crate) extern "C" fn create_ffi_rank_func() -> FFI_WindowUDF { + let udwf: Arc = Arc::new( + Rank::new( + "rank_demo".to_string(), + datafusion::functions_window::rank::RankType::Basic, + ) + .into(), + ); + + udwf.into() +} diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs new file mode 100644 index 000000000000..80b872159f48 --- /dev/null +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -0,0 +1,366 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, ops::Deref}; + +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; +use arrow::{array::ArrayRef, error::ArrowError}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::Accumulator, + scalar::ScalarValue, +}; +use prost::Message; + +use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return}; + +/// A stable struct for sharing [`Accumulator`] across FFI boundaries. +/// For an explanation of each field, see the corresponding function +/// defined in [`Accumulator`]. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_Accumulator { + pub update_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + ) -> RResult<(), RString>, + + // Evaluate and return a ScalarValues as protobuf bytes + pub evaluate: + unsafe extern "C" fn(accumulator: &mut Self) -> RResult, RString>, + + pub size: unsafe extern "C" fn(accumulator: &Self) -> usize, + + pub state: + unsafe extern "C" fn(accumulator: &mut Self) -> RResult>, RString>, + + pub merge_batch: unsafe extern "C" fn( + accumulator: &mut Self, + states: RVec, + ) -> RResult<(), RString>, + + pub retract_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + ) -> RResult<(), RString>, + + pub supports_retract_batch: bool, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(accumulator: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the accumulator. + /// A [`ForeignAccumulator`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_Accumulator {} +unsafe impl Sync for FFI_Accumulator {} + +pub struct AccumulatorPrivateData { + pub accumulator: Box, +} + +impl FFI_Accumulator { + #[inline] + unsafe fn inner_mut(&mut self) -> &mut Box { + let private_data = self.private_data as *mut AccumulatorPrivateData; + &mut (*private_data).accumulator + } + + #[inline] + unsafe fn inner(&self) -> &dyn Accumulator { + let private_data = self.private_data as *const AccumulatorPrivateData; + (*private_data).accumulator.deref() + } +} + +unsafe extern "C" fn update_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, + values: RVec, +) -> RResult<(), RString> { + let accumulator = accumulator.inner_mut(); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + rresult!(accumulator.update_batch(&values_arrays)) +} + +unsafe extern "C" fn evaluate_fn_wrapper( + accumulator: &mut FFI_Accumulator, +) -> RResult, RString> { + let accumulator = accumulator.inner_mut(); + + let scalar_result = rresult_return!(accumulator.evaluate()); + let proto_result: datafusion_proto::protobuf::ScalarValue = + rresult_return!((&scalar_result).try_into()); + + RResult::ROk(proto_result.encode_to_vec().into()) +} + +unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_Accumulator) -> usize { + accumulator.inner().size() +} + +unsafe extern "C" fn state_fn_wrapper( + accumulator: &mut FFI_Accumulator, +) -> RResult>, RString> { + let accumulator = accumulator.inner_mut(); + + let state = rresult_return!(accumulator.state()); + let state = state + .into_iter() + .map(|state_val| { + datafusion_proto::protobuf::ScalarValue::try_from(&state_val) + .map_err(DataFusionError::from) + .map(|v| RVec::from(v.encode_to_vec())) + }) + .collect::>>() + .map(|state_vec| state_vec.into()); + + rresult!(state) +} + +unsafe extern "C" fn merge_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, + states: RVec, +) -> RResult<(), RString> { + let accumulator = accumulator.inner_mut(); + + let states = rresult_return!(states + .into_iter() + .map(|state| ArrayRef::try_from(state).map_err(DataFusionError::from)) + .collect::>>()); + + rresult!(accumulator.merge_batch(&states)) +} + +unsafe extern "C" fn retract_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, + values: RVec, +) -> RResult<(), RString> { + let accumulator = accumulator.inner_mut(); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + rresult!(accumulator.retract_batch(&values_arrays)) +} + +unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) { + let private_data = + Box::from_raw(accumulator.private_data as *mut AccumulatorPrivateData); + drop(private_data); +} + +impl From> for FFI_Accumulator { + fn from(accumulator: Box) -> Self { + let supports_retract_batch = accumulator.supports_retract_batch(); + let private_data = AccumulatorPrivateData { accumulator }; + + Self { + update_batch: update_batch_fn_wrapper, + evaluate: evaluate_fn_wrapper, + size: size_fn_wrapper, + state: state_fn_wrapper, + merge_batch: merge_batch_fn_wrapper, + retract_batch: retract_batch_fn_wrapper, + supports_retract_batch, + release: release_fn_wrapper, + private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, + } + } +} + +impl Drop for FFI_Accumulator { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignAccumulator is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_Accumulator. +#[derive(Debug)] +pub struct ForeignAccumulator { + accumulator: FFI_Accumulator, +} + +unsafe impl Send for ForeignAccumulator {} +unsafe impl Sync for ForeignAccumulator {} + +impl From for ForeignAccumulator { + fn from(accumulator: FFI_Accumulator) -> Self { + Self { accumulator } + } +} + +impl Accumulator for ForeignAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + df_result!((self.accumulator.update_batch)( + &mut self.accumulator, + values.into() + )) + } + } + + fn evaluate(&mut self) -> Result { + unsafe { + let scalar_bytes = + df_result!((self.accumulator.evaluate)(&mut self.accumulator))?; + + let proto_scalar = + datafusion_proto::protobuf::ScalarValue::decode(scalar_bytes.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + ScalarValue::try_from(&proto_scalar).map_err(DataFusionError::from) + } + } + + fn size(&self) -> usize { + unsafe { (self.accumulator.size)(&self.accumulator) } + } + + fn state(&mut self) -> Result> { + unsafe { + let state_protos = + df_result!((self.accumulator.state)(&mut self.accumulator))?; + + state_protos + .into_iter() + .map(|proto_bytes| { + datafusion_proto::protobuf::ScalarValue::decode(proto_bytes.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e))) + .and_then(|proto_value| { + ScalarValue::try_from(&proto_value) + .map_err(DataFusionError::from) + }) + }) + .collect::>>() + } + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + unsafe { + let states = states + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + df_result!((self.accumulator.merge_batch)( + &mut self.accumulator, + states.into() + )) + } + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + df_result!((self.accumulator.retract_batch)( + &mut self.accumulator, + values.into() + )) + } + } + + fn supports_retract_batch(&self) -> bool { + self.accumulator.supports_retract_batch + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{make_array, Array}; + use datafusion::{ + common::create_array, error::Result, + functions_aggregate::average::AvgAccumulator, logical_expr::Accumulator, + scalar::ScalarValue, + }; + + use super::{FFI_Accumulator, ForeignAccumulator}; + + #[test] + fn test_foreign_avg_accumulator() -> Result<()> { + let original_accum = AvgAccumulator::default(); + let original_size = original_accum.size(); + let original_supports_retract = original_accum.supports_retract_batch(); + + let boxed_accum: Box = Box::new(original_accum); + let ffi_accum: FFI_Accumulator = boxed_accum.into(); + let mut foreign_accum: ForeignAccumulator = ffi_accum.into(); + + // Send in an array to average. There are 5 values and it should average to 30.0 + let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); + foreign_accum.update_batch(&[values])?; + + let avg = foreign_accum.evaluate()?; + assert_eq!(avg, ScalarValue::Float64(Some(30.0))); + + let state = foreign_accum.state()?; + assert_eq!(state.len(), 2); + assert_eq!(state[0], ScalarValue::UInt64(Some(5))); + assert_eq!(state[1], ScalarValue::Float64(Some(150.0))); + + // To verify merging batches works, create a second state to add in + // This should cause our average to go down to 25.0 + let second_states = vec![ + make_array(create_array!(UInt64, vec![1]).to_data()), + make_array(create_array!(Float64, vec![0.0]).to_data()), + ]; + + foreign_accum.merge_batch(&second_states)?; + let avg = foreign_accum.evaluate()?; + assert_eq!(avg, ScalarValue::Float64(Some(25.0))); + + // If we remove a batch that is equivalent to the state we added + // we should go back to our original value of 30.0 + let values = create_array!(Float64, vec![0.0]); + foreign_accum.retract_batch(&[values])?; + let avg = foreign_accum.evaluate()?; + assert_eq!(avg, ScalarValue::Float64(Some(30.0))); + + assert_eq!(original_size, foreign_accum.size()); + assert_eq!( + original_supports_retract, + foreign_accum.supports_retract_batch() + ); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs new file mode 100644 index 000000000000..874a2ac8b82e --- /dev/null +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -0,0 +1,193 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::arrow_wrappers::WrappedSchema; +use abi_stable::{ + std_types::{RString, RVec}, + StableAbi, +}; +use arrow::{datatypes::Schema, ffi::FFI_ArrowSchema}; +use arrow_schema::FieldRef; +use datafusion::{ + error::DataFusionError, + logical_expr::function::AccumulatorArgs, + physical_expr::{PhysicalExpr, PhysicalSortExpr}, + prelude::SessionContext, +}; +use datafusion_proto::{ + physical_plan::{ + from_proto::{parse_physical_exprs, parse_physical_sort_exprs}, + to_proto::{serialize_physical_exprs, serialize_physical_sort_exprs}, + DefaultPhysicalExtensionCodec, + }, + protobuf::PhysicalAggregateExprNode, +}; +use prost::Message; + +/// A stable struct for sharing [`AccumulatorArgs`] across FFI boundaries. +/// For an explanation of each field, see the corresponding field +/// defined in [`AccumulatorArgs`]. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_AccumulatorArgs { + return_field: WrappedSchema, + schema: WrappedSchema, + is_reversed: bool, + name: RString, + physical_expr_def: RVec, +} + +impl TryFrom> for FFI_AccumulatorArgs { + type Error = DataFusionError; + + fn try_from(args: AccumulatorArgs) -> Result { + let return_field = + WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?); + let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?); + + let codec = DefaultPhysicalExtensionCodec {}; + let ordering_req = + serialize_physical_sort_exprs(args.order_bys.to_owned(), &codec)?; + + let expr = serialize_physical_exprs(args.exprs, &codec)?; + + let physical_expr_def = PhysicalAggregateExprNode { + expr, + ordering_req, + distinct: args.is_distinct, + ignore_nulls: args.ignore_nulls, + fun_definition: None, + aggregate_function: None, + }; + let physical_expr_def = physical_expr_def.encode_to_vec().into(); + + Ok(Self { + return_field, + schema, + is_reversed: args.is_reversed, + name: args.name.into(), + physical_expr_def, + }) + } +} + +/// This struct mirrors AccumulatorArgs except that it contains owned data. +/// It is necessary to create this struct so that we can parse the protobuf +/// data across the FFI boundary and turn it into owned data that +/// AccumulatorArgs can then reference. +pub struct ForeignAccumulatorArgs { + pub return_field: FieldRef, + pub schema: Schema, + pub ignore_nulls: bool, + pub order_bys: Vec, + pub is_reversed: bool, + pub name: String, + pub is_distinct: bool, + pub exprs: Vec>, +} + +impl TryFrom for ForeignAccumulatorArgs { + type Error = DataFusionError; + + fn try_from(value: FFI_AccumulatorArgs) -> Result { + let proto_def = + PhysicalAggregateExprNode::decode(value.physical_expr_def.as_ref()) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + + let return_field = Arc::new((&value.return_field.0).try_into()?); + let schema = Schema::try_from(&value.schema.0)?; + + let default_ctx = SessionContext::new(); + let codex = DefaultPhysicalExtensionCodec {}; + + let order_bys = parse_physical_sort_exprs( + &proto_def.ordering_req, + &default_ctx, + &schema, + &codex, + )?; + + let exprs = parse_physical_exprs(&proto_def.expr, &default_ctx, &schema, &codex)?; + + Ok(Self { + return_field, + schema, + ignore_nulls: proto_def.ignore_nulls, + order_bys, + is_reversed: value.is_reversed, + name: value.name.to_string(), + is_distinct: proto_def.distinct, + exprs, + }) + } +} + +impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> { + fn from(value: &'a ForeignAccumulatorArgs) -> Self { + Self { + return_field: Arc::clone(&value.return_field), + schema: &value.schema, + ignore_nulls: value.ignore_nulls, + order_bys: &value.order_bys, + is_reversed: value.is_reversed, + name: value.name.as_str(), + is_distinct: value.is_distinct, + exprs: &value.exprs, + } + } +} + +#[cfg(test)] +mod tests { + use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::{ + error::Result, logical_expr::function::AccumulatorArgs, + physical_expr::PhysicalSortExpr, physical_plan::expressions::col, + }; + + #[test] + fn test_round_trip_accumulator_args() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let orig_args = AccumulatorArgs { + return_field: Field::new("f", DataType::Float64, true).into(), + schema: &schema, + ignore_nulls: false, + order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)], + is_reversed: false, + name: "round_trip", + is_distinct: true, + exprs: &[col("a", &schema)?], + }; + let orig_str = format!("{orig_args:?}"); + + let ffi_args: FFI_AccumulatorArgs = orig_args.try_into()?; + let foreign_args: ForeignAccumulatorArgs = ffi_args.try_into()?; + let round_trip_args: AccumulatorArgs = (&foreign_args).into(); + + let round_trip_str = format!("{round_trip_args:?}"); + + // Since AccumulatorArgs doesn't implement Eq, simply compare + // the debug strings. + assert_eq!(orig_str, round_trip_str); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs new file mode 100644 index 000000000000..58a18c69db7c --- /dev/null +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -0,0 +1,513 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, ops::Deref, sync::Arc}; + +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + df_result, rresult, rresult_return, +}; +use abi_stable::{ + std_types::{ROption, RResult, RString, RVec}, + StableAbi, +}; +use arrow::{ + array::{Array, ArrayRef, BooleanArray}, + error::ArrowError, + ffi::to_ffi, +}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::{EmitTo, GroupsAccumulator}, +}; + +/// A stable struct for sharing [`GroupsAccumulator`] across FFI boundaries. +/// For an explanation of each field, see the corresponding function +/// defined in [`GroupsAccumulator`]. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_GroupsAccumulator { + pub update_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, + ) -> RResult<(), RString>, + + // Evaluate and return a ScalarValues as protobuf bytes + pub evaluate: unsafe extern "C" fn( + accumulator: &mut Self, + emit_to: FFI_EmitTo, + ) -> RResult, + + pub size: unsafe extern "C" fn(accumulator: &Self) -> usize, + + pub state: unsafe extern "C" fn( + accumulator: &mut Self, + emit_to: FFI_EmitTo, + ) -> RResult, RString>, + + pub merge_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, + ) -> RResult<(), RString>, + + pub convert_to_state: unsafe extern "C" fn( + accumulator: &Self, + values: RVec, + opt_filter: ROption, + ) + -> RResult, RString>, + + pub supports_convert_to_state: bool, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(accumulator: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the accumulator. + /// A [`ForeignGroupsAccumulator`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_GroupsAccumulator {} +unsafe impl Sync for FFI_GroupsAccumulator {} + +pub struct GroupsAccumulatorPrivateData { + pub accumulator: Box, +} + +impl FFI_GroupsAccumulator { + #[inline] + unsafe fn inner_mut(&mut self) -> &mut Box { + let private_data = self.private_data as *mut GroupsAccumulatorPrivateData; + &mut (*private_data).accumulator + } + + #[inline] + unsafe fn inner(&self) -> &dyn GroupsAccumulator { + let private_data = self.private_data as *const GroupsAccumulatorPrivateData; + (*private_data).accumulator.deref() + } +} + +fn process_values(values: RVec) -> Result>> { + values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>() +} + +/// Convert C-typed opt_filter into the internal type. +fn process_opt_filter(opt_filter: ROption) -> Result> { + opt_filter + .into_option() + .map(|filter| { + ArrayRef::try_from(filter) + .map_err(DataFusionError::from) + .map(|arr| BooleanArray::from(arr.into_data())) + }) + .transpose() +} + +unsafe extern "C" fn update_batch_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, +) -> RResult<(), RString> { + let accumulator = accumulator.inner_mut(); + let values = rresult_return!(process_values(values)); + let group_indices: Vec = group_indices.into_iter().collect(); + let opt_filter = rresult_return!(process_opt_filter(opt_filter)); + + rresult!(accumulator.update_batch( + &values, + &group_indices, + opt_filter.as_ref(), + total_num_groups + )) +} + +unsafe extern "C" fn evaluate_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + emit_to: FFI_EmitTo, +) -> RResult { + let accumulator = accumulator.inner_mut(); + + let result = rresult_return!(accumulator.evaluate(emit_to.into())); + + rresult!(WrappedArray::try_from(&result)) +} + +unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> usize { + let accumulator = accumulator.inner(); + accumulator.size() +} + +unsafe extern "C" fn state_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + emit_to: FFI_EmitTo, +) -> RResult, RString> { + let accumulator = accumulator.inner_mut(); + + let state = rresult_return!(accumulator.state(emit_to.into())); + rresult!(state + .into_iter() + .map(|arr| WrappedArray::try_from(&arr).map_err(DataFusionError::from)) + .collect::>>()) +} + +unsafe extern "C" fn merge_batch_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, +) -> RResult<(), RString> { + let accumulator = accumulator.inner_mut(); + let values = rresult_return!(process_values(values)); + let group_indices: Vec = group_indices.into_iter().collect(); + let opt_filter = rresult_return!(process_opt_filter(opt_filter)); + + rresult!(accumulator.merge_batch( + &values, + &group_indices, + opt_filter.as_ref(), + total_num_groups + )) +} + +unsafe extern "C" fn convert_to_state_fn_wrapper( + accumulator: &FFI_GroupsAccumulator, + values: RVec, + opt_filter: ROption, +) -> RResult, RString> { + let accumulator = accumulator.inner(); + let values = rresult_return!(process_values(values)); + let opt_filter = rresult_return!(process_opt_filter(opt_filter)); + let state = + rresult_return!(accumulator.convert_to_state(&values, opt_filter.as_ref())); + + rresult!(state + .iter() + .map(|arr| WrappedArray::try_from(arr).map_err(DataFusionError::from)) + .collect::>>()) +} + +unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_GroupsAccumulator) { + let private_data = + Box::from_raw(accumulator.private_data as *mut GroupsAccumulatorPrivateData); + drop(private_data); +} + +impl From> for FFI_GroupsAccumulator { + fn from(accumulator: Box) -> Self { + let supports_convert_to_state = accumulator.supports_convert_to_state(); + let private_data = GroupsAccumulatorPrivateData { accumulator }; + + Self { + update_batch: update_batch_fn_wrapper, + evaluate: evaluate_fn_wrapper, + size: size_fn_wrapper, + state: state_fn_wrapper, + merge_batch: merge_batch_fn_wrapper, + convert_to_state: convert_to_state_fn_wrapper, + supports_convert_to_state, + + release: release_fn_wrapper, + private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, + } + } +} + +impl Drop for FFI_GroupsAccumulator { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignGroupsAccumulator is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_GroupsAccumulator. +#[derive(Debug)] +pub struct ForeignGroupsAccumulator { + accumulator: FFI_GroupsAccumulator, +} + +unsafe impl Send for ForeignGroupsAccumulator {} +unsafe impl Sync for ForeignGroupsAccumulator {} + +impl From for ForeignGroupsAccumulator { + fn from(accumulator: FFI_GroupsAccumulator) -> Self { + Self { accumulator } + } +} + +impl GroupsAccumulator for ForeignGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + let group_indices = group_indices.iter().cloned().collect(); + let opt_filter = opt_filter + .map(|bool_array| to_ffi(&bool_array.to_data())) + .transpose()? + .map(|(array, schema)| WrappedArray { + array, + schema: WrappedSchema(schema), + }) + .into(); + + df_result!((self.accumulator.update_batch)( + &mut self.accumulator, + values.into(), + group_indices, + opt_filter, + total_num_groups + )) + } + } + + fn size(&self) -> usize { + unsafe { (self.accumulator.size)(&self.accumulator) } + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + unsafe { + let return_array = df_result!((self.accumulator.evaluate)( + &mut self.accumulator, + emit_to.into() + ))?; + + return_array.try_into().map_err(DataFusionError::from) + } + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + unsafe { + let returned_arrays = df_result!((self.accumulator.state)( + &mut self.accumulator, + emit_to.into() + ))?; + + returned_arrays + .into_iter() + .map(|wrapped_array| { + wrapped_array.try_into().map_err(DataFusionError::from) + }) + .collect::>>() + } + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + let group_indices = group_indices.iter().cloned().collect(); + let opt_filter = opt_filter + .map(|bool_array| to_ffi(&bool_array.to_data())) + .transpose()? + .map(|(array, schema)| WrappedArray { + array, + schema: WrappedSchema(schema), + }) + .into(); + + df_result!((self.accumulator.merge_batch)( + &mut self.accumulator, + values.into(), + group_indices, + opt_filter, + total_num_groups + )) + } + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + + let opt_filter = opt_filter + .map(|bool_array| to_ffi(&bool_array.to_data())) + .transpose()? + .map(|(array, schema)| WrappedArray { + array, + schema: WrappedSchema(schema), + }) + .into(); + + let returned_array = df_result!((self.accumulator.convert_to_state)( + &self.accumulator, + values, + opt_filter + ))?; + + returned_array + .into_iter() + .map(|arr| arr.try_into().map_err(DataFusionError::from)) + .collect() + } + } + + fn supports_convert_to_state(&self) -> bool { + self.accumulator.supports_convert_to_state + } +} + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub enum FFI_EmitTo { + All, + First(usize), +} + +impl From for FFI_EmitTo { + fn from(value: EmitTo) -> Self { + match value { + EmitTo::All => Self::All, + EmitTo::First(v) => Self::First(v), + } + } +} + +impl From for EmitTo { + fn from(value: FFI_EmitTo) -> Self { + match value { + FFI_EmitTo::All => Self::All, + FFI_EmitTo::First(v) => Self::First(v), + } + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{make_array, Array, BooleanArray}; + use datafusion::{ + common::create_array, + error::Result, + logical_expr::{EmitTo, GroupsAccumulator}, + }; + use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; + + use super::{FFI_EmitTo, FFI_GroupsAccumulator, ForeignGroupsAccumulator}; + + #[test] + fn test_foreign_avg_accumulator() -> Result<()> { + let boxed_accum: Box = + Box::new(BooleanGroupsAccumulator::new(|a, b| a && b, true)); + let ffi_accum: FFI_GroupsAccumulator = boxed_accum.into(); + let mut foreign_accum: ForeignGroupsAccumulator = ffi_accum.into(); + + // Send in an array to evaluate. We want a mean of 30 and standard deviation of 4. + let values = create_array!(Boolean, vec![true, true, true, false, true, true]); + let opt_filter = + create_array!(Boolean, vec![true, true, true, true, false, false]); + foreign_accum.update_batch( + &[values], + &[0, 0, 1, 1, 2, 2], + Some(opt_filter.as_ref()), + 3, + )?; + + let groups_bool = foreign_accum.evaluate(EmitTo::All)?; + let groups_bool = groups_bool.as_any().downcast_ref::().unwrap(); + + assert_eq!( + groups_bool, + create_array!(Boolean, vec![Some(true), Some(false), None]).as_ref() + ); + + let state = foreign_accum.state(EmitTo::All)?; + assert_eq!(state.len(), 1); + + // To verify merging batches works, create a second state to add in + // This should cause our average to go down to 25.0 + let second_states = + vec![make_array(create_array!(Boolean, vec![false]).to_data())]; + + let opt_filter = create_array!(Boolean, vec![true]); + foreign_accum.merge_batch(&second_states, &[0], Some(opt_filter.as_ref()), 1)?; + let groups_bool = foreign_accum.evaluate(EmitTo::All)?; + assert_eq!(groups_bool.len(), 1); + assert_eq!( + groups_bool.as_ref(), + make_array(create_array!(Boolean, vec![false]).to_data()).as_ref() + ); + + let values = create_array!(Boolean, vec![false]); + let opt_filter = create_array!(Boolean, vec![true]); + let groups_bool = + foreign_accum.convert_to_state(&[values], Some(opt_filter.as_ref()))?; + + assert_eq!( + groups_bool[0].as_ref(), + make_array(create_array!(Boolean, vec![false]).to_data()).as_ref() + ); + + Ok(()) + } + + fn test_emit_to_round_trip(value: EmitTo) -> Result<()> { + let ffi_value: FFI_EmitTo = value.into(); + let round_trip_value: EmitTo = ffi_value.into(); + + assert_eq!(value, round_trip_value); + Ok(()) + } + + /// This test ensures all enum values are properly translated + #[test] + fn test_all_emit_to_round_trip() -> Result<()> { + test_emit_to_round_trip(EmitTo::All)?; + test_emit_to_round_trip(EmitTo::First(10))?; + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs new file mode 100644 index 000000000000..eb7a408ab178 --- /dev/null +++ b/datafusion/ffi/src/udaf/mod.rs @@ -0,0 +1,725 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, sync::Arc}; + +use abi_stable::{ + std_types::{ROption, RResult, RStr, RString, RVec}, + StableAbi, +}; +use accumulator::{FFI_Accumulator, ForeignAccumulator}; +use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; +use arrow::datatypes::{DataType, Field}; +use arrow::ffi::FFI_ArrowSchema; +use arrow_schema::FieldRef; +use datafusion::{ + error::DataFusionError, + logical_expr::{ + function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, + type_coercion::functions::fields_with_aggregate_udf, + utils::AggregateOrderSensitivity, + Accumulator, GroupsAccumulator, + }, +}; +use datafusion::{ + error::Result, + logical_expr::{AggregateUDF, AggregateUDFImpl, Signature}, +}; +use datafusion_proto_common::from_proto::parse_proto_fields_to_fields; +use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator}; + +use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; +use crate::{ + arrow_wrappers::WrappedSchema, + df_result, rresult, rresult_return, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, +}; +use prost::{DecodeError, Message}; + +mod accumulator; +mod accumulator_args; +mod groups_accumulator; + +/// A stable struct for sharing a [`AggregateUDF`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_AggregateUDF { + /// FFI equivalent to the `name` of a [`AggregateUDF`] + pub name: RString, + + /// FFI equivalent to the `aliases` of a [`AggregateUDF`] + pub aliases: RVec, + + /// FFI equivalent to the `volatility` of a [`AggregateUDF`] + pub volatility: FFI_Volatility, + + /// Determines the return type of the underlying [`AggregateUDF`] based on the + /// argument types. + pub return_type: unsafe extern "C" fn( + udaf: &Self, + arg_types: RVec, + ) -> RResult, + + /// FFI equivalent to the `is_nullable` of a [`AggregateUDF`] + pub is_nullable: bool, + + /// FFI equivalent to [`AggregateUDF::groups_accumulator_supported`] + pub groups_accumulator_supported: + unsafe extern "C" fn(udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs) -> bool, + + /// FFI equivalent to [`AggregateUDF::accumulator`] + pub accumulator: unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult, + + /// FFI equivalent to [`AggregateUDF::create_sliding_accumulator`] + pub create_sliding_accumulator: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult, + + /// FFI equivalent to [`AggregateUDF::state_fields`] + #[allow(clippy::type_complexity)] + pub state_fields: unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + name: &RStr, + input_fields: RVec, + return_field: WrappedSchema, + ordering_fields: RVec>, + is_distinct: bool, + ) -> RResult>, RString>, + + /// FFI equivalent to [`AggregateUDF::create_groups_accumulator`] + pub create_groups_accumulator: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult, + + /// FFI equivalent to [`AggregateUDF::with_beneficial_ordering`] + pub with_beneficial_ordering: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + beneficial_ordering: bool, + ) -> RResult, RString>, + + /// FFI equivalent to [`AggregateUDF::order_sensitivity`] + pub order_sensitivity: + unsafe extern "C" fn(udaf: &FFI_AggregateUDF) -> FFI_AggregateOrderSensitivity, + + /// Performs type coersion. To simply this interface, all UDFs are treated as having + /// user defined signatures, which will in turn call coerce_types to be called. This + /// call should be transparent to most users as the internal function performs the + /// appropriate calls on the underlying [`AggregateUDF`] + pub coerce_types: unsafe extern "C" fn( + udf: &Self, + arg_types: RVec, + ) -> RResult, RString>, + + /// Used to create a clone on the provider of the udaf. This should + /// only need to be called by the receiver of the udaf. + pub clone: unsafe extern "C" fn(udaf: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(udaf: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the udaf. + /// A [`ForeignAggregateUDF`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_AggregateUDF {} +unsafe impl Sync for FFI_AggregateUDF {} + +pub struct AggregateUDFPrivateData { + pub udaf: Arc, +} + +impl FFI_AggregateUDF { + unsafe fn inner(&self) -> &Arc { + let private_data = self.private_data as *const AggregateUDFPrivateData; + &(*private_data).udaf + } +} + +unsafe extern "C" fn return_type_fn_wrapper( + udaf: &FFI_AggregateUDF, + arg_types: RVec, +) -> RResult { + let udaf = udaf.inner(); + + let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); + + let return_type = udaf + .return_type(&arg_types) + .and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from)) + .map(WrappedSchema); + + rresult!(return_type) +} + +unsafe extern "C" fn accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult { + let udaf = udaf.inner(); + + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); + + rresult!(udaf + .accumulator(accumulator_args.into()) + .map(FFI_Accumulator::from)) +} + +unsafe extern "C" fn create_sliding_accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult { + let udaf = udaf.inner(); + + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); + + rresult!(udaf + .create_sliding_accumulator(accumulator_args.into()) + .map(FFI_Accumulator::from)) +} + +unsafe extern "C" fn create_groups_accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult { + let udaf = udaf.inner(); + + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); + + rresult!(udaf + .create_groups_accumulator(accumulator_args.into()) + .map(FFI_GroupsAccumulator::from)) +} + +unsafe extern "C" fn groups_accumulator_supported_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> bool { + let udaf = udaf.inner(); + + ForeignAccumulatorArgs::try_from(args) + .map(|a| udaf.groups_accumulator_supported((&a).into())) + .unwrap_or_else(|e| { + log::warn!("Unable to parse accumulator args. {e}"); + false + }) +} + +unsafe extern "C" fn with_beneficial_ordering_fn_wrapper( + udaf: &FFI_AggregateUDF, + beneficial_ordering: bool, +) -> RResult, RString> { + let udaf = udaf.inner().as_ref().clone(); + + let result = rresult_return!(udaf.with_beneficial_ordering(beneficial_ordering)); + let result = rresult_return!(result + .map(|func| func.with_beneficial_ordering(beneficial_ordering)) + .transpose()) + .flatten() + .map(|func| FFI_AggregateUDF::from(Arc::new(func))); + + RResult::ROk(result.into()) +} + +unsafe extern "C" fn state_fields_fn_wrapper( + udaf: &FFI_AggregateUDF, + name: &RStr, + input_fields: RVec, + return_field: WrappedSchema, + ordering_fields: RVec>, + is_distinct: bool, +) -> RResult>, RString> { + let udaf = udaf.inner(); + + let input_fields = &rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields)); + let return_field = rresult_return!(Field::try_from(&return_field.0)).into(); + + let ordering_fields = &rresult_return!(ordering_fields + .into_iter() + .map(|field_bytes| datafusion_proto_common::Field::decode(field_bytes.as_ref())) + .collect::, DecodeError>>()); + + let ordering_fields = &rresult_return!(parse_proto_fields_to_fields(ordering_fields)) + .into_iter() + .map(Arc::new) + .collect::>(); + + let args = StateFieldsArgs { + name: name.as_str(), + input_fields, + return_field, + ordering_fields, + is_distinct, + }; + + let state_fields = rresult_return!(udaf.state_fields(args)); + let state_fields = rresult_return!(state_fields + .iter() + .map(|f| f.as_ref()) + .map(datafusion_proto::protobuf::Field::try_from) + .map(|v| v.map_err(DataFusionError::from)) + .collect::>>()) + .into_iter() + .map(|field| field.encode_to_vec().into()) + .collect(); + + RResult::ROk(state_fields) +} + +unsafe extern "C" fn order_sensitivity_fn_wrapper( + udaf: &FFI_AggregateUDF, +) -> FFI_AggregateOrderSensitivity { + udaf.inner().order_sensitivity().into() +} + +unsafe extern "C" fn coerce_types_fn_wrapper( + udaf: &FFI_AggregateUDF, + arg_types: RVec, +) -> RResult, RString> { + let udaf = udaf.inner(); + + let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); + + let arg_fields = arg_types + .iter() + .map(|dt| Field::new("f", dt.clone(), true)) + .map(Arc::new) + .collect::>(); + let return_types = rresult_return!(fields_with_aggregate_udf(&arg_fields, udaf)) + .into_iter() + .map(|f| f.data_type().to_owned()) + .collect::>(); + + rresult!(vec_datatype_to_rvec_wrapped(&return_types)) +} + +unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) { + let private_data = Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF { + Arc::clone(udaf.inner()).into() +} + +impl Clone for FFI_AggregateUDF { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl From> for FFI_AggregateUDF { + fn from(udaf: Arc) -> Self { + let name = udaf.name().into(); + let aliases = udaf.aliases().iter().map(|a| a.to_owned().into()).collect(); + let is_nullable = udaf.is_nullable(); + let volatility = udaf.signature().volatility.into(); + + let private_data = Box::new(AggregateUDFPrivateData { udaf }); + + Self { + name, + is_nullable, + volatility, + aliases, + return_type: return_type_fn_wrapper, + accumulator: accumulator_fn_wrapper, + create_sliding_accumulator: create_sliding_accumulator_fn_wrapper, + create_groups_accumulator: create_groups_accumulator_fn_wrapper, + groups_accumulator_supported: groups_accumulator_supported_fn_wrapper, + with_beneficial_ordering: with_beneficial_ordering_fn_wrapper, + state_fields: state_fields_fn_wrapper, + order_sensitivity: order_sensitivity_fn_wrapper, + coerce_types: coerce_types_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl Drop for FFI_AggregateUDF { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignAggregateUDF is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_AggregateUDF. +#[derive(Debug)] +pub struct ForeignAggregateUDF { + signature: Signature, + aliases: Vec, + udaf: FFI_AggregateUDF, +} + +unsafe impl Send for ForeignAggregateUDF {} +unsafe impl Sync for ForeignAggregateUDF {} + +impl TryFrom<&FFI_AggregateUDF> for ForeignAggregateUDF { + type Error = DataFusionError; + + fn try_from(udaf: &FFI_AggregateUDF) -> Result { + let signature = Signature::user_defined((&udaf.volatility).into()); + let aliases = udaf.aliases.iter().map(|s| s.to_string()).collect(); + + Ok(Self { + udaf: udaf.clone(), + signature, + aliases, + }) + } +} + +impl AggregateUDFImpl for ForeignAggregateUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + self.udaf.name.as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; + + let result = unsafe { (self.udaf.return_type)(&self.udaf, arg_types) }; + + let result = df_result!(result); + + result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) + } + + fn is_nullable(&self) -> bool { + self.udaf.is_nullable + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let args = acc_args.try_into()?; + unsafe { + df_result!((self.udaf.accumulator)(&self.udaf, args)).map(|accum| { + Box::new(ForeignAccumulator::from(accum)) as Box + }) + } + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + unsafe { + let name = RStr::from_str(args.name); + let input_fields = vec_fieldref_to_rvec_wrapped(args.input_fields)?; + let return_field = + WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?); + let ordering_fields = args + .ordering_fields + .iter() + .map(|f| f.as_ref()) + .map(datafusion_proto::protobuf::Field::try_from) + .map(|v| v.map_err(DataFusionError::from)) + .collect::>>()? + .into_iter() + .map(|proto_field| proto_field.encode_to_vec().into()) + .collect(); + + let fields = df_result!((self.udaf.state_fields)( + &self.udaf, + &name, + input_fields, + return_field, + ordering_fields, + args.is_distinct + ))?; + let fields = fields + .into_iter() + .map(|field_bytes| { + datafusion_proto_common::Field::decode(field_bytes.as_ref()) + .map_err(|e| DataFusionError::Execution(e.to_string())) + }) + .collect::>>()?; + + parse_proto_fields_to_fields(fields.iter()) + .map(|fields| fields.into_iter().map(Arc::new).collect()) + .map_err(|e| DataFusionError::Execution(e.to_string())) + } + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + let args = match FFI_AccumulatorArgs::try_from(args) { + Ok(v) => v, + Err(e) => { + log::warn!("Attempting to convert accumulator arguments: {e}"); + return false; + } + }; + + unsafe { (self.udaf.groups_accumulator_supported)(&self.udaf, args) } + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let args = FFI_AccumulatorArgs::try_from(args)?; + + unsafe { + df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args)).map( + |accum| { + Box::new(ForeignGroupsAccumulator::from(accum)) + as Box + }, + ) + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let args = args.try_into()?; + unsafe { + df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args)).map( + |accum| Box::new(ForeignAccumulator::from(accum)) as Box, + ) + } + } + + fn with_beneficial_ordering( + self: Arc, + beneficial_ordering: bool, + ) -> Result>> { + unsafe { + let result = df_result!((self.udaf.with_beneficial_ordering)( + &self.udaf, + beneficial_ordering + ))? + .into_option(); + + let result = result + .map(|func| ForeignAggregateUDF::try_from(&func)) + .transpose()?; + + Ok(result.map(|func| Arc::new(func) as Arc)) + } + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + unsafe { (self.udaf.order_sensitivity)(&self.udaf).into() } + } + + fn simplify(&self) -> Option { + None + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + unsafe { + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; + let result_types = + df_result!((self.udaf.coerce_types)(&self.udaf, arg_types))?; + Ok(rvec_wrapped_to_vec_datatype(&result_types)?) + } + } +} + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub enum FFI_AggregateOrderSensitivity { + Insensitive, + HardRequirement, + Beneficial, +} + +impl From for AggregateOrderSensitivity { + fn from(value: FFI_AggregateOrderSensitivity) -> Self { + match value { + FFI_AggregateOrderSensitivity::Insensitive => Self::Insensitive, + FFI_AggregateOrderSensitivity::HardRequirement => Self::HardRequirement, + FFI_AggregateOrderSensitivity::Beneficial => Self::Beneficial, + } + } +} + +impl From for FFI_AggregateOrderSensitivity { + fn from(value: AggregateOrderSensitivity) -> Self { + match value { + AggregateOrderSensitivity::Insensitive => Self::Insensitive, + AggregateOrderSensitivity::HardRequirement => Self::HardRequirement, + AggregateOrderSensitivity::Beneficial => Self::Beneficial, + } + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::Schema; + use datafusion::{ + common::create_array, functions_aggregate::sum::Sum, + physical_expr::PhysicalSortExpr, physical_plan::expressions::col, + scalar::ScalarValue, + }; + + use super::*; + + fn create_test_foreign_udaf( + original_udaf: impl AggregateUDFImpl + 'static, + ) -> Result { + let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + + let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + + let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; + Ok(foreign_udaf.into()) + } + + #[test] + fn test_round_trip_udaf() -> Result<()> { + let original_udaf = Sum::new(); + let original_name = original_udaf.name().to_owned(); + let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + + // Convert to FFI format + let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + + // Convert back to native format + let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; + let foreign_udaf: AggregateUDF = foreign_udaf.into(); + + assert_eq!(original_name, foreign_udaf.name()); + Ok(()) + } + + #[test] + fn test_foreign_udaf_aliases() -> Result<()> { + let foreign_udaf = + create_test_foreign_udaf(Sum::new())?.with_aliases(["my_function"]); + + let return_type = foreign_udaf.return_type(&[DataType::Float64])?; + assert_eq!(return_type, DataType::Float64); + Ok(()) + } + + #[test] + fn test_foreign_udaf_accumulator() -> Result<()> { + let foreign_udaf = create_test_foreign_udaf(Sum::new())?; + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); + let acc_args = AccumulatorArgs { + return_field: Field::new("f", DataType::Float64, true).into(), + schema: &schema, + ignore_nulls: true, + order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)], + is_reversed: false, + name: "round_trip", + is_distinct: true, + exprs: &[col("a", &schema)?], + }; + let mut accumulator = foreign_udaf.accumulator(acc_args)?; + let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); + accumulator.update_batch(&[values])?; + let resultant_value = accumulator.evaluate()?; + assert_eq!(resultant_value, ScalarValue::Float64(Some(150.))); + + Ok(()) + } + + #[test] + fn test_beneficial_ordering() -> Result<()> { + let foreign_udaf = create_test_foreign_udaf( + datafusion::functions_aggregate::first_last::FirstValue::new(), + )?; + + let foreign_udaf = foreign_udaf.with_beneficial_ordering(true)?.unwrap(); + + assert_eq!( + foreign_udaf.order_sensitivity(), + AggregateOrderSensitivity::Beneficial + ); + + let a_field = Arc::new(Field::new("a", DataType::Float64, true)); + let state_fields = foreign_udaf.state_fields(StateFieldsArgs { + name: "a", + input_fields: &[Field::new("f", DataType::Float64, true).into()], + return_field: Field::new("f", DataType::Float64, true).into(), + ordering_fields: &[Arc::clone(&a_field)], + is_distinct: false, + })?; + + assert_eq!(state_fields.len(), 3); + assert_eq!(state_fields[1], a_field); + Ok(()) + } + + #[test] + fn test_sliding_accumulator() -> Result<()> { + let foreign_udaf = create_test_foreign_udaf(Sum::new())?; + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); + let acc_args = AccumulatorArgs { + return_field: Field::new("f", DataType::Float64, true).into(), + schema: &schema, + ignore_nulls: true, + order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)], + is_reversed: false, + name: "round_trip", + is_distinct: true, + exprs: &[col("a", &schema)?], + }; + + let mut accumulator = foreign_udaf.create_sliding_accumulator(acc_args)?; + let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); + accumulator.update_batch(&[values])?; + let resultant_value = accumulator.evaluate()?; + assert_eq!(resultant_value, ScalarValue::Float64(Some(150.))); + + Ok(()) + } + + fn test_round_trip_order_sensitivity(sensitivity: AggregateOrderSensitivity) { + let ffi_sensitivity: FFI_AggregateOrderSensitivity = sensitivity.into(); + let round_trip_sensitivity: AggregateOrderSensitivity = ffi_sensitivity.into(); + + assert_eq!(sensitivity, round_trip_sensitivity); + } + + #[test] + fn test_round_trip_all_order_sensitivities() { + test_round_trip_order_sensitivity(AggregateOrderSensitivity::Insensitive); + test_round_trip_order_sensitivity(AggregateOrderSensitivity::HardRequirement); + test_round_trip_order_sensitivity(AggregateOrderSensitivity::Beneficial); + } +} diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 706b9fabedcb..303acc783b2e 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -15,23 +15,27 @@ // specific language governing permissions and limitations // under the License. -use std::{ffi::c_void, sync::Arc}; - +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + df_result, rresult, rresult_return, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, +}; use abi_stable::{ std_types::{RResult, RString, RVec}, StableAbi, }; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::{ array::ArrayRef, error::ArrowError, ffi::{from_ffi, to_ffi, FFI_ArrowSchema}, }; +use arrow_schema::FieldRef; +use datafusion::logical_expr::ReturnFieldArgs; use datafusion::{ error::DataFusionError, - logical_expr::{ - type_coercion::functions::data_types_with_scalar_udf, ReturnInfo, ReturnTypeArgs, - }, + logical_expr::type_coercion::functions::data_types_with_scalar_udf, }; use datafusion::{ error::Result, @@ -39,19 +43,11 @@ use datafusion::{ ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, }, }; -use return_info::FFI_ReturnInfo; use return_type_args::{ - FFI_ReturnTypeArgs, ForeignReturnTypeArgs, ForeignReturnTypeArgsOwned, -}; - -use crate::{ - arrow_wrappers::{WrappedArray, WrappedSchema}, - df_result, rresult, rresult_return, - util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, - volatility::FFI_Volatility, + FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned, }; +use std::{ffi::c_void, sync::Arc}; -pub mod return_info; pub mod return_type_args; /// A stable struct for sharing a [`ScalarUDF`] across FFI boundaries. @@ -77,19 +73,21 @@ pub struct FFI_ScalarUDF { /// Determines the return info of the underlying [`ScalarUDF`]. Either this /// or return_type may be implemented on a UDF. - pub return_type_from_args: unsafe extern "C" fn( + pub return_field_from_args: unsafe extern "C" fn( udf: &Self, - args: FFI_ReturnTypeArgs, + args: FFI_ReturnFieldArgs, ) - -> RResult, + -> RResult, /// Execute the underlying [`ScalarUDF`] and return the result as a `FFI_ArrowArray` /// within an AbiStable wrapper. + #[allow(clippy::type_complexity)] pub invoke_with_args: unsafe extern "C" fn( udf: &Self, args: RVec, + arg_fields: RVec, num_rows: usize, - return_type: WrappedSchema, + return_field: WrappedSchema, ) -> RResult, /// See [`ScalarUDFImpl`] for details on short_circuits @@ -140,19 +138,20 @@ unsafe extern "C" fn return_type_fn_wrapper( rresult!(return_type) } -unsafe extern "C" fn return_type_from_args_fn_wrapper( +unsafe extern "C" fn return_field_from_args_fn_wrapper( udf: &FFI_ScalarUDF, - args: FFI_ReturnTypeArgs, -) -> RResult { + args: FFI_ReturnFieldArgs, +) -> RResult { let private_data = udf.private_data as *const ScalarUDFPrivateData; let udf = &(*private_data).udf; - let args: ForeignReturnTypeArgsOwned = rresult_return!((&args).try_into()); - let args_ref: ForeignReturnTypeArgs = (&args).into(); + let args: ForeignReturnFieldArgsOwned = rresult_return!((&args).try_into()); + let args_ref: ForeignReturnFieldArgs = (&args).into(); let return_type = udf - .return_type_from_args((&args_ref).into()) - .and_then(FFI_ReturnInfo::try_from); + .return_field_from_args((&args_ref).into()) + .and_then(|f| FFI_ArrowSchema::try_from(&f).map_err(DataFusionError::from)) + .map(WrappedSchema); rresult!(return_type) } @@ -174,8 +173,9 @@ unsafe extern "C" fn coerce_types_fn_wrapper( unsafe extern "C" fn invoke_with_args_fn_wrapper( udf: &FFI_ScalarUDF, args: RVec, + arg_fields: RVec, number_rows: usize, - return_type: WrappedSchema, + return_field: WrappedSchema, ) -> RResult { let private_data = udf.private_data as *const ScalarUDFPrivateData; let udf = &(*private_data).udf; @@ -189,12 +189,23 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( .collect::>(); let args = rresult_return!(args); - let return_type = rresult_return!(DataType::try_from(&return_type.0)); + let return_field = rresult_return!(Field::try_from(&return_field.0)).into(); + + let arg_fields = arg_fields + .into_iter() + .map(|wrapped_field| { + Field::try_from(&wrapped_field.0) + .map(Arc::new) + .map_err(DataFusionError::from) + }) + .collect::>>(); + let arg_fields = rresult_return!(arg_fields); let args = ScalarFunctionArgs { args, + arg_fields, number_rows, - return_type: &return_type, + return_field, }; let result = rresult_return!(udf @@ -243,7 +254,7 @@ impl From> for FFI_ScalarUDF { short_circuits, invoke_with_args: invoke_with_args_fn_wrapper, return_type: return_type_fn_wrapper, - return_type_from_args: return_type_from_args_fn_wrapper, + return_field_from_args: return_field_from_args_fn_wrapper, coerce_types: coerce_types_fn_wrapper, clone: clone_fn_wrapper, release: release_fn_wrapper, @@ -316,21 +327,26 @@ impl ScalarUDFImpl for ForeignScalarUDF { result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - let args: FFI_ReturnTypeArgs = args.try_into()?; + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let args: FFI_ReturnFieldArgs = args.try_into()?; - let result = unsafe { (self.udf.return_type_from_args)(&self.udf, args) }; + let result = unsafe { (self.udf.return_field_from_args)(&self.udf, args) }; let result = df_result!(result); - result.and_then(|r| r.try_into()) + result.and_then(|r| { + Field::try_from(&r.0) + .map(Arc::new) + .map_err(DataFusionError::from) + }) } fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args, + arg_fields, number_rows, - return_type, + return_field, } = invoke_args; let args = args @@ -347,10 +363,27 @@ impl ScalarUDFImpl for ForeignScalarUDF { .collect::, ArrowError>>()? .into(); - let return_type = WrappedSchema(FFI_ArrowSchema::try_from(return_type)?); + let arg_fields_wrapped = arg_fields + .iter() + .map(FFI_ArrowSchema::try_from) + .collect::, ArrowError>>()?; + + let arg_fields = arg_fields_wrapped + .into_iter() + .map(WrappedSchema) + .collect::>(); + + let return_field = return_field.as_ref().clone(); + let return_field = WrappedSchema(FFI_ArrowSchema::try_from(return_field)?); let result = unsafe { - (self.udf.invoke_with_args)(&self.udf, args, number_rows, return_type) + (self.udf.invoke_with_args)( + &self.udf, + args, + arg_fields, + number_rows, + return_field, + ) }; let result = df_result!(result)?; @@ -389,7 +422,7 @@ mod tests { let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?; - assert!(original_udf.name() == foreign_udf.name()); + assert_eq!(original_udf.name(), foreign_udf.name()); Ok(()) } diff --git a/datafusion/ffi/src/udf/return_type_args.rs b/datafusion/ffi/src/udf/return_type_args.rs index a0897630e2ea..c437c9537be6 100644 --- a/datafusion/ffi/src/udf/return_type_args.rs +++ b/datafusion/ffi/src/udf/return_type_args.rs @@ -19,33 +19,30 @@ use abi_stable::{ std_types::{ROption, RVec}, StableAbi, }; -use arrow::datatypes::DataType; +use arrow_schema::FieldRef; use datafusion::{ - common::exec_datafusion_err, error::DataFusionError, logical_expr::ReturnTypeArgs, + common::exec_datafusion_err, error::DataFusionError, logical_expr::ReturnFieldArgs, scalar::ScalarValue, }; -use crate::{ - arrow_wrappers::WrappedSchema, - util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, -}; +use crate::arrow_wrappers::WrappedSchema; +use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; use prost::Message; -/// A stable struct for sharing a [`ReturnTypeArgs`] across FFI boundaries. +/// A stable struct for sharing a [`ReturnFieldArgs`] across FFI boundaries. #[repr(C)] #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] -pub struct FFI_ReturnTypeArgs { - arg_types: RVec, +pub struct FFI_ReturnFieldArgs { + arg_fields: RVec, scalar_arguments: RVec>>, - nullables: RVec, } -impl TryFrom> for FFI_ReturnTypeArgs { +impl TryFrom> for FFI_ReturnFieldArgs { type Error = DataFusionError; - fn try_from(value: ReturnTypeArgs) -> Result { - let arg_types = vec_datatype_to_rvec_wrapped(value.arg_types)?; + fn try_from(value: ReturnFieldArgs) -> Result { + let arg_fields = vec_fieldref_to_rvec_wrapped(value.arg_fields)?; let scalar_arguments: Result, Self::Error> = value .scalar_arguments .iter() @@ -62,35 +59,31 @@ impl TryFrom> for FFI_ReturnTypeArgs { .collect(); let scalar_arguments = scalar_arguments?.into_iter().map(ROption::from).collect(); - let nullables = value.nullables.into(); Ok(Self { - arg_types, + arg_fields, scalar_arguments, - nullables, }) } } // TODO(tsaucer) It would be good to find a better way around this, but it // appears a restriction based on the need to have a borrowed ScalarValue -// in the arguments when converted to ReturnTypeArgs -pub struct ForeignReturnTypeArgsOwned { - arg_types: Vec, +// in the arguments when converted to ReturnFieldArgs +pub struct ForeignReturnFieldArgsOwned { + arg_fields: Vec, scalar_arguments: Vec>, - nullables: Vec, } -pub struct ForeignReturnTypeArgs<'a> { - arg_types: &'a [DataType], +pub struct ForeignReturnFieldArgs<'a> { + arg_fields: &'a [FieldRef], scalar_arguments: Vec>, - nullables: &'a [bool], } -impl TryFrom<&FFI_ReturnTypeArgs> for ForeignReturnTypeArgsOwned { +impl TryFrom<&FFI_ReturnFieldArgs> for ForeignReturnFieldArgsOwned { type Error = DataFusionError; - fn try_from(value: &FFI_ReturnTypeArgs) -> Result { - let arg_types = rvec_wrapped_to_vec_datatype(&value.arg_types)?; + fn try_from(value: &FFI_ReturnFieldArgs) -> Result { + let arg_fields = rvec_wrapped_to_vec_fieldref(&value.arg_fields)?; let scalar_arguments: Result, Self::Error> = value .scalar_arguments .iter() @@ -107,36 +100,31 @@ impl TryFrom<&FFI_ReturnTypeArgs> for ForeignReturnTypeArgsOwned { .collect(); let scalar_arguments = scalar_arguments?.into_iter().collect(); - let nullables = value.nullables.iter().cloned().collect(); - Ok(Self { - arg_types, + arg_fields, scalar_arguments, - nullables, }) } } -impl<'a> From<&'a ForeignReturnTypeArgsOwned> for ForeignReturnTypeArgs<'a> { - fn from(value: &'a ForeignReturnTypeArgsOwned) -> Self { +impl<'a> From<&'a ForeignReturnFieldArgsOwned> for ForeignReturnFieldArgs<'a> { + fn from(value: &'a ForeignReturnFieldArgsOwned) -> Self { Self { - arg_types: &value.arg_types, + arg_fields: &value.arg_fields, scalar_arguments: value .scalar_arguments .iter() .map(|opt| opt.as_ref()) .collect(), - nullables: &value.nullables, } } } -impl<'a> From<&'a ForeignReturnTypeArgs<'a>> for ReturnTypeArgs<'a> { - fn from(value: &'a ForeignReturnTypeArgs) -> Self { - ReturnTypeArgs { - arg_types: value.arg_types, +impl<'a> From<&'a ForeignReturnFieldArgs<'a>> for ReturnFieldArgs<'a> { + fn from(value: &'a ForeignReturnFieldArgs) -> Self { + ReturnFieldArgs { + arg_fields: value.arg_fields, scalar_arguments: &value.scalar_arguments, - nullables: value.nullables, } } } diff --git a/datafusion/ffi/src/udtf.rs b/datafusion/ffi/src/udtf.rs index 1e06247546be..ceedec2599a2 100644 --- a/datafusion/ffi/src/udtf.rs +++ b/datafusion/ffi/src/udtf.rs @@ -214,7 +214,7 @@ mod tests { let args = args .iter() .map(|arg| { - if let Expr::Literal(scalar) = arg { + if let Expr::Literal(scalar, _) = arg { Ok(scalar) } else { exec_err!("Expected only literal arguments to table udf") @@ -243,21 +243,21 @@ mod tests { ScalarValue::Utf8(s) => { let s_vec = vec![s.to_owned(); num_rows]; ( - Field::new(format!("field-{}", idx), DataType::Utf8, true), + Field::new(format!("field-{idx}"), DataType::Utf8, true), Arc::new(StringArray::from(s_vec)) as ArrayRef, ) } ScalarValue::UInt64(v) => { let v_vec = vec![v.to_owned(); num_rows]; ( - Field::new(format!("field-{}", idx), DataType::UInt64, true), + Field::new(format!("field-{idx}"), DataType::UInt64, true), Arc::new(UInt64Array::from(v_vec)) as ArrayRef, ) } ScalarValue::Float64(v) => { let v_vec = vec![v.to_owned(); num_rows]; ( - Field::new(format!("field-{}", idx), DataType::Float64, true), + Field::new(format!("field-{idx}"), DataType::Float64, true), Arc::new(Float64Array::from(v_vec)) as ArrayRef, ) } diff --git a/datafusion/ffi/src/udwf/mod.rs b/datafusion/ffi/src/udwf/mod.rs new file mode 100644 index 000000000000..504bf7a411f1 --- /dev/null +++ b/datafusion/ffi/src/udwf/mod.rs @@ -0,0 +1,432 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, sync::Arc}; + +use abi_stable::{ + std_types::{ROption, RResult, RString, RVec}, + StableAbi, +}; +use arrow::datatypes::Schema; +use arrow::{ + compute::SortOptions, + datatypes::{DataType, SchemaRef}, +}; +use arrow_schema::{Field, FieldRef}; +use datafusion::{ + error::DataFusionError, + logical_expr::{ + function::WindowUDFFieldArgs, type_coercion::functions::fields_with_window_udf, + PartitionEvaluator, + }, +}; +use datafusion::{ + error::Result, + logical_expr::{Signature, WindowUDF, WindowUDFImpl}, +}; +use partition_evaluator::{FFI_PartitionEvaluator, ForeignPartitionEvaluator}; +use partition_evaluator_args::{ + FFI_PartitionEvaluatorArgs, ForeignPartitionEvaluatorArgs, +}; +mod partition_evaluator; +mod partition_evaluator_args; +mod range; + +use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; +use crate::{ + arrow_wrappers::WrappedSchema, + df_result, rresult, rresult_return, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, +}; + +/// A stable struct for sharing a [`WindowUDF`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_WindowUDF { + /// FFI equivalent to the `name` of a [`WindowUDF`] + pub name: RString, + + /// FFI equivalent to the `aliases` of a [`WindowUDF`] + pub aliases: RVec, + + /// FFI equivalent to the `volatility` of a [`WindowUDF`] + pub volatility: FFI_Volatility, + + pub partition_evaluator: + unsafe extern "C" fn( + udwf: &Self, + args: FFI_PartitionEvaluatorArgs, + ) -> RResult, + + pub field: unsafe extern "C" fn( + udwf: &Self, + input_types: RVec, + display_name: RString, + ) -> RResult, + + /// Performs type coersion. To simply this interface, all UDFs are treated as having + /// user defined signatures, which will in turn call coerce_types to be called. This + /// call should be transparent to most users as the internal function performs the + /// appropriate calls on the underlying [`WindowUDF`] + pub coerce_types: unsafe extern "C" fn( + udf: &Self, + arg_types: RVec, + ) -> RResult, RString>, + + pub sort_options: ROption, + + /// Used to create a clone on the provider of the udf. This should + /// only need to be called by the receiver of the udf. + pub clone: unsafe extern "C" fn(udf: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(udf: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the udf. + /// A [`ForeignWindowUDF`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_WindowUDF {} +unsafe impl Sync for FFI_WindowUDF {} + +pub struct WindowUDFPrivateData { + pub udf: Arc, +} + +impl FFI_WindowUDF { + unsafe fn inner(&self) -> &Arc { + let private_data = self.private_data as *const WindowUDFPrivateData; + &(*private_data).udf + } +} + +unsafe extern "C" fn partition_evaluator_fn_wrapper( + udwf: &FFI_WindowUDF, + args: FFI_PartitionEvaluatorArgs, +) -> RResult { + let inner = udwf.inner(); + + let args = rresult_return!(ForeignPartitionEvaluatorArgs::try_from(args)); + + let evaluator = rresult_return!(inner.partition_evaluator_factory((&args).into())); + + RResult::ROk(evaluator.into()) +} + +unsafe extern "C" fn field_fn_wrapper( + udwf: &FFI_WindowUDF, + input_fields: RVec, + display_name: RString, +) -> RResult { + let inner = udwf.inner(); + + let input_fields = rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields)); + + let field = rresult_return!(inner.field(WindowUDFFieldArgs::new( + &input_fields, + display_name.as_str() + ))); + + let schema = Arc::new(Schema::new(vec![field])); + + RResult::ROk(WrappedSchema::from(schema)) +} + +unsafe extern "C" fn coerce_types_fn_wrapper( + udwf: &FFI_WindowUDF, + arg_types: RVec, +) -> RResult, RString> { + let inner = udwf.inner(); + + let arg_fields = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)) + .into_iter() + .map(|dt| Field::new("f", dt, false)) + .map(Arc::new) + .collect::>(); + + let return_fields = rresult_return!(fields_with_window_udf(&arg_fields, inner)); + let return_types = return_fields + .into_iter() + .map(|f| f.data_type().to_owned()) + .collect::>(); + + rresult!(vec_datatype_to_rvec_wrapped(&return_types)) +} + +unsafe extern "C" fn release_fn_wrapper(udwf: &mut FFI_WindowUDF) { + let private_data = Box::from_raw(udwf.private_data as *mut WindowUDFPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(udwf: &FFI_WindowUDF) -> FFI_WindowUDF { + // let private_data = udf.private_data as *const WindowUDFPrivateData; + // let udf_data = &(*private_data); + + // let private_data = Box::new(WindowUDFPrivateData { + // udf: Arc::clone(&udf_data.udf), + // }); + let private_data = Box::new(WindowUDFPrivateData { + udf: Arc::clone(udwf.inner()), + }); + + FFI_WindowUDF { + name: udwf.name.clone(), + aliases: udwf.aliases.clone(), + volatility: udwf.volatility.clone(), + partition_evaluator: partition_evaluator_fn_wrapper, + sort_options: udwf.sort_options.clone(), + coerce_types: coerce_types_fn_wrapper, + field: field_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } +} + +impl Clone for FFI_WindowUDF { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl From> for FFI_WindowUDF { + fn from(udf: Arc) -> Self { + let name = udf.name().into(); + let aliases = udf.aliases().iter().map(|a| a.to_owned().into()).collect(); + let volatility = udf.signature().volatility.into(); + let sort_options = udf.sort_options().map(|v| (&v).into()).into(); + + let private_data = Box::new(WindowUDFPrivateData { udf }); + + Self { + name, + aliases, + volatility, + partition_evaluator: partition_evaluator_fn_wrapper, + sort_options, + coerce_types: coerce_types_fn_wrapper, + field: field_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl Drop for FFI_WindowUDF { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignWindowUDF is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_WindowUDF. +#[derive(Debug)] +pub struct ForeignWindowUDF { + name: String, + aliases: Vec, + udf: FFI_WindowUDF, + signature: Signature, +} + +unsafe impl Send for ForeignWindowUDF {} +unsafe impl Sync for ForeignWindowUDF {} + +impl TryFrom<&FFI_WindowUDF> for ForeignWindowUDF { + type Error = DataFusionError; + + fn try_from(udf: &FFI_WindowUDF) -> Result { + let name = udf.name.to_owned().into(); + let signature = Signature::user_defined((&udf.volatility).into()); + + let aliases = udf.aliases.iter().map(|s| s.to_string()).collect(); + + Ok(Self { + name, + udf: udf.clone(), + aliases, + signature, + }) + } +} + +impl WindowUDFImpl for ForeignWindowUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + unsafe { + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; + let result_types = df_result!((self.udf.coerce_types)(&self.udf, arg_types))?; + Ok(rvec_wrapped_to_vec_datatype(&result_types)?) + } + } + + fn partition_evaluator( + &self, + args: datafusion::logical_expr::function::PartitionEvaluatorArgs, + ) -> Result> { + let evaluator = unsafe { + let args = FFI_PartitionEvaluatorArgs::try_from(args)?; + (self.udf.partition_evaluator)(&self.udf, args) + }; + + df_result!(evaluator).map(|evaluator| { + Box::new(ForeignPartitionEvaluator::from(evaluator)) + as Box + }) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + unsafe { + let input_types = vec_fieldref_to_rvec_wrapped(field_args.input_fields())?; + let schema = df_result!((self.udf.field)( + &self.udf, + input_types, + field_args.name().into() + ))?; + let schema: SchemaRef = schema.into(); + + match schema.fields().is_empty() { + true => Err(DataFusionError::Execution( + "Unable to retrieve field in WindowUDF via FFI".to_string(), + )), + false => Ok(schema.field(0).to_owned().into()), + } + } + } + + fn sort_options(&self) -> Option { + let options: Option<&FFI_SortOptions> = self.udf.sort_options.as_ref().into(); + options.map(|s| s.into()) + } +} + +#[repr(C)] +#[derive(Debug, StableAbi, Clone)] +#[allow(non_camel_case_types)] +pub struct FFI_SortOptions { + pub descending: bool, + pub nulls_first: bool, +} + +impl From<&SortOptions> for FFI_SortOptions { + fn from(value: &SortOptions) -> Self { + Self { + descending: value.descending, + nulls_first: value.nulls_first, + } + } +} + +impl From<&FFI_SortOptions> for SortOptions { + fn from(value: &FFI_SortOptions) -> Self { + Self { + descending: value.descending, + nulls_first: value.nulls_first, + } + } +} + +#[cfg(test)] +#[cfg(feature = "integration-tests")] +mod tests { + use crate::tests::create_record_batch; + use crate::udwf::{FFI_WindowUDF, ForeignWindowUDF}; + use arrow::array::{create_array, ArrayRef}; + use datafusion::functions_window::lead_lag::{lag_udwf, WindowShift}; + use datafusion::logical_expr::expr::Sort; + use datafusion::logical_expr::{col, ExprFunctionExt, WindowUDF, WindowUDFImpl}; + use datafusion::prelude::SessionContext; + use std::sync::Arc; + + fn create_test_foreign_udwf( + original_udwf: impl WindowUDFImpl + 'static, + ) -> datafusion::common::Result { + let original_udwf = Arc::new(WindowUDF::from(original_udwf)); + + let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into(); + + let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?; + Ok(foreign_udwf.into()) + } + + #[test] + fn test_round_trip_udwf() -> datafusion::common::Result<()> { + let original_udwf = lag_udwf(); + let original_name = original_udwf.name().to_owned(); + + // Convert to FFI format + let local_udwf: FFI_WindowUDF = Arc::clone(&original_udwf).into(); + + // Convert back to native format + let foreign_udwf: ForeignWindowUDF = (&local_udwf).try_into()?; + let foreign_udwf: WindowUDF = foreign_udwf.into(); + + assert_eq!(original_name, foreign_udwf.name()); + Ok(()) + } + + #[tokio::test] + async fn test_lag_udwf() -> datafusion::common::Result<()> { + let udwf = create_test_foreign_udwf(WindowShift::lag())?; + + let ctx = SessionContext::default(); + let df = ctx.read_batch(create_record_batch(-5, 5))?; + + let df = df.select(vec![ + col("a"), + udwf.call(vec![col("a")]) + .order_by(vec![Sort::new(col("a"), true, true)]) + .build() + .unwrap() + .alias("lag_a"), + ])?; + + df.clone().show().await?; + + let result = df.collect().await?; + let expected = + create_array!(Int32, [None, Some(-5), Some(-4), Some(-3), Some(-2)]) + as ArrayRef; + + assert_eq!(result.len(), 1); + assert_eq!(result[0].column(1), &expected); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udwf/partition_evaluator.rs b/datafusion/ffi/src/udwf/partition_evaluator.rs new file mode 100644 index 000000000000..995d00cce30e --- /dev/null +++ b/datafusion/ffi/src/udwf/partition_evaluator.rs @@ -0,0 +1,320 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, ops::Range}; + +use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return}; +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; +use arrow::{array::ArrayRef, error::ArrowError}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::{window_state::WindowAggState, PartitionEvaluator}, + scalar::ScalarValue, +}; +use prost::Message; + +use super::range::FFI_Range; + +/// A stable struct for sharing [`PartitionEvaluator`] across FFI boundaries. +/// For an explanation of each field, see the corresponding function +/// defined in [`PartitionEvaluator`]. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_PartitionEvaluator { + pub evaluate_all: unsafe extern "C" fn( + evaluator: &mut Self, + values: RVec, + num_rows: usize, + ) -> RResult, + + pub evaluate: unsafe extern "C" fn( + evaluator: &mut Self, + values: RVec, + range: FFI_Range, + ) -> RResult, RString>, + + pub evaluate_all_with_rank: unsafe extern "C" fn( + evaluator: &Self, + num_rows: usize, + ranks_in_partition: RVec, + ) + -> RResult, + + pub get_range: unsafe extern "C" fn( + evaluator: &Self, + idx: usize, + n_rows: usize, + ) -> RResult, + + pub is_causal: bool, + + pub supports_bounded_execution: bool, + pub uses_window_frame: bool, + pub include_rank: bool, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(evaluator: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the evaluator. + /// A [`ForeignPartitionEvaluator`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_PartitionEvaluator {} +unsafe impl Sync for FFI_PartitionEvaluator {} + +pub struct PartitionEvaluatorPrivateData { + pub evaluator: Box, +} + +impl FFI_PartitionEvaluator { + unsafe fn inner_mut(&mut self) -> &mut Box<(dyn PartitionEvaluator + 'static)> { + let private_data = self.private_data as *mut PartitionEvaluatorPrivateData; + &mut (*private_data).evaluator + } + + unsafe fn inner(&self) -> &(dyn PartitionEvaluator + 'static) { + let private_data = self.private_data as *mut PartitionEvaluatorPrivateData; + (*private_data).evaluator.as_ref() + } +} + +unsafe extern "C" fn evaluate_all_fn_wrapper( + evaluator: &mut FFI_PartitionEvaluator, + values: RVec, + num_rows: usize, +) -> RResult { + let inner = evaluator.inner_mut(); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + let return_array = inner + .evaluate_all(&values_arrays, num_rows) + .and_then(|array| WrappedArray::try_from(&array).map_err(DataFusionError::from)); + + rresult!(return_array) +} + +unsafe extern "C" fn evaluate_fn_wrapper( + evaluator: &mut FFI_PartitionEvaluator, + values: RVec, + range: FFI_Range, +) -> RResult, RString> { + let inner = evaluator.inner_mut(); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + // let return_array = (inner.evaluate(&values_arrays, &range.into())); + // .and_then(|array| WrappedArray::try_from(&array).map_err(DataFusionError::from)); + let scalar_result = rresult_return!(inner.evaluate(&values_arrays, &range.into())); + let proto_result: datafusion_proto::protobuf::ScalarValue = + rresult_return!((&scalar_result).try_into()); + + RResult::ROk(proto_result.encode_to_vec().into()) +} + +unsafe extern "C" fn evaluate_all_with_rank_fn_wrapper( + evaluator: &FFI_PartitionEvaluator, + num_rows: usize, + ranks_in_partition: RVec, +) -> RResult { + let inner = evaluator.inner(); + + let ranks_in_partition = ranks_in_partition + .into_iter() + .map(Range::from) + .collect::>(); + + let return_array = inner + .evaluate_all_with_rank(num_rows, &ranks_in_partition) + .and_then(|array| WrappedArray::try_from(&array).map_err(DataFusionError::from)); + + rresult!(return_array) +} + +unsafe extern "C" fn get_range_fn_wrapper( + evaluator: &FFI_PartitionEvaluator, + idx: usize, + n_rows: usize, +) -> RResult { + let inner = evaluator.inner(); + let range = inner.get_range(idx, n_rows).map(FFI_Range::from); + + rresult!(range) +} + +unsafe extern "C" fn release_fn_wrapper(evaluator: &mut FFI_PartitionEvaluator) { + let private_data = + Box::from_raw(evaluator.private_data as *mut PartitionEvaluatorPrivateData); + drop(private_data); +} + +impl From> for FFI_PartitionEvaluator { + fn from(evaluator: Box) -> Self { + let is_causal = evaluator.is_causal(); + let supports_bounded_execution = evaluator.supports_bounded_execution(); + let include_rank = evaluator.include_rank(); + let uses_window_frame = evaluator.uses_window_frame(); + + let private_data = PartitionEvaluatorPrivateData { evaluator }; + + Self { + evaluate: evaluate_fn_wrapper, + evaluate_all: evaluate_all_fn_wrapper, + evaluate_all_with_rank: evaluate_all_with_rank_fn_wrapper, + get_range: get_range_fn_wrapper, + is_causal, + supports_bounded_execution, + include_rank, + uses_window_frame, + release: release_fn_wrapper, + private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, + } + } +} + +impl Drop for FFI_PartitionEvaluator { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignPartitionEvaluator is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_PartitionEvaluator. +#[derive(Debug)] +pub struct ForeignPartitionEvaluator { + evaluator: FFI_PartitionEvaluator, +} + +unsafe impl Send for ForeignPartitionEvaluator {} +unsafe impl Sync for ForeignPartitionEvaluator {} + +impl From for ForeignPartitionEvaluator { + fn from(evaluator: FFI_PartitionEvaluator) -> Self { + Self { evaluator } + } +} + +impl PartitionEvaluator for ForeignPartitionEvaluator { + fn memoize(&mut self, _state: &mut WindowAggState) -> Result<()> { + // Exposing `memoize` increases the surface are of the FFI work + // so for now we dot support it. + Ok(()) + } + + fn get_range(&self, idx: usize, n_rows: usize) -> Result> { + let range = unsafe { (self.evaluator.get_range)(&self.evaluator, idx, n_rows) }; + df_result!(range).map(Range::from) + } + + /// Get whether evaluator needs future data for its result (if so returns `false`) or not + fn is_causal(&self) -> bool { + self.evaluator.is_causal + } + + fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result { + let result = unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + (self.evaluator.evaluate_all)(&mut self.evaluator, values, num_rows) + }; + + let array = df_result!(result)?; + + Ok(array.try_into()?) + } + + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &Range, + ) -> Result { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + + let scalar_bytes = df_result!((self.evaluator.evaluate)( + &mut self.evaluator, + values, + range.to_owned().into() + ))?; + + let proto_scalar = + datafusion_proto::protobuf::ScalarValue::decode(scalar_bytes.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + ScalarValue::try_from(&proto_scalar).map_err(DataFusionError::from) + } + } + + fn evaluate_all_with_rank( + &self, + num_rows: usize, + ranks_in_partition: &[Range], + ) -> Result { + let result = unsafe { + let ranks_in_partition = ranks_in_partition + .iter() + .map(|rank| FFI_Range::from(rank.to_owned())) + .collect(); + (self.evaluator.evaluate_all_with_rank)( + &self.evaluator, + num_rows, + ranks_in_partition, + ) + }; + + let array = df_result!(result)?; + + Ok(array.try_into()?) + } + + fn supports_bounded_execution(&self) -> bool { + self.evaluator.supports_bounded_execution + } + + fn uses_window_frame(&self) -> bool { + self.evaluator.uses_window_frame + } + + fn include_rank(&self) -> bool { + self.evaluator.include_rank + } +} + +#[cfg(test)] +mod tests {} diff --git a/datafusion/ffi/src/udwf/partition_evaluator_args.rs b/datafusion/ffi/src/udwf/partition_evaluator_args.rs new file mode 100644 index 000000000000..dffeb23741b6 --- /dev/null +++ b/datafusion/ffi/src/udwf/partition_evaluator_args.rs @@ -0,0 +1,181 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{collections::HashMap, sync::Arc}; + +use crate::arrow_wrappers::WrappedSchema; +use abi_stable::{std_types::RVec, StableAbi}; +use arrow::{ + datatypes::{DataType, Field, Schema, SchemaRef}, + error::ArrowError, + ffi::FFI_ArrowSchema, +}; +use arrow_schema::FieldRef; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::function::PartitionEvaluatorArgs, + physical_plan::{expressions::Column, PhysicalExpr}, + prelude::SessionContext, +}; +use datafusion_proto::{ + physical_plan::{ + from_proto::parse_physical_expr, to_proto::serialize_physical_exprs, + DefaultPhysicalExtensionCodec, + }, + protobuf::PhysicalExprNode, +}; +use prost::Message; + +/// A stable struct for sharing [`PartitionEvaluatorArgs`] across FFI boundaries. +/// For an explanation of each field, see the corresponding function +/// defined in [`PartitionEvaluatorArgs`]. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_PartitionEvaluatorArgs { + input_exprs: RVec>, + input_fields: RVec, + is_reversed: bool, + ignore_nulls: bool, + schema: WrappedSchema, +} + +impl TryFrom> for FFI_PartitionEvaluatorArgs { + type Error = DataFusionError; + fn try_from(args: PartitionEvaluatorArgs) -> Result { + // This is a bit of a hack. Since PartitionEvaluatorArgs does not carry a schema + // around, and instead passes the data types directly we are unable to decode the + // protobuf PhysicalExpr correctly. In evaluating the code the only place these + // appear to be really used are the Column data types. So here we will find all + // of the required columns and create a schema that has empty fields except for + // the ones we require. Ideally we would enhance PartitionEvaluatorArgs to just + // pass along the schema, but that is a larger breaking change. + let required_columns: HashMap = args + .input_exprs() + .iter() + .zip(args.input_fields()) + .filter_map(|(expr, field)| { + expr.as_any() + .downcast_ref::() + .map(|column| (column.index(), (column.name(), field.data_type()))) + }) + .collect(); + + let max_column = required_columns.keys().max(); + let fields: Vec<_> = max_column + .map(|max_column| { + (0..(max_column + 1)) + .map(|idx| match required_columns.get(&idx) { + Some((name, data_type)) => { + Field::new(*name, (*data_type).clone(), true) + } + None => Field::new( + format!("ffi_partition_evaluator_col_{idx}"), + DataType::Null, + true, + ), + }) + .collect() + }) + .unwrap_or_default(); + + let schema = Arc::new(Schema::new(fields)); + + let codec = DefaultPhysicalExtensionCodec {}; + let input_exprs = serialize_physical_exprs(args.input_exprs(), &codec)? + .into_iter() + .map(|expr_node| expr_node.encode_to_vec().into()) + .collect(); + + let input_fields = args + .input_fields() + .iter() + .map(|input_type| FFI_ArrowSchema::try_from(input_type).map(WrappedSchema)) + .collect::, ArrowError>>()? + .into(); + + let schema: WrappedSchema = schema.into(); + + Ok(Self { + input_exprs, + input_fields, + schema, + is_reversed: args.is_reversed(), + ignore_nulls: args.ignore_nulls(), + }) + } +} + +/// This struct mirrors PartitionEvaluatorArgs except that it contains owned data. +/// It is necessary to create this struct so that we can parse the protobuf +/// data across the FFI boundary and turn it into owned data that +/// PartitionEvaluatorArgs can then reference. +pub struct ForeignPartitionEvaluatorArgs { + input_exprs: Vec>, + input_fields: Vec, + is_reversed: bool, + ignore_nulls: bool, +} + +impl TryFrom for ForeignPartitionEvaluatorArgs { + type Error = DataFusionError; + + fn try_from(value: FFI_PartitionEvaluatorArgs) -> Result { + let default_ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + + let schema: SchemaRef = value.schema.into(); + + let input_exprs = value + .input_exprs + .into_iter() + .map(|input_expr_bytes| PhysicalExprNode::decode(input_expr_bytes.as_ref())) + .collect::, prost::DecodeError>>() + .map_err(|e| DataFusionError::Execution(e.to_string()))? + .iter() + .map(|expr_node| { + parse_physical_expr(expr_node, &default_ctx, &schema, &codec) + }) + .collect::>>()?; + + let input_fields = input_exprs + .iter() + .map(|expr| expr.return_field(&schema)) + .collect::>>()?; + + Ok(Self { + input_exprs, + input_fields, + is_reversed: value.is_reversed, + ignore_nulls: value.ignore_nulls, + }) + } +} + +impl<'a> From<&'a ForeignPartitionEvaluatorArgs> for PartitionEvaluatorArgs<'a> { + fn from(value: &'a ForeignPartitionEvaluatorArgs) -> Self { + PartitionEvaluatorArgs::new( + &value.input_exprs, + &value.input_fields, + value.is_reversed, + value.ignore_nulls, + ) + } +} + +#[cfg(test)] +mod tests {} diff --git a/datafusion/ffi/src/udf/return_info.rs b/datafusion/ffi/src/udwf/range.rs similarity index 50% rename from datafusion/ffi/src/udf/return_info.rs rename to datafusion/ffi/src/udwf/range.rs index cf76ddd1db76..1ddcc4199fe2 100644 --- a/datafusion/ffi/src/udf/return_info.rs +++ b/datafusion/ffi/src/udwf/range.rs @@ -15,39 +15,50 @@ // specific language governing permissions and limitations // under the License. -use abi_stable::StableAbi; -use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; -use datafusion::{error::DataFusionError, logical_expr::ReturnInfo}; +use std::ops::Range; -use crate::arrow_wrappers::WrappedSchema; +use abi_stable::StableAbi; -/// A stable struct for sharing a [`ReturnInfo`] across FFI boundaries. +/// A stable struct for sharing [`Range`] across FFI boundaries. +/// For an explanation of each field, see the corresponding function +/// defined in [`Range`]. #[repr(C)] #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] -pub struct FFI_ReturnInfo { - return_type: WrappedSchema, - nullable: bool, +pub struct FFI_Range { + pub start: usize, + pub end: usize, } -impl TryFrom for FFI_ReturnInfo { - type Error = DataFusionError; +impl From> for FFI_Range { + fn from(value: Range) -> Self { + Self { + start: value.start, + end: value.end, + } + } +} - fn try_from(value: ReturnInfo) -> Result { - let return_type = WrappedSchema(FFI_ArrowSchema::try_from(value.return_type())?); - Ok(Self { - return_type, - nullable: value.nullable(), - }) +impl From for Range { + fn from(value: FFI_Range) -> Self { + Self { + start: value.start, + end: value.end, + } } } -impl TryFrom for ReturnInfo { - type Error = DataFusionError; +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_round_trip_ffi_range() { + let original = Range { start: 10, end: 30 }; - fn try_from(value: FFI_ReturnInfo) -> Result { - let return_type = DataType::try_from(&value.return_type.0)?; + let ffi_range: FFI_Range = original.clone().into(); + let round_trip: Range = ffi_range.into(); - Ok(ReturnInfo::new(return_type, value.nullable)) + assert_eq!(original, round_trip); } } diff --git a/datafusion/ffi/src/util.rs b/datafusion/ffi/src/util.rs index 9d5f2aefe324..abe369c57298 100644 --- a/datafusion/ffi/src/util.rs +++ b/datafusion/ffi/src/util.rs @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. +use crate::arrow_wrappers::WrappedSchema; use abi_stable::std_types::RVec; +use arrow::datatypes::Field; use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; +use arrow_schema::FieldRef; +use std::sync::Arc; -use crate::arrow_wrappers::WrappedSchema; - -/// This macro is a helpful conversion utility to conver from an abi_stable::RResult to a +/// This macro is a helpful conversion utility to convert from an abi_stable::RResult to a /// DataFusion result. #[macro_export] macro_rules! df_result { @@ -64,6 +66,31 @@ macro_rules! rresult_return { }; } +/// This is a utility function to convert a slice of [`Field`] to its equivalent +/// FFI friendly counterpart, [`WrappedSchema`] +pub fn vec_fieldref_to_rvec_wrapped( + fields: &[FieldRef], +) -> Result, arrow::error::ArrowError> { + Ok(fields + .iter() + .map(FFI_ArrowSchema::try_from) + .collect::, arrow::error::ArrowError>>()? + .into_iter() + .map(WrappedSchema) + .collect()) +} + +/// This is a utility function to convert an FFI friendly vector of [`WrappedSchema`] +/// to their equivalent [`Field`]. +pub fn rvec_wrapped_to_vec_fieldref( + fields: &RVec, +) -> Result, arrow::error::ArrowError> { + fields + .iter() + .map(|d| Field::try_from(&d.0).map(Arc::new)) + .collect() +} + /// This is a utility function to convert a slice of [`DataType`] to its equivalent /// FFI friendly counterpart, [`WrappedSchema`] pub fn vec_datatype_to_rvec_wrapped( @@ -116,7 +143,7 @@ mod tests { assert!(returned_err_result.is_err()); assert!( returned_err_result.unwrap_err().to_string() - == format!("Execution error: {}", ERROR_VALUE) + == format!("Execution error: {ERROR_VALUE}") ); let ok_result: Result = Ok(VALID_VALUE.to_string()); @@ -129,7 +156,7 @@ mod tests { let returned_err_r_result = wrap_result(err_result); assert!( returned_err_r_result - == RResult::RErr(format!("Execution error: {}", ERROR_VALUE).into()) + == RResult::RErr(format!("Execution error: {ERROR_VALUE}").into()) ); } } diff --git a/datafusion/ffi/src/volatility.rs b/datafusion/ffi/src/volatility.rs index 0aaf68a174cf..f1705da294a3 100644 --- a/datafusion/ffi/src/volatility.rs +++ b/datafusion/ffi/src/volatility.rs @@ -19,7 +19,7 @@ use abi_stable::StableAbi; use datafusion::logical_expr::Volatility; #[repr(C)] -#[derive(Debug, StableAbi)] +#[derive(Debug, StableAbi, Clone)] #[allow(non_camel_case_types)] pub enum FFI_Volatility { Immutable, diff --git a/datafusion/ffi/tests/ffi_integration.rs b/datafusion/ffi/tests/ffi_integration.rs index c6df324e9a17..1ef16fbaa4d8 100644 --- a/datafusion/ffi/tests/ffi_integration.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -19,7 +19,6 @@ /// when the feature integtation-tests is built #[cfg(feature = "integration-tests")] mod tests { - use datafusion::error::{DataFusionError, Result}; use datafusion::prelude::SessionContext; use datafusion_ffi::catalog_provider::ForeignCatalogProvider; diff --git a/datafusion/ffi/tests/ffi_udaf.rs b/datafusion/ffi/tests/ffi_udaf.rs new file mode 100644 index 000000000000..31b1f473913c --- /dev/null +++ b/datafusion/ffi/tests/ffi_udaf.rs @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Add an additional module here for convenience to scope this to only +/// when the feature integtation-tests is built +#[cfg(feature = "integration-tests")] +mod tests { + use arrow::array::Float64Array; + use datafusion::common::record_batch; + use datafusion::error::{DataFusionError, Result}; + use datafusion::logical_expr::AggregateUDF; + use datafusion::prelude::{col, SessionContext}; + + use datafusion_ffi::tests::utils::get_module; + use datafusion_ffi::udaf::ForeignAggregateUDF; + + #[tokio::test] + async fn test_ffi_udaf() -> Result<()> { + let module = get_module()?; + + let ffi_sum_func = + module + .create_sum_udaf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_udaf".to_string(), + ))?(); + let foreign_sum_func: ForeignAggregateUDF = (&ffi_sum_func).try_into()?; + + let udaf: AggregateUDF = foreign_sum_func.into(); + + let ctx = SessionContext::default(); + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), + ("b", Float64, vec![1.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0]) + ) + .unwrap(); + + let df = ctx.read_batch(record_batch)?; + + let df = df + .aggregate( + vec![col("a")], + vec![udaf.call(vec![col("b")]).alias("sum_b")], + )? + .sort_by(vec![col("a")])?; + + let result = df.collect().await?; + + let expected = record_batch!( + ("a", Int32, vec![1, 2, 4]), + ("sum_b", Float64, vec![1.0, 4.0, 16.0]) + )?; + + assert_eq!(result[0], expected); + + Ok(()) + } + + #[tokio::test] + async fn test_ffi_grouping_udaf() -> Result<()> { + let module = get_module()?; + + let ffi_stddev_func = + module + .create_stddev_udaf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_udaf".to_string(), + ))?(); + let foreign_stddev_func: ForeignAggregateUDF = (&ffi_stddev_func).try_into()?; + + let udaf: AggregateUDF = foreign_stddev_func.into(); + + let ctx = SessionContext::default(); + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), + ( + "b", + Float64, + vec![ + 1.0, + 2.0, + 2.0 + 2.0_f64.sqrt(), + 4.0, + 4.0, + 4.0 + 3.0_f64.sqrt(), + 4.0 + 3.0_f64.sqrt() + ] + ) + ) + .unwrap(); + + let df = ctx.read_batch(record_batch)?; + + let df = df + .aggregate( + vec![col("a")], + vec![udaf.call(vec![col("b")]).alias("stddev_b")], + )? + .sort_by(vec![col("a")])?; + + let result = df.collect().await?; + let result = result[0].column_by_name("stddev_b").unwrap(); + let result = result + .as_any() + .downcast_ref::() + .unwrap() + .values(); + + assert!(result.first().unwrap().is_nan()); + assert!(result.get(1).unwrap() - 1.0 < 0.00001); + assert!(result.get(2).unwrap() - 1.0 < 0.00001); + + Ok(()) + } +} diff --git a/datafusion/ffi/tests/ffi_udwf.rs b/datafusion/ffi/tests/ffi_udwf.rs new file mode 100644 index 000000000000..db9ebba0fdfb --- /dev/null +++ b/datafusion/ffi/tests/ffi_udwf.rs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Add an additional module here for convenience to scope this to only +/// when the feature integtation-tests is built +#[cfg(feature = "integration-tests")] +mod tests { + use arrow::array::{create_array, ArrayRef}; + use datafusion::error::{DataFusionError, Result}; + use datafusion::logical_expr::expr::Sort; + use datafusion::logical_expr::{col, ExprFunctionExt, WindowUDF}; + use datafusion::prelude::SessionContext; + use datafusion_ffi::tests::create_record_batch; + use datafusion_ffi::tests::utils::get_module; + use datafusion_ffi::udwf::ForeignWindowUDF; + + #[tokio::test] + async fn test_rank_udwf() -> Result<()> { + let module = get_module()?; + + let ffi_rank_func = + module + .create_rank_udwf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_scalar_udf" + .to_string(), + ))?(); + let foreign_rank_func: ForeignWindowUDF = (&ffi_rank_func).try_into()?; + + let udwf: WindowUDF = foreign_rank_func.into(); + + let ctx = SessionContext::default(); + let df = ctx.read_batch(create_record_batch(-5, 5))?; + + let df = df.select(vec![ + col("a"), + udwf.call(vec![]) + .order_by(vec![Sort::new(col("a"), true, true)]) + .build() + .unwrap() + .alias("rank_a"), + ])?; + + df.clone().show().await?; + + let result = df.collect().await?; + let expected = create_array!(UInt64, [1, 2, 3, 4, 5]) as ArrayRef; + + assert_eq!(result.len(), 1); + assert_eq!(result[0].column(1), &expected); + + Ok(()) + } +} diff --git a/datafusion/functions-aggregate-common/README.md b/datafusion/functions-aggregate-common/README.md new file mode 100644 index 000000000000..61a81e8085a4 --- /dev/null +++ b/datafusion/functions-aggregate-common/README.md @@ -0,0 +1,31 @@ + + +# DataFusion Aggregate Function Library + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate contains common functionality for implementation aggregate and window functions. + +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/functions-aggregate-common/src/accumulator.rs b/datafusion/functions-aggregate-common/src/accumulator.rs index a230bb028909..39303889f0fe 100644 --- a/datafusion/functions-aggregate-common/src/accumulator.rs +++ b/datafusion/functions-aggregate-common/src/accumulator.rs @@ -15,11 +15,11 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, FieldRef, Schema}; use datafusion_common::Result; use datafusion_expr_common::accumulator::Accumulator; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use std::sync::Arc; /// [`AccumulatorArgs`] contains information about how an aggregate @@ -27,8 +27,8 @@ use std::sync::Arc; /// ordering expressions. #[derive(Debug)] pub struct AccumulatorArgs<'a> { - /// The return type of the aggregate function. - pub return_type: &'a DataType, + /// The return field of the aggregate function. + pub return_field: FieldRef, /// The schema of the input arguments pub schema: &'a Schema, @@ -50,9 +50,7 @@ pub struct AccumulatorArgs<'a> { /// ```sql /// SELECT FIRST_VALUE(column1 ORDER BY column2) FROM t; /// ``` - /// - /// If no `ORDER BY` is specified, `ordering_req` will be empty. - pub ordering_req: &'a LexOrdering, + pub order_bys: &'a [PhysicalSortExpr], /// Whether the aggregation is running in reverse order pub is_reversed: bool, @@ -71,6 +69,13 @@ pub struct AccumulatorArgs<'a> { pub exprs: &'a [Arc], } +impl AccumulatorArgs<'_> { + /// Returns the return type of the aggregate function. + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } +} + /// Factory that returns an accumulator for the given aggregate function. pub type AccumulatorFactoryFunction = Arc Result> + Send + Sync>; @@ -81,15 +86,22 @@ pub struct StateFieldsArgs<'a> { /// The name of the aggregate function. pub name: &'a str, - /// The input types of the aggregate function. - pub input_types: &'a [DataType], + /// The input fields of the aggregate function. + pub input_fields: &'a [FieldRef], - /// The return type of the aggregate function. - pub return_type: &'a DataType, + /// The return fields of the aggregate function. + pub return_field: FieldRef, /// The ordering fields of the aggregate function. - pub ordering_fields: &'a [Field], + pub ordering_fields: &'a [FieldRef], /// Whether the aggregate function is distinct. pub is_distinct: bool, } + +impl StateFieldsArgs<'_> { + /// The return type of the aggregate function. + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } +} diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs index 7d772f7c649d..25b40382299b 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs @@ -16,9 +16,11 @@ // under the License. mod bytes; +mod dict; mod native; pub use bytes::BytesDistinctCountAccumulator; pub use bytes::BytesViewDistinctCountAccumulator; +pub use dict::DictionaryCountAccumulator; pub use native::FloatDistinctCountAccumulator; pub use native::PrimitiveDistinctCountAccumulator; diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct/dict.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/dict.rs new file mode 100644 index 000000000000..089d8d5acded --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/dict.rs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, BooleanArray}; +use arrow::downcast_dictionary_array; +use datafusion_common::{arrow_datafusion_err, ScalarValue}; +use datafusion_common::{internal_err, DataFusionError}; +use datafusion_expr_common::accumulator::Accumulator; + +#[derive(Debug)] +pub struct DictionaryCountAccumulator { + inner: Box, +} + +impl DictionaryCountAccumulator { + pub fn new(inner: Box) -> Self { + Self { inner } + } +} + +impl Accumulator for DictionaryCountAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { + let values: Vec<_> = values + .iter() + .map(|dict| { + downcast_dictionary_array! { + dict => { + let buff: BooleanArray = dict.occupancy().into(); + arrow::compute::filter( + dict.values(), + &buff + ).map_err(|e| arrow_datafusion_err!(e)) + }, + _ => internal_err!("DictionaryCountAccumulator only supports dictionary arrays") + } + }) + .collect::, _>>()?; + self.inner.update_batch(values.as_slice()) + } + + fn evaluate(&mut self) -> datafusion_common::Result { + self.inner.evaluate() + } + + fn size(&self) -> usize { + self.inner.size() + } + + fn state(&mut self) -> datafusion_common::Result> { + self.inner.state() + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> { + self.inner.merge_batch(states) + } +} diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index e629e99e1657..987ba57f7719 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -636,7 +636,7 @@ mod test { #[test] fn accumulate_fuzz() { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..100 { Fixture::new_random(&mut rng).run(); } @@ -661,23 +661,23 @@ mod test { impl Fixture { fn new_random(rng: &mut ThreadRng) -> Self { // Number of input values in a batch - let num_values: usize = rng.gen_range(1..200); + let num_values: usize = rng.random_range(1..200); // number of distinct groups - let num_groups: usize = rng.gen_range(2..1000); + let num_groups: usize = rng.random_range(2..1000); let max_group = num_groups - 1; let group_indices: Vec = (0..num_values) - .map(|_| rng.gen_range(0..max_group)) + .map(|_| rng.random_range(0..max_group)) .collect(); - let values: Vec = (0..num_values).map(|_| rng.gen()).collect(); + let values: Vec = (0..num_values).map(|_| rng.random()).collect(); // 10% chance of false // 10% change of null // 80% chance of true let filter: BooleanArray = (0..num_values) .map(|_| { - let filter_value = rng.gen_range(0.0..1.0); + let filter_value = rng.random_range(0.0..1.0); if filter_value < 0.1 { Some(false) } else if filter_value < 0.2 { @@ -690,14 +690,14 @@ mod test { // random values with random number and location of nulls // random null percentage - let null_pct: f32 = rng.gen_range(0.0..1.0); + let null_pct: f32 = rng.random_range(0.0..1.0); let values_with_nulls: Vec> = (0..num_values) .map(|_| { - let is_null = null_pct < rng.gen_range(0.0..1.0); + let is_null = null_pct < rng.random_range(0.0..1.0); if is_null { None } else { - Some(rng.gen()) + Some(rng.random()) } }) .collect(); diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs index 6a8946034cbc..c8c7736bba14 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs @@ -20,7 +20,7 @@ use arrow::array::{ Array, ArrayRef, ArrowNumericType, AsArray, BinaryArray, BinaryViewArray, BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray, StringArray, - StringViewArray, + StringViewArray, StructArray, }; use arrow::buffer::NullBuffer; use arrow::datatypes::DataType; @@ -193,6 +193,18 @@ pub fn set_nulls_dyn(input: &dyn Array, nulls: Option) -> Result { + let input = input.as_struct(); + // safety: values / offsets came from a valid struct array + // and we checked nulls has the same length as values + unsafe { + Arc::new(StructArray::new_unchecked( + input.fields().clone(), + input.columns().to_vec(), + nulls, + )) + } + } _ => { return not_impl_err!("Applying nulls {:?}", input.data_type()); } diff --git a/datafusion/functions-aggregate-common/src/lib.rs b/datafusion/functions-aggregate-common/src/lib.rs index da718e7ceefe..203ae98fe1ed 100644 --- a/datafusion/functions-aggregate-common/src/lib.rs +++ b/datafusion/functions-aggregate-common/src/lib.rs @@ -34,6 +34,7 @@ pub mod accumulator; pub mod aggregate; pub mod merge_arrays; +pub mod min_max; pub mod order; pub mod stats; pub mod tdigest; diff --git a/datafusion/functions-aggregate-common/src/min_max.rs b/datafusion/functions-aggregate-common/src/min_max.rs new file mode 100644 index 000000000000..6d9f7f464362 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/min_max.rs @@ -0,0 +1,353 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Basic min/max functionality shared across DataFusion aggregate functions + +use arrow::array::{ + ArrayRef, AsArray as _, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, + Date64Array, Decimal128Array, Decimal256Array, DurationMicrosecondArray, + DurationMillisecondArray, DurationNanosecondArray, DurationSecondArray, Float16Array, + Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, + IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, + LargeBinaryArray, LargeStringArray, StringArray, StringViewArray, + Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, +}; +use arrow::compute; +use arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; +use datafusion_common::{downcast_value, Result, ScalarValue}; +use std::cmp::Ordering; + +// Statically-typed version of min/max(array) -> ScalarValue for string types +macro_rules! typed_min_max_batch_string { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ + let array = downcast_value!($VALUES, $ARRAYTYPE); + let value = compute::$OP(array); + let value = value.and_then(|e| Some(e.to_string())); + ScalarValue::$SCALAR(value) + }}; +} + +// Statically-typed version of min/max(array) -> ScalarValue for binary types. +macro_rules! typed_min_max_batch_binary { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ + let array = downcast_value!($VALUES, $ARRAYTYPE); + let value = compute::$OP(array); + let value = value.and_then(|e| Some(e.to_vec())); + ScalarValue::$SCALAR(value) + }}; +} + +// Statically-typed version of min/max(array) -> ScalarValue for non-string types. +macro_rules! typed_min_max_batch { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ + let array = downcast_value!($VALUES, $ARRAYTYPE); + let value = compute::$OP(array); + ScalarValue::$SCALAR(value, $($EXTRA_ARGS.clone()),*) + }}; +} + +// Statically-typed version of min/max(array) -> ScalarValue for non-string types. +// this is a macro to support both operations (min and max). +macro_rules! min_max_batch { + ($VALUES:expr, $OP:ident) => {{ + match $VALUES.data_type() { + DataType::Null => ScalarValue::Null, + DataType::Decimal128(precision, scale) => { + typed_min_max_batch!( + $VALUES, + Decimal128Array, + Decimal128, + $OP, + precision, + scale + ) + } + DataType::Decimal256(precision, scale) => { + typed_min_max_batch!( + $VALUES, + Decimal256Array, + Decimal256, + $OP, + precision, + scale + ) + } + // all types that have a natural order + DataType::Float64 => { + typed_min_max_batch!($VALUES, Float64Array, Float64, $OP) + } + DataType::Float32 => { + typed_min_max_batch!($VALUES, Float32Array, Float32, $OP) + } + DataType::Float16 => { + typed_min_max_batch!($VALUES, Float16Array, Float16, $OP) + } + DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP), + DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP), + DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP), + DataType::Int8 => typed_min_max_batch!($VALUES, Int8Array, Int8, $OP), + DataType::UInt64 => typed_min_max_batch!($VALUES, UInt64Array, UInt64, $OP), + DataType::UInt32 => typed_min_max_batch!($VALUES, UInt32Array, UInt32, $OP), + DataType::UInt16 => typed_min_max_batch!($VALUES, UInt16Array, UInt16, $OP), + DataType::UInt8 => typed_min_max_batch!($VALUES, UInt8Array, UInt8, $OP), + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + typed_min_max_batch!( + $VALUES, + TimestampSecondArray, + TimestampSecond, + $OP, + tz_opt + ) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_min_max_batch!( + $VALUES, + TimestampMillisecondArray, + TimestampMillisecond, + $OP, + tz_opt + ), + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_min_max_batch!( + $VALUES, + TimestampMicrosecondArray, + TimestampMicrosecond, + $OP, + tz_opt + ), + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_min_max_batch!( + $VALUES, + TimestampNanosecondArray, + TimestampNanosecond, + $OP, + tz_opt + ), + DataType::Date32 => typed_min_max_batch!($VALUES, Date32Array, Date32, $OP), + DataType::Date64 => typed_min_max_batch!($VALUES, Date64Array, Date64, $OP), + DataType::Time32(TimeUnit::Second) => { + typed_min_max_batch!($VALUES, Time32SecondArray, Time32Second, $OP) + } + DataType::Time32(TimeUnit::Millisecond) => { + typed_min_max_batch!( + $VALUES, + Time32MillisecondArray, + Time32Millisecond, + $OP + ) + } + DataType::Time64(TimeUnit::Microsecond) => { + typed_min_max_batch!( + $VALUES, + Time64MicrosecondArray, + Time64Microsecond, + $OP + ) + } + DataType::Time64(TimeUnit::Nanosecond) => { + typed_min_max_batch!( + $VALUES, + Time64NanosecondArray, + Time64Nanosecond, + $OP + ) + } + DataType::Interval(IntervalUnit::YearMonth) => { + typed_min_max_batch!( + $VALUES, + IntervalYearMonthArray, + IntervalYearMonth, + $OP + ) + } + DataType::Interval(IntervalUnit::DayTime) => { + typed_min_max_batch!($VALUES, IntervalDayTimeArray, IntervalDayTime, $OP) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + typed_min_max_batch!( + $VALUES, + IntervalMonthDayNanoArray, + IntervalMonthDayNano, + $OP + ) + } + DataType::Duration(TimeUnit::Second) => { + typed_min_max_batch!($VALUES, DurationSecondArray, DurationSecond, $OP) + } + DataType::Duration(TimeUnit::Millisecond) => { + typed_min_max_batch!( + $VALUES, + DurationMillisecondArray, + DurationMillisecond, + $OP + ) + } + DataType::Duration(TimeUnit::Microsecond) => { + typed_min_max_batch!( + $VALUES, + DurationMicrosecondArray, + DurationMicrosecond, + $OP + ) + } + DataType::Duration(TimeUnit::Nanosecond) => { + typed_min_max_batch!( + $VALUES, + DurationNanosecondArray, + DurationNanosecond, + $OP + ) + } + other => { + // This should have been handled before + return datafusion_common::internal_err!( + "Min/Max accumulator not implemented for type {:?}", + other + ); + } + } + }}; +} + +/// dynamically-typed min(array) -> ScalarValue +pub fn min_batch(values: &ArrayRef) -> Result { + Ok(match values.data_type() { + DataType::Utf8 => { + typed_min_max_batch_string!(values, StringArray, Utf8, min_string) + } + DataType::LargeUtf8 => { + typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string) + } + DataType::Utf8View => { + typed_min_max_batch_string!( + values, + StringViewArray, + Utf8View, + min_string_view + ) + } + DataType::Boolean => { + typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean) + } + DataType::Binary => { + typed_min_max_batch_binary!(&values, BinaryArray, Binary, min_binary) + } + DataType::LargeBinary => { + typed_min_max_batch_binary!( + &values, + LargeBinaryArray, + LargeBinary, + min_binary + ) + } + DataType::BinaryView => { + typed_min_max_batch_binary!( + &values, + BinaryViewArray, + BinaryView, + min_binary_view + ) + } + DataType::Struct(_) => min_max_batch_generic(values, Ordering::Greater)?, + DataType::List(_) => min_max_batch_generic(values, Ordering::Greater)?, + DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Greater)?, + DataType::FixedSizeList(_, _) => { + min_max_batch_generic(values, Ordering::Greater)? + } + DataType::Dictionary(_, _) => { + let values = values.as_any_dictionary().values(); + min_batch(values)? + } + _ => min_max_batch!(values, min), + }) +} + +/// Generic min/max implementation for complex types +fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result { + if array.len() == array.null_count() { + return ScalarValue::try_from(array.data_type()); + } + let mut extreme = ScalarValue::try_from_array(array, 0)?; + for i in 1..array.len() { + let current = ScalarValue::try_from_array(array, i)?; + if current.is_null() { + continue; + } + if extreme.is_null() { + extreme = current; + continue; + } + if let Some(cmp) = extreme.partial_cmp(¤t) { + if cmp == ordering { + extreme = current; + } + } + } + + Ok(extreme) +} + +/// dynamically-typed max(array) -> ScalarValue +pub fn max_batch(values: &ArrayRef) -> Result { + Ok(match values.data_type() { + DataType::Utf8 => { + typed_min_max_batch_string!(values, StringArray, Utf8, max_string) + } + DataType::LargeUtf8 => { + typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string) + } + DataType::Utf8View => { + typed_min_max_batch_string!( + values, + StringViewArray, + Utf8View, + max_string_view + ) + } + DataType::Boolean => { + typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean) + } + DataType::Binary => { + typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary) + } + DataType::BinaryView => { + typed_min_max_batch_binary!( + &values, + BinaryViewArray, + BinaryView, + max_binary_view + ) + } + DataType::LargeBinary => { + typed_min_max_batch_binary!( + &values, + LargeBinaryArray, + LargeBinary, + max_binary + ) + } + DataType::Struct(_) => min_max_batch_generic(values, Ordering::Less)?, + DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?, + DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?, + DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?, + DataType::Dictionary(_, _) => { + let values = values.as_any_dictionary().values(); + max_batch(values)? + } + _ => min_max_batch!(values, max), + }) +} diff --git a/datafusion/functions-aggregate-common/src/utils.rs b/datafusion/functions-aggregate-common/src/utils.rs index 083dac615b5d..2f20e916743b 100644 --- a/datafusion/functions-aggregate-common/src/utils.rs +++ b/datafusion/functions-aggregate-common/src/utils.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use arrow::array::{ArrayRef, AsArray}; -use arrow::datatypes::ArrowNativeType; +use arrow::datatypes::{ArrowNativeType, FieldRef}; use arrow::{ array::ArrowNativeTypeOp, compute::SortOptions, @@ -30,7 +30,7 @@ use arrow::{ }; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr_common::accumulator::Accumulator; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; /// Convert scalar values from an accumulator into arrays. pub fn get_accum_scalar_values_as_arrays( @@ -87,13 +87,13 @@ pub fn adjust_output_array(data_type: &DataType, array: ArrayRef) -> Result Vec { - ordering_req +) -> Vec { + order_bys .iter() .zip(data_types.iter()) .map(|(sort_expr, dtype)| { @@ -104,6 +104,7 @@ pub fn ordering_fields( true, ) }) + .map(Arc::new) .collect() } diff --git a/datafusion/functions-aggregate/README.md b/datafusion/functions-aggregate/README.md index 29b313d2a903..244112d4fd7a 100644 --- a/datafusion/functions-aggregate/README.md +++ b/datafusion/functions-aggregate/README.md @@ -21,7 +21,11 @@ [DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. -This crate contains packages of function that can be used to customize the -functionality of DataFusion. +This crate contains implementations of aggregate functions. + +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/functions-aggregate/benches/array_agg.rs b/datafusion/functions-aggregate/benches/array_agg.rs index e22be611d8d7..6dadb12aba86 100644 --- a/datafusion/functions-aggregate/benches/array_agg.rs +++ b/datafusion/functions-aggregate/benches/array_agg.rs @@ -27,7 +27,7 @@ use datafusion_expr::Accumulator; use datafusion_functions_aggregate::array_agg::ArrayAggAccumulator; use arrow::buffer::OffsetBuffer; -use rand::distributions::{Distribution, Standard}; +use rand::distr::{Distribution, StandardUniform}; use rand::prelude::StdRng; use rand::Rng; use rand::SeedableRng; @@ -43,7 +43,7 @@ fn merge_batch_bench(c: &mut Criterion, name: &str, values: ArrayRef) { b.iter(|| { #[allow(clippy::unit_arg)] black_box( - ArrayAggAccumulator::try_new(&list_item_data_type) + ArrayAggAccumulator::try_new(&list_item_data_type, false) .unwrap() .merge_batch(&[values.clone()]) .unwrap(), @@ -55,23 +55,23 @@ fn merge_batch_bench(c: &mut Criterion, name: &str, values: ArrayRef) { pub fn create_primitive_array(size: usize, null_density: f32) -> PrimitiveArray where T: ArrowPrimitiveType, - Standard: Distribution, + StandardUniform: Distribution, { let mut rng = seedable_rng(); (0..size) .map(|_| { - if rng.gen::() < null_density { + if rng.random::() < null_density { None } else { - Some(rng.gen()) + Some(rng.random()) } }) .collect() } /// Create List array with the given item data type, null density, null locations and zero length lists density -/// Creates an random (but fixed-seeded) array of a given size and null density +/// Creates a random (but fixed-seeded) array of a given size and null density pub fn create_list_array( size: usize, null_density: f32, @@ -79,20 +79,20 @@ pub fn create_list_array( ) -> ListArray where T: ArrowPrimitiveType, - Standard: Distribution, + StandardUniform: Distribution, { let mut nulls_builder = NullBufferBuilder::new(size); - let mut rng = seedable_rng(); + let mut rng = StdRng::seed_from_u64(42); let offsets = OffsetBuffer::from_lengths((0..size).map(|_| { - let is_null = rng.gen::() < null_density; + let is_null = rng.random::() < null_density; - let mut length = rng.gen_range(1..10); + let mut length = rng.random_range(1..10); if is_null { nulls_builder.append_null(); - if rng.gen::() <= zero_length_lists_probability { + if rng.random::() <= zero_length_lists_probability { length = 0; } } else { diff --git a/datafusion/functions-aggregate/benches/count.rs b/datafusion/functions-aggregate/benches/count.rs index 8bde7d04c44d..80cb65be2ed7 100644 --- a/datafusion/functions-aggregate/benches/count.rs +++ b/datafusion/functions-aggregate/benches/count.rs @@ -15,23 +15,29 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use arrow::array::{ArrayRef, BooleanArray}; use arrow::datatypes::{DataType, Field, Int32Type, Schema}; -use arrow::util::bench_util::{create_boolean_array, create_primitive_array}; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_expr::{function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator}; +use arrow::util::bench_util::{ + create_boolean_array, create_dict_from_values, create_primitive_array, + create_string_array_with_len, +}; + +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::{Accumulator, AggregateUDFImpl, GroupsAccumulator}; use datafusion_functions_aggregate::count::Count; use datafusion_physical_expr::expressions::col; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use std::sync::Arc; -fn prepare_accumulator() -> Box { +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +fn prepare_group_accumulator() -> Box { let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Int32, true)])); let accumulator_args = AccumulatorArgs { - return_type: &DataType::Int64, + return_field: Field::new("f", DataType::Int64, true).into(), schema: &schema, ignore_nulls: false, - ordering_req: &LexOrdering::default(), + order_bys: &[], is_reversed: false, name: "COUNT(f)", is_distinct: false, @@ -44,13 +50,34 @@ fn prepare_accumulator() -> Box { .unwrap() } +fn prepare_accumulator() -> Box { + let schema = Arc::new(Schema::new(vec![Field::new( + "f", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + )])); + let accumulator_args = AccumulatorArgs { + return_field: Arc::new(Field::new_list_field(DataType::Int64, true)), + schema: &schema, + ignore_nulls: false, + order_bys: &[], + is_reversed: false, + name: "COUNT(f)", + is_distinct: true, + exprs: &[col("f", &schema).unwrap()], + }; + let count_fn = Count::new(); + + count_fn.accumulator(accumulator_args).unwrap() +} + fn convert_to_state_bench( c: &mut Criterion, name: &str, values: ArrayRef, opt_filter: Option<&BooleanArray>, ) { - let accumulator = prepare_accumulator(); + let accumulator = prepare_group_accumulator(); c.bench_function(name, |b| { b.iter(|| { black_box( @@ -89,6 +116,18 @@ fn count_benchmark(c: &mut Criterion) { values, Some(&filter), ); + + let arr = create_string_array_with_len::(20, 0.0, 50); + let values = + Arc::new(create_dict_from_values::(200_000, 0.8, &arr)) as ArrayRef; + + let mut accumulator = prepare_accumulator(); + c.bench_function("count low cardinality dict 20% nulls, no filter", |b| { + b.iter(|| { + #[allow(clippy::unit_arg)] + black_box(accumulator.update_batch(&[values.clone()]).unwrap()) + }) + }); } criterion_group!(benches, count_benchmark); diff --git a/datafusion/functions-aggregate/benches/sum.rs b/datafusion/functions-aggregate/benches/sum.rs index fab53ae94b25..4517db6b1510 100644 --- a/datafusion/functions-aggregate/benches/sum.rs +++ b/datafusion/functions-aggregate/benches/sum.rs @@ -15,23 +15,26 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use arrow::array::{ArrayRef, BooleanArray}; use arrow::datatypes::{DataType, Field, Int64Type, Schema}; use arrow::util::bench_util::{create_boolean_array, create_primitive_array}; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; + use datafusion_expr::{function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator}; use datafusion_functions_aggregate::sum::Sum; use datafusion_physical_expr::expressions::col; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use std::sync::Arc; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; fn prepare_accumulator(data_type: &DataType) -> Box { - let schema = Arc::new(Schema::new(vec![Field::new("f", data_type.clone(), true)])); + let field = Field::new("f", data_type.clone(), true).into(); + let schema = Arc::new(Schema::new(vec![Arc::clone(&field)])); let accumulator_args = AccumulatorArgs { - return_type: data_type, + return_field: field, schema: &schema, ignore_nulls: false, - ordering_req: &LexOrdering::default(), + order_bys: &[], is_reversed: false, name: "SUM(f)", is_distinct: false, diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index c97dba1925ca..0d5dcd5c2085 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -23,7 +23,7 @@ use arrow::array::{ GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, }; use arrow::datatypes::{ - ArrowPrimitiveType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + ArrowPrimitiveType, FieldRef, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; @@ -322,12 +322,13 @@ impl AggregateUDFImpl for ApproxDistinct { Ok(DataType::UInt64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![Field::new( format_state_name(args.name, "hll_registers"), DataType::Binary, false, - )]) + ) + .into()]) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index 787e08bae286..0f2e3039ca9f 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -17,11 +17,11 @@ //! Defines physical expressions for APPROX_MEDIAN that can be evaluated MEDIAN at runtime during query execution +use arrow::datatypes::DataType::{Float64, UInt64}; +use arrow::datatypes::{DataType, Field, FieldRef}; use std::any::Any; use std::fmt::Debug; - -use arrow::datatypes::DataType::{Float64, UInt64}; -use arrow::datatypes::{DataType, Field}; +use std::sync::Arc; use datafusion_common::{not_impl_err, plan_err, Result}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; @@ -45,7 +45,7 @@ make_udaf_expr_and_func!( /// APPROX_MEDIAN aggregate expression #[user_doc( doc_section(label = "Approximate Functions"), - description = "Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(x, 0.5)`.", + description = "Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY x)`.", syntax_example = "approx_median(expression)", sql_example = r#"```sql > SELECT approx_median(column_name) FROM table_name; @@ -91,7 +91,7 @@ impl AggregateUDFImpl for ApproxMedian { self } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new(format_state_name(args.name, "max_size"), UInt64, false), Field::new(format_state_name(args.name, "sum"), Float64, false), @@ -103,7 +103,10 @@ impl AggregateUDFImpl for ApproxMedian { Field::new_list_field(Float64, true), false, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn name(&self) -> &str { diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 1fad5f73703c..9b0d62e936bc 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use arrow::array::{Array, RecordBatch}; use arrow::compute::{filter, is_not_null}; +use arrow::datatypes::FieldRef; use arrow::{ array::{ ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, @@ -29,11 +30,11 @@ use arrow::{ }, datatypes::{DataType, Field, Schema}, }; - use datafusion_common::{ downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err, Result, ScalarValue, }; +use datafusion_expr::expr::{AggregateFunction, Sort}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::utils::format_state_name; @@ -51,29 +52,39 @@ create_func!(ApproxPercentileCont, approx_percentile_cont_udaf); /// Computes the approximate percentile continuous of a set of numbers pub fn approx_percentile_cont( - expression: Expr, + order_by: Sort, percentile: Expr, centroids: Option, ) -> Expr { + let expr = order_by.expr.clone(); + let args = if let Some(centroids) = centroids { - vec![expression, percentile, centroids] + vec![expr, percentile, centroids] } else { - vec![expression, percentile] + vec![expr, percentile] }; - approx_percentile_cont_udaf().call(args) + + Expr::AggregateFunction(AggregateFunction::new_udf( + approx_percentile_cont_udaf(), + args, + false, + None, + Some(vec![order_by]), + None, + )) } #[user_doc( doc_section(label = "Approximate Functions"), description = "Returns the approximate percentile of input values using the t-digest algorithm.", - syntax_example = "approx_percentile_cont(expression, percentile, centroids)", + syntax_example = "approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression)", sql_example = r#"```sql -> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; -+-------------------------------------------------+ -| approx_percentile_cont(column_name, 0.75, 100) | -+-------------------------------------------------+ -| 65.0 | -+-------------------------------------------------+ +> SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++-----------------------------------------------------------------------+ +| approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) | ++-----------------------------------------------------------------------+ +| 65.0 | ++-----------------------------------------------------------------------+ ```"#, standard_argument(name = "expression",), argument( @@ -130,6 +141,19 @@ impl ApproxPercentileCont { args: AccumulatorArgs, ) -> Result { let percentile = validate_input_percentile_expr(&args.exprs[1])?; + + let is_descending = args + .order_bys + .first() + .map(|sort_expr| sort_expr.options.descending) + .unwrap_or(false); + + let percentile = if is_descending { + 1.0 - percentile + } else { + percentile + }; + let tdigest_max_size = if args.exprs.len() == 3 { Some(validate_input_max_size_expr(&args.exprs[2])?) } else { @@ -232,7 +256,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { #[allow(rustdoc::private_intra_doc_links)] /// See [`TDigest::to_scalar_state()`] for a description of the serialized /// state. - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "max_size"), @@ -264,7 +288,10 @@ impl AggregateUDFImpl for ApproxPercentileCont { Field::new_list_field(DataType::Float64, true), false, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn name(&self) -> &str { @@ -292,6 +319,14 @@ impl AggregateUDFImpl for ApproxPercentileCont { Ok(arg_types[0].clone()) } + fn supports_null_handling_clause(&self) -> bool { + false + } + + fn is_ordered_set_aggregate(&self) -> bool { + true + } + fn documentation(&self) -> Option<&Documentation> { self.doc() } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index 16dac2c1b8f0..5180d4588962 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -20,11 +20,8 @@ use std::fmt::{Debug, Formatter}; use std::mem::size_of_val; use std::sync::Arc; -use arrow::{ - array::ArrayRef, - datatypes::{DataType, Field}, -}; - +use arrow::datatypes::FieldRef; +use arrow::{array::ArrayRef, datatypes::DataType}; use datafusion_common::ScalarValue; use datafusion_common::{not_impl_err, plan_err, Result}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; @@ -52,14 +49,14 @@ make_udaf_expr_and_func!( #[user_doc( doc_section(label = "Approximate Functions"), description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.", - syntax_example = "approx_percentile_cont_with_weight(expression, weight, percentile)", + syntax_example = "approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY expression)", sql_example = r#"```sql -> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; -+----------------------------------------------------------------------+ -| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | -+----------------------------------------------------------------------+ -| 78.5 | -+----------------------------------------------------------------------+ +> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++---------------------------------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) | ++---------------------------------------------------------------------------------------------+ +| 78.5 | ++---------------------------------------------------------------------------------------------+ ```"#, standard_argument(name = "expression", prefix = "The"), argument( @@ -174,10 +171,18 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { #[allow(rustdoc::private_intra_doc_links)] /// See [`TDigest::to_scalar_state()`] for a description of the serialized /// state. - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { self.approx_percentile_cont.state_fields(args) } + fn supports_null_handling_clause(&self) -> bool { + false + } + + fn is_ordered_set_aggregate(&self) -> bool { + true + } + fn documentation(&self) -> Option<&Documentation> { self.doc() } diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index d658744c1ba5..4ec73e306e0f 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -17,26 +17,31 @@ //! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`] -use arrow::array::{new_empty_array, Array, ArrayRef, AsArray, ListArray, StructArray}; -use arrow::compute::SortOptions; -use arrow::datatypes::{DataType, Field, Fields}; +use std::cmp::Ordering; +use std::collections::{HashSet, VecDeque}; +use std::mem::{size_of, size_of_val}; +use std::sync::Arc; + +use arrow::array::{ + make_array, new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray, + StructArray, +}; +use arrow::compute::{filter, SortOptions}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields}; use datafusion_common::cast::as_list_array; +use datafusion_common::scalar::copy_array_data; use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; -use datafusion_common::{exec_err, ScalarValue}; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; -use datafusion_expr::{Accumulator, Signature, Volatility}; -use datafusion_expr::{AggregateUDFImpl, Documentation}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; use datafusion_functions_aggregate_common::utils::ordering_fields; use datafusion_macros::user_doc; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use std::cmp::Ordering; -use std::collections::{HashSet, VecDeque}; -use std::mem::{size_of, size_of_val}; -use std::sync::Arc; make_udaf_expr_and_func!( ArrayAgg, @@ -92,10 +97,6 @@ impl AggregateUDFImpl for ArrayAgg { "array_agg" } - fn aliases(&self) -> &[String] { - &[] - } - fn signature(&self) -> &Signature { &self.signature } @@ -107,39 +108,46 @@ impl AggregateUDFImpl for ArrayAgg { )))) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { if args.is_distinct { return Ok(vec![Field::new_list( format_state_name(args.name, "distinct_array_agg"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_types[0].clone(), true), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), true, - )]); + ) + .into()]); } let mut fields = vec![Field::new_list( format_state_name(args.name, "array_agg"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_types[0].clone(), true), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), true, - )]; + ) + .into()]; if args.ordering_fields.is_empty() { return Ok(fields); } let orderings = args.ordering_fields.to_vec(); - fields.push(Field::new_list( - format_state_name(args.name, "array_agg_orderings"), - Field::new_list_field(DataType::Struct(Fields::from(orderings)), true), - false, - )); + fields.push( + Field::new_list( + format_state_name(args.name, "array_agg_orderings"), + Field::new_list_field(DataType::Struct(Fields::from(orderings)), true), + false, + ) + .into(), + ); Ok(fields) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; + let ignore_nulls = + acc_args.ignore_nulls && acc_args.exprs[0].nullable(acc_args.schema)?; if acc_args.is_distinct { // Limitation similar to Postgres. The aggregation function can only mix @@ -156,28 +164,30 @@ impl AggregateUDFImpl for ArrayAgg { // ARRAY_AGG(DISTINCT concat(col, '') ORDER BY concat(col, '')) <- Valid // ARRAY_AGG(DISTINCT col ORDER BY other_col) <- Invalid // ARRAY_AGG(DISTINCT col ORDER BY concat(col, '')) <- Invalid - if acc_args.ordering_req.len() > 1 { - return exec_err!("In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list"); - } - let mut sort_option: Option = None; - if let Some(order) = acc_args.ordering_req.first() { - if !order.expr.eq(&acc_args.exprs[0]) { - return exec_err!("In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list"); + let sort_option = match acc_args.order_bys { + [single] if single.expr.eq(&acc_args.exprs[0]) => Some(single.options), + [] => None, + _ => { + return exec_err!( + "In an aggregate with DISTINCT, ORDER BY expressions must appear in argument list" + ); } - sort_option = Some(order.options) - } + }; return Ok(Box::new(DistinctArrayAggAccumulator::try_new( &data_type, sort_option, + ignore_nulls, )?)); } - if acc_args.ordering_req.is_empty() { - return Ok(Box::new(ArrayAggAccumulator::try_new(&data_type)?)); - } + let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else { + return Ok(Box::new(ArrayAggAccumulator::try_new( + &data_type, + ignore_nulls, + )?)); + }; - let ordering_dtypes = acc_args - .ordering_req + let ordering_dtypes = ordering .iter() .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; @@ -185,8 +195,9 @@ impl AggregateUDFImpl for ArrayAgg { OrderSensitiveArrayAggAccumulator::try_new( &data_type, &ordering_dtypes, - acc_args.ordering_req.clone(), + ordering, acc_args.is_reversed, + ignore_nulls, ) .map(|acc| Box::new(acc) as _) } @@ -204,18 +215,20 @@ impl AggregateUDFImpl for ArrayAgg { pub struct ArrayAggAccumulator { values: Vec, datatype: DataType, + ignore_nulls: bool, } impl ArrayAggAccumulator { /// new array_agg accumulator based on given item data type - pub fn try_new(datatype: &DataType) -> Result { + pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result { Ok(Self { values: vec![], datatype: datatype.clone(), + ignore_nulls, }) } - /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non empty list) + /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non-empty list) /// If there are gaps but only in the end of the list array, the function will return the values without the null values in the end fn get_optional_values_to_merge_as_is(list_array: &ListArray) -> Option { let offsets = list_array.value_offsets(); @@ -239,7 +252,7 @@ impl ArrayAggAccumulator { return Some(list_array.values().slice(0, 0)); } - // According to the Arrow spec, null values can point to non empty lists + // According to the Arrow spec, null values can point to non-empty lists // So this will check if all null values starting from the first valid value to the last one point to a 0 length list so we can just slice the underlying value // Unwrapping is safe as we just checked if there is a null value @@ -247,7 +260,7 @@ impl ArrayAggAccumulator { let mut valid_slices_iter = nulls.valid_slices(); - // This is safe as we validated that that are at least 1 valid value in the array + // This is safe as we validated that there is at least 1 valid value in the array let (start, end) = valid_slices_iter.next().unwrap(); let start_offset = offsets[start]; @@ -257,7 +270,7 @@ impl ArrayAggAccumulator { let mut end_offset_of_last_valid_value = offsets[end]; for (start, end) in valid_slices_iter { - // If there is a null value that point to a non empty list than the start offset of the valid value + // If there is a null value that point to a non-empty list than the start offset of the valid value // will be different that the end offset of the last valid value if offsets[start] != end_offset_of_last_valid_value { return None; @@ -288,10 +301,27 @@ impl Accumulator for ArrayAggAccumulator { return internal_err!("expects single batch"); } - let val = Arc::clone(&values[0]); + let val = &values[0]; + let nulls = if self.ignore_nulls { + val.logical_nulls() + } else { + None + }; + + let val = match nulls { + Some(nulls) if nulls.null_count() >= val.len() => return Ok(()), + Some(nulls) => filter(val, &BooleanArray::new(nulls.inner().clone(), None))?, + None => Arc::clone(val), + }; + if !val.is_empty() { - self.values.push(val); + // The ArrayRef might be holding a reference to its original input buffer, so + // storing it here directly copied/compacted avoids over accounting memory + // not used here. + self.values + .push(make_array(copy_array_data(&val.to_data()))); } + Ok(()) } @@ -360,17 +390,20 @@ struct DistinctArrayAggAccumulator { values: HashSet, datatype: DataType, sort_options: Option, + ignore_nulls: bool, } impl DistinctArrayAggAccumulator { pub fn try_new( datatype: &DataType, sort_options: Option, + ignore_nulls: bool, ) -> Result { Ok(Self { values: HashSet::new(), datatype: datatype.clone(), sort_options, + ignore_nulls, }) } } @@ -385,11 +418,21 @@ impl Accumulator for DistinctArrayAggAccumulator { return Ok(()); } - let array = &values[0]; + let val = &values[0]; + let nulls = if self.ignore_nulls { + val.logical_nulls() + } else { + None + }; - for i in 0..array.len() { - let scalar = ScalarValue::try_from_array(&array, i)?; - self.values.insert(scalar); + let nulls = nulls.as_ref(); + if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) { + for i in 0..val.len() { + if nulls.is_none_or(|nulls| nulls.is_valid(i)) { + self.values + .insert(ScalarValue::try_from_array(val, i)?.compacted()); + } + } } Ok(()) @@ -471,6 +514,8 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator { ordering_req: LexOrdering, /// Whether the aggregation is running in reverse. reverse: bool, + /// Whether the aggregation should ignore null values. + ignore_nulls: bool, } impl OrderSensitiveArrayAggAccumulator { @@ -481,6 +526,7 @@ impl OrderSensitiveArrayAggAccumulator { ordering_dtypes: &[DataType], ordering_req: LexOrdering, reverse: bool, + ignore_nulls: bool, ) -> Result { let mut datatypes = vec![datatype.clone()]; datatypes.extend(ordering_dtypes.iter().cloned()); @@ -490,8 +536,34 @@ impl OrderSensitiveArrayAggAccumulator { datatypes, ordering_req, reverse, + ignore_nulls, }) } + + fn evaluate_orderings(&self) -> Result { + let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); + + let column_wise_ordering_values = if self.ordering_values.is_empty() { + fields + .iter() + .map(|f| new_empty_array(f.data_type())) + .collect::>() + } else { + (0..fields.len()) + .map(|i| { + let column_values = self.ordering_values.iter().map(|x| x[i].clone()); + ScalarValue::iter_to_array(column_values) + }) + .collect::>()? + }; + + let ordering_array = StructArray::try_new( + Fields::from(fields), + column_wise_ordering_values, + None, + )?; + Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar()) + } } impl Accumulator for OrderSensitiveArrayAggAccumulator { @@ -500,11 +572,28 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { return Ok(()); } - let n_row = values[0].len(); - for index in 0..n_row { - let row = get_row_at_idx(values, index)?; - self.values.push(row[0].clone()); - self.ordering_values.push(row[1..].to_vec()); + let val = &values[0]; + let ord = &values[1..]; + let nulls = if self.ignore_nulls { + val.logical_nulls() + } else { + None + }; + + let nulls = nulls.as_ref(); + if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) { + for i in 0..val.len() { + if nulls.is_none_or(|nulls| nulls.is_valid(i)) { + self.values + .push(ScalarValue::try_from_array(val, i)?.compacted()); + self.ordering_values.push( + get_row_at_idx(ord, i)? + .into_iter() + .map(|v| v.compacted()) + .collect(), + ) + } + } } Ok(()) @@ -635,41 +724,15 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { } } -impl OrderSensitiveArrayAggAccumulator { - fn evaluate_orderings(&self) -> Result { - let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]); - let num_columns = fields.len(); - let struct_field = Fields::from(fields.clone()); - - let mut column_wise_ordering_values = vec![]; - for i in 0..num_columns { - let column_values = self - .ordering_values - .iter() - .map(|x| x[i].clone()) - .collect::>(); - let array = if column_values.is_empty() { - new_empty_array(fields[i].data_type()) - } else { - ScalarValue::iter_to_array(column_values.into_iter())? - }; - column_wise_ordering_values.push(array); - } - - let ordering_array = - StructArray::try_new(struct_field, column_wise_ordering_values, None)?; - Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar()) - } -} - #[cfg(test)] mod tests { use super::*; + use arrow::array::{ListBuilder, StringBuilder}; use arrow::datatypes::{FieldRef, Schema}; use datafusion_common::cast::as_generic_string_array; use datafusion_common::internal_err; use datafusion_physical_expr::expressions::Column; - use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use std::sync::Arc; #[test] @@ -931,10 +994,60 @@ mod tests { Ok(()) } + #[test] + fn does_not_over_account_memory() -> Result<()> { + let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?; + + acc1.update_batch(&[data(["a", "c", "b"])])?; + acc2.update_batch(&[data(["b", "c", "a"])])?; + acc1 = merge(acc1, acc2)?; + + // without compaction, the size is 2652. + assert_eq!(acc1.size(), 732); + + Ok(()) + } + #[test] + fn does_not_over_account_memory_distinct() -> Result<()> { + let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string() + .distinct() + .build_two()?; + + acc1.update_batch(&[string_list_data([ + vec!["a", "b", "c"], + vec!["d", "e", "f"], + ])])?; + acc2.update_batch(&[string_list_data([vec!["e", "f", "g"]])])?; + acc1 = merge(acc1, acc2)?; + + // without compaction, the size is 16660 + assert_eq!(acc1.size(), 1660); + + Ok(()) + } + + #[test] + fn does_not_over_account_memory_ordered() -> Result<()> { + let mut acc = ArrayAggAccumulatorBuilder::string() + .order_by_col("col", SortOptions::new(false, false)) + .build()?; + + acc.update_batch(&[string_list_data([ + vec!["a", "b", "c"], + vec!["c", "d", "e"], + vec!["b", "c", "d"], + ])])?; + + // without compaction, the size is 17112 + assert_eq!(acc.size(), 2112); + + Ok(()) + } + struct ArrayAggAccumulatorBuilder { - data_type: DataType, + return_field: FieldRef, distinct: bool, - ordering: LexOrdering, + order_bys: Vec, schema: Schema, } @@ -945,15 +1058,13 @@ mod tests { fn new(data_type: DataType) -> Self { Self { - data_type: data_type.clone(), - distinct: Default::default(), - ordering: Default::default(), + return_field: Field::new("f", data_type.clone(), true).into(), + distinct: false, + order_bys: vec![], schema: Schema { fields: Fields::from(vec![Field::new( "col", - DataType::List(FieldRef::new(Field::new( - "item", data_type, true, - ))), + DataType::new_list(data_type, true), true, )]), metadata: Default::default(), @@ -967,22 +1078,23 @@ mod tests { } fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self { - self.ordering.extend([PhysicalSortExpr::new( + let new_order = PhysicalSortExpr::new( Arc::new( Column::new_with_schema(col, &self.schema) .expect("column not available in schema"), ), sort_options, - )]); + ); + self.order_bys.push(new_order); self } fn build(&self) -> Result> { ArrayAgg::default().accumulator(AccumulatorArgs { - return_type: &self.data_type, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: false, - ordering_req: &self.ordering, + order_bys: &self.order_bys, is_reversed: false, name: "", is_distinct: self.distinct, @@ -1007,10 +1119,19 @@ mod tests { fn print_nulls(sort: Vec>) -> Vec { sort.into_iter() - .map(|v| v.unwrap_or("NULL".to_string())) + .map(|v| v.unwrap_or_else(|| "NULL".to_string())) .collect() } + fn string_list_data<'a>(data: impl IntoIterator>) -> ArrayRef { + let mut builder = ListBuilder::new(StringBuilder::new()); + for string_list in data.into_iter() { + builder.append_value(string_list.iter().map(Some).collect::>()); + } + + Arc::new(builder.finish()) + } + fn data(list: [T; N]) -> ArrayRef where ScalarValue: From, diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 141771b0412f..3c1d33e093b5 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -24,8 +24,9 @@ use arrow::array::{ use arrow::compute::sum; use arrow::datatypes::{ - i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field, - Float64Type, UInt64Type, + i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, + DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, + DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type, }; use datafusion_common::{ exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue, @@ -120,7 +121,7 @@ impl AggregateUDFImpl for Avg { let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; // instantiate specialized accumulator based for the type - match (&data_type, acc_args.return_type) { + match (&data_type, acc_args.return_field.data_type()) { (Float64, Float64) => Ok(Box::::default()), ( Decimal128(sum_precision, sum_scale), @@ -145,15 +146,25 @@ impl AggregateUDFImpl for Avg { target_precision: *target_precision, target_scale: *target_scale, })), + + (Duration(time_unit), Duration(result_unit)) => { + Ok(Box::new(DurationAvgAccumulator { + sum: None, + count: 0, + time_unit: *time_unit, + result_unit: *result_unit, + })) + } + _ => exec_err!( "AvgAccumulator for ({} --> {})", &data_type, - acc_args.return_type + acc_args.return_field.data_type() ), } } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "count"), @@ -162,16 +173,19 @@ impl AggregateUDFImpl for Avg { ), Field::new( format_state_name(args.name, "sum"), - args.input_types[0].clone(), + args.input_fields[0].data_type().clone(), true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { matches!( - args.return_type, - DataType::Float64 | DataType::Decimal128(_, _) + args.return_field.data_type(), + DataType::Float64 | DataType::Decimal128(_, _) | DataType::Duration(_) ) } @@ -183,11 +197,11 @@ impl AggregateUDFImpl for Avg { let data_type = args.exprs[0].data_type(args.schema)?; // instantiate specialized accumulator based for the type - match (&data_type, args.return_type) { + match (&data_type, args.return_field.data_type()) { (Float64, Float64) => { Ok(Box::new(AvgGroupsAccumulator::::new( &data_type, - args.return_type, + args.return_field.data_type(), |sum: f64, count: u64| Ok(sum / count as f64), ))) } @@ -206,7 +220,7 @@ impl AggregateUDFImpl for Avg { Ok(Box::new(AvgGroupsAccumulator::::new( &data_type, - args.return_type, + args.return_field.data_type(), avg_fn, ))) } @@ -227,15 +241,54 @@ impl AggregateUDFImpl for Avg { Ok(Box::new(AvgGroupsAccumulator::::new( &data_type, - args.return_type, + args.return_field.data_type(), avg_fn, ))) } + (Duration(time_unit), Duration(_result_unit)) => { + let avg_fn = move |sum: i64, count: u64| Ok(sum / count as i64); + + match time_unit { + TimeUnit::Second => Ok(Box::new(AvgGroupsAccumulator::< + DurationSecondType, + _, + >::new( + &data_type, + args.return_type(), + avg_fn, + ))), + TimeUnit::Millisecond => Ok(Box::new(AvgGroupsAccumulator::< + DurationMillisecondType, + _, + >::new( + &data_type, + args.return_type(), + avg_fn, + ))), + TimeUnit::Microsecond => Ok(Box::new(AvgGroupsAccumulator::< + DurationMicrosecondType, + _, + >::new( + &data_type, + args.return_type(), + avg_fn, + ))), + TimeUnit::Nanosecond => Ok(Box::new(AvgGroupsAccumulator::< + DurationNanosecondType, + _, + >::new( + &data_type, + args.return_type(), + avg_fn, + ))), + } + } + _ => not_impl_err!( "AvgGroupsAccumulator for ({} --> {})", &data_type, - args.return_type + args.return_field.data_type() ), } } @@ -399,6 +452,105 @@ impl Accumulator for DecimalAvgAccumu } } +/// An accumulator to compute the average for duration values +#[derive(Debug)] +struct DurationAvgAccumulator { + sum: Option, + count: u64, + time_unit: TimeUnit, + result_unit: TimeUnit, +} + +impl Accumulator for DurationAvgAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + self.count += (array.len() - array.null_count()) as u64; + + let sum_value = match self.time_unit { + TimeUnit::Second => sum(array.as_primitive::()), + TimeUnit::Millisecond => sum(array.as_primitive::()), + TimeUnit::Microsecond => sum(array.as_primitive::()), + TimeUnit::Nanosecond => sum(array.as_primitive::()), + }; + + if let Some(x) = sum_value { + let v = self.sum.get_or_insert(0); + *v += x; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let avg = self.sum.map(|sum| sum / self.count as i64); + + match self.result_unit { + TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)), + TimeUnit::Millisecond => Ok(ScalarValue::DurationMillisecond(avg)), + TimeUnit::Microsecond => Ok(ScalarValue::DurationMicrosecond(avg)), + TimeUnit::Nanosecond => Ok(ScalarValue::DurationNanosecond(avg)), + } + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn state(&mut self) -> Result> { + let duration_value = match self.time_unit { + TimeUnit::Second => ScalarValue::DurationSecond(self.sum), + TimeUnit::Millisecond => ScalarValue::DurationMillisecond(self.sum), + TimeUnit::Microsecond => ScalarValue::DurationMicrosecond(self.sum), + TimeUnit::Nanosecond => ScalarValue::DurationNanosecond(self.sum), + }; + + Ok(vec![ScalarValue::from(self.count), duration_value]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.count += sum(states[0].as_primitive::()).unwrap_or_default(); + + let sum_value = match self.time_unit { + TimeUnit::Second => sum(states[1].as_primitive::()), + TimeUnit::Millisecond => { + sum(states[1].as_primitive::()) + } + TimeUnit::Microsecond => { + sum(states[1].as_primitive::()) + } + TimeUnit::Nanosecond => { + sum(states[1].as_primitive::()) + } + }; + + if let Some(x) = sum_value { + let v = self.sum.get_or_insert(0); + *v += x; + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + self.count -= (array.len() - array.null_count()) as u64; + + let sum_value = match self.time_unit { + TimeUnit::Second => sum(array.as_primitive::()), + TimeUnit::Millisecond => sum(array.as_primitive::()), + TimeUnit::Microsecond => sum(array.as_primitive::()), + TimeUnit::Nanosecond => sum(array.as_primitive::()), + }; + + if let Some(x) = sum_value { + self.sum = Some(self.sum.unwrap() - x); + } + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } +} + /// An accumulator to compute the average of `[PrimitiveArray]`. /// Stores values as native types, and does overflow checking /// diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index 50ab50abc9e2..4512162ba5d3 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -25,8 +25,8 @@ use std::mem::{size_of, size_of_val}; use ahash::RandomState; use arrow::array::{downcast_integer, Array, ArrayRef, AsArray}; use arrow::datatypes::{ - ArrowNativeType, ArrowNumericType, DataType, Field, Int16Type, Int32Type, Int64Type, - Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + ArrowNativeType, ArrowNumericType, DataType, Field, FieldRef, Int16Type, Int32Type, + Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use datafusion_common::cast::as_list_array; @@ -87,7 +87,7 @@ macro_rules! accumulator_helper { /// `is_distinct` is boolean value indicating whether the operation is distinct or not. macro_rules! downcast_bitwise_accumulator { ($args:ident, $opr:expr, $is_distinct: expr) => { - match $args.return_type { + match $args.return_field.data_type() { DataType::Int8 => accumulator_helper!(Int8Type, $opr, $is_distinct), DataType::Int16 => accumulator_helper!(Int16Type, $opr, $is_distinct), DataType::Int32 => accumulator_helper!(Int32Type, $opr, $is_distinct), @@ -101,7 +101,7 @@ macro_rules! downcast_bitwise_accumulator { "{} not supported for {}: {}", stringify!($opr), $args.name, - $args.return_type + $args.return_field.data_type() ) } } @@ -205,7 +205,7 @@ enum BitwiseOperationType { impl Display for BitwiseOperationType { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) + write!(f, "{self:?}") } } @@ -263,7 +263,7 @@ impl AggregateUDFImpl for BitwiseOperation { downcast_bitwise_accumulator!(acc_args, self.operation, acc_args.is_distinct) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { if self.operation == BitwiseOperationType::Xor && args.is_distinct { Ok(vec![Field::new_list( format_state_name( @@ -271,15 +271,17 @@ impl AggregateUDFImpl for BitwiseOperation { format!("{} distinct", self.name()).as_str(), ), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.return_type.clone(), true), + Field::new_list_field(args.return_type().clone(), true), false, - )]) + ) + .into()]) } else { Ok(vec![Field::new( format_state_name(args.name, self.name()), - args.return_type.clone(), + args.return_field.data_type().clone(), true, - )]) + ) + .into()]) } } @@ -291,7 +293,7 @@ impl AggregateUDFImpl for BitwiseOperation { &self, args: AccumulatorArgs, ) -> Result> { - let data_type = args.return_type; + let data_type = args.return_field.data_type(); let operation = &self.operation; downcast_integer! { data_type => (group_accumulator_helper, data_type, operation), diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index 1b33a7900c00..d779e0a399b5 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -24,8 +24,8 @@ use arrow::array::ArrayRef; use arrow::array::BooleanArray; use arrow::compute::bool_and as compute_bool_and; use arrow::compute::bool_or as compute_bool_or; -use arrow::datatypes::DataType; use arrow::datatypes::Field; +use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::internal_err; use datafusion_common::{downcast_value, not_impl_err}; @@ -150,12 +150,13 @@ impl AggregateUDFImpl for BoolAnd { Ok(Box::::default()) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![Field::new( format_state_name(args.name, self.name()), DataType::Boolean, true, - )]) + ) + .into()]) } fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { @@ -166,22 +167,18 @@ impl AggregateUDFImpl for BoolAnd { &self, args: AccumulatorArgs, ) -> Result> { - match args.return_type { + match args.return_field.data_type() { DataType::Boolean => { Ok(Box::new(BooleanGroupsAccumulator::new(|x, y| x && y, true))) } _ => not_impl_err!( "GroupsAccumulator not supported for {} with {}", args.name, - args.return_type + args.return_field.data_type() ), } } - fn aliases(&self) -> &[String] { - &[] - } - fn order_sensitivity(&self) -> AggregateOrderSensitivity { AggregateOrderSensitivity::Insensitive } @@ -288,12 +285,13 @@ impl AggregateUDFImpl for BoolOr { Ok(Box::::default()) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![Field::new( format_state_name(args.name, self.name()), DataType::Boolean, true, - )]) + ) + .into()]) } fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { @@ -304,7 +302,7 @@ impl AggregateUDFImpl for BoolOr { &self, args: AccumulatorArgs, ) -> Result> { - match args.return_type { + match args.return_field.data_type() { DataType::Boolean => Ok(Box::new(BooleanGroupsAccumulator::new( |x, y| x || y, false, @@ -312,15 +310,11 @@ impl AggregateUDFImpl for BoolOr { _ => not_impl_err!( "GroupsAccumulator not supported for {} with {}", args.name, - args.return_type + args.return_field.data_type() ), } } - fn aliases(&self) -> &[String] { - &[] - } - fn order_sensitivity(&self) -> AggregateOrderSensitivity { AggregateOrderSensitivity::Insensitive } diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index ac57256ce882..0a7345245ca8 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -27,7 +27,7 @@ use arrow::array::{ UInt64Array, }; use arrow::compute::{and, filter, is_not_null, kernels::cast}; -use arrow::datatypes::{Float64Type, UInt64Type}; +use arrow::datatypes::{FieldRef, Float64Type, UInt64Type}; use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, @@ -117,7 +117,7 @@ impl AggregateUDFImpl for Correlation { Ok(Box::new(CorrelationAccumulator::try_new()?)) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), @@ -130,7 +130,10 @@ impl AggregateUDFImpl for Correlation { DataType::Float64, true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 2d995b4a4179..d1fe410321f6 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -16,53 +16,49 @@ // under the License. use ahash::RandomState; -use datafusion_common::stats::Precision; -use datafusion_expr::expr::WindowFunction; -use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; -use datafusion_macros::user_doc; -use datafusion_physical_expr::expressions; -use std::collections::HashSet; -use std::fmt::Debug; -use std::mem::{size_of, size_of_val}; -use std::ops::BitAnd; -use std::sync::Arc; - use arrow::{ - array::{ArrayRef, AsArray}, + array::{Array, ArrayRef, AsArray, BooleanArray, Int64Array, PrimitiveArray}, + buffer::BooleanBuffer, compute, datatypes::{ DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, - Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, - Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + FieldRef, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, + Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }, }; - -use arrow::{ - array::{Array, BooleanArray, Int64Array, PrimitiveArray}, - buffer::BooleanBuffer, -}; use datafusion_common::{ - downcast_value, internal_err, not_impl_err, Result, ScalarValue, + downcast_value, internal_err, not_impl_err, stats::Precision, + utils::expr::COUNT_STAR_EXPANSION, Result, ScalarValue, }; -use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ - function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, - Documentation, EmitTo, GroupsAccumulator, SetMonotonicity, Signature, Volatility, + expr::WindowFunction, + function::{AccumulatorArgs, StateFieldsArgs}, + utils::format_state_name, + Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, GroupsAccumulator, + ReversedUDAF, SetMonotonicity, Signature, StatisticsArgs, TypeSignature, Volatility, + WindowFunctionDefinition, }; -use datafusion_expr::{ - Expr, ReversedUDAF, StatisticsArgs, TypeSignature, WindowFunctionDefinition, +use datafusion_functions_aggregate_common::aggregate::{ + count_distinct::BytesDistinctCountAccumulator, + count_distinct::BytesViewDistinctCountAccumulator, + count_distinct::DictionaryCountAccumulator, + count_distinct::FloatDistinctCountAccumulator, + count_distinct::PrimitiveDistinctCountAccumulator, + groups_accumulator::accumulate::accumulate_indices, }; -use datafusion_functions_aggregate_common::aggregate::count_distinct::{ - BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, - PrimitiveDistinctCountAccumulator, -}; -use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices; +use datafusion_macros::user_doc; +use datafusion_physical_expr::expressions; use datafusion_physical_expr_common::binary_map::OutputType; - -use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; +use std::{ + collections::HashSet, + fmt::Debug, + mem::{size_of, size_of_val}, + ops::BitAnd, + sync::Arc, +}; make_udaf_expr_and_func!( Count, count, @@ -100,7 +96,7 @@ pub fn count_distinct(expr: Expr) -> Expr { /// let expr = col(expr.schema_name().to_string()); /// ``` pub fn count_all() -> Expr { - count(Expr::Literal(COUNT_STAR_EXPANSION)).alias("count(*)") + count(Expr::Literal(COUNT_STAR_EXPANSION, None)).alias("count(*)") } /// Creates window aggregation to count all rows. @@ -123,9 +119,9 @@ pub fn count_all() -> Expr { /// let expr = col(expr.schema_name().to_string()); /// ``` pub fn count_all_window() -> Expr { - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), - vec![Expr::Literal(COUNT_STAR_EXPANSION)], + vec![Expr::Literal(COUNT_STAR_EXPANSION, None)], )) } @@ -179,6 +175,107 @@ impl Count { } } } +fn get_count_accumulator(data_type: &DataType) -> Box { + match data_type { + // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator + DataType::Int8 => Box::new(PrimitiveDistinctCountAccumulator::::new( + data_type, + )), + DataType::Int16 => Box::new(PrimitiveDistinctCountAccumulator::::new( + data_type, + )), + DataType::Int32 => Box::new(PrimitiveDistinctCountAccumulator::::new( + data_type, + )), + DataType::Int64 => Box::new(PrimitiveDistinctCountAccumulator::::new( + data_type, + )), + DataType::UInt8 => Box::new(PrimitiveDistinctCountAccumulator::::new( + data_type, + )), + DataType::UInt16 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< + Decimal128Type, + >::new(data_type)), + DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< + Decimal256Type, + >::new(data_type)), + + DataType::Date32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Date64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time32(TimeUnit::Millisecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time32(TimeUnit::Second) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time64(TimeUnit::Microsecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time64(TimeUnit::Nanosecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Second, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + + DataType::Float16 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + DataType::Float32 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + DataType::Float64 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + + DataType::Utf8 => { + Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + } + DataType::Utf8View => { + Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View)) + } + DataType::LargeUtf8 => { + Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + } + DataType::Binary => Box::new(BytesDistinctCountAccumulator::::new( + OutputType::Binary, + )), + DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new( + OutputType::BinaryView, + )), + DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::::new( + OutputType::Binary, + )), + + // Use the generic accumulator based on `ScalarValue` for all other types + _ => Box::new(DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: data_type.clone(), + }), + } +} impl AggregateUDFImpl for Count { fn as_any(&self) -> &dyn std::any::Any { @@ -201,20 +298,27 @@ impl AggregateUDFImpl for Count { false } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { if args.is_distinct { + let dtype: DataType = match &args.input_fields[0].data_type() { + DataType::Dictionary(_, values_type) => (**values_type).clone(), + &dtype => dtype.clone(), + }; + Ok(vec![Field::new_list( format_state_name(args.name, "count distinct"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_types[0].clone(), true), + Field::new_list_field(dtype, true), false, - )]) + ) + .into()]) } else { Ok(vec![Field::new( format_state_name(args.name, "count"), DataType::Int64, false, - )]) + ) + .into()]) } } @@ -228,121 +332,16 @@ impl AggregateUDFImpl for Count { } let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?; - Ok(match data_type { - // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator - DataType::Int8 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Int16 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Int32 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Int64 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::UInt8 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::UInt16 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::UInt32 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::UInt64 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal128Type, - >::new(data_type)), - DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal256Type, - >::new(data_type)), - - DataType::Date32 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Date64 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Time32(TimeUnit::Millisecond) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Time32(TimeUnit::Second) => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Time64(TimeUnit::Microsecond) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Time64(TimeUnit::Nanosecond) => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new( - PrimitiveDistinctCountAccumulator::::new( - data_type, - ), - ), - DataType::Timestamp(TimeUnit::Second, _) => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), - DataType::Float16 => { - Box::new(FloatDistinctCountAccumulator::::new()) - } - DataType::Float32 => { - Box::new(FloatDistinctCountAccumulator::::new()) - } - DataType::Float64 => { - Box::new(FloatDistinctCountAccumulator::::new()) - } - - DataType::Utf8 => { - Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) - } - DataType::Utf8View => { - Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View)) - } - DataType::LargeUtf8 => { - Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + Ok(match data_type { + DataType::Dictionary(_, values_type) => { + let inner = get_count_accumulator(values_type); + Box::new(DictionaryCountAccumulator::new(inner)) } - DataType::Binary => Box::new(BytesDistinctCountAccumulator::::new( - OutputType::Binary, - )), - DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new( - OutputType::BinaryView, - )), - DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::::new( - OutputType::Binary, - )), - - // Use the generic accumulator based on `ScalarValue` for all other types - _ => Box::new(DistinctCountAccumulator { - values: HashSet::default(), - state_data_type: data_type.clone(), - }), + _ => get_count_accumulator(data_type), }) } - fn aliases(&self) -> &[String] { - &[] - } - fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { // groups accumulator only supports `COUNT(c1)`, not // `COUNT(c1, c2)`, etc @@ -708,8 +707,8 @@ impl Accumulator for DistinctCountAccumulator { } (0..arr.len()).try_for_each(|index| { - if !arr.is_null(index) { - let scalar = ScalarValue::try_from_array(arr, index)?; + let scalar = ScalarValue::try_from_array(arr, index)?; + if !scalar.is_null() { self.values.insert(scalar); } Ok(()) @@ -754,8 +753,28 @@ impl Accumulator for DistinctCountAccumulator { #[cfg(test)] mod tests { + use super::*; - use arrow::array::NullArray; + use arrow::{ + array::{DictionaryArray, Int32Array, NullArray, StringArray}, + datatypes::{DataType, Field, Int32Type, Schema}, + }; + use datafusion_expr::function::AccumulatorArgs; + use datafusion_physical_expr::expressions::Column; + use std::sync::Arc; + /// Helper function to create a dictionary array with non-null keys but some null values + /// Returns a dictionary array where: + /// - keys are [0, 1, 2, 0, 1] (all non-null) + /// - values are ["a", null, "c"] + /// - so the keys reference: "a", null, "c", "a", null + fn create_dictionary_with_null_values() -> Result> { + let values = StringArray::from(vec![Some("a"), None, Some("c")]); + let keys = Int32Array::from(vec![0, 1, 2, 0, 1]); // references "a", null, "c", "a", null + Ok(DictionaryArray::::try_new( + keys, + Arc::new(values), + )?) + } #[test] fn count_accumulator_nulls() -> Result<()> { @@ -764,4 +783,99 @@ mod tests { assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0))); Ok(()) } + + #[test] + fn test_nested_dictionary() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new( + "dict_col", + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + )), + ), + true, + )])); + + // Using Count UDAF's accumulator + let count = Count::new(); + let expr = Arc::new(Column::new("dict_col", 0)); + let args = AccumulatorArgs { + schema: &schema, + exprs: &[expr], + is_distinct: true, + name: "count", + ignore_nulls: false, + is_reversed: false, + return_field: Arc::new(Field::new_list_field(DataType::Int64, true)), + order_bys: &[], + }; + + let inner_dict = + DictionaryArray::::from_iter(["a", "b", "c", "d", "a", "b"]); + + let keys = Int32Array::from(vec![0, 1, 2, 0, 3, 1]); + let dict_of_dict = + DictionaryArray::::try_new(keys, Arc::new(inner_dict))?; + + let mut acc = count.accumulator(args)?; + acc.update_batch(&[Arc::new(dict_of_dict)])?; + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(4))); + + Ok(()) + } + + #[test] + fn count_distinct_accumulator_dictionary_with_null_values() -> Result<()> { + let dict_array = create_dictionary_with_null_values()?; + + // The expected behavior is that count_distinct should count only non-null values + // which in this case are "a" and "c" (appearing as 0 and 2 in keys) + let mut accumulator = DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: dict_array.data_type().clone(), + }; + + accumulator.update_batch(&[Arc::new(dict_array)])?; + + // Should have 2 distinct non-null values ("a" and "c") + assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(2))); + Ok(()) + } + + #[test] + fn count_accumulator_dictionary_with_null_values() -> Result<()> { + let dict_array = create_dictionary_with_null_values()?; + + // The expected behavior is that count should only count non-null values + let mut accumulator = CountAccumulator::new(); + + accumulator.update_batch(&[Arc::new(dict_array)])?; + + // 5 elements in the array, of which 2 reference null values (the two 1s in the keys) + // So we should count 3 non-null values + assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(3))); + Ok(()) + } + + #[test] + fn count_distinct_accumulator_dictionary_all_null_values() -> Result<()> { + // Create a dictionary array that only contains null values + let dict_values = StringArray::from(vec![None, Some("abc")]); + let dict_indices = Int32Array::from(vec![0; 5]); + let dict_array = + DictionaryArray::::try_new(dict_indices, Arc::new(dict_values))?; + + let mut accumulator = DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: dict_array.data_type().clone(), + }; + + accumulator.update_batch(&[Arc::new(dict_array)])?; + + // All referenced values are null so count(distinct) should be 0 + assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0))); + Ok(()) + } } diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index d4ae27533c6d..9f37a73e5429 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -17,15 +17,12 @@ //! [`CovarianceSample`]: covariance sample aggregations. -use std::fmt::Debug; -use std::mem::size_of_val; - +use arrow::datatypes::FieldRef; use arrow::{ array::{ArrayRef, Float64Array, UInt64Array}, compute::kernels::cast, datatypes::{DataType, Field}, }; - use datafusion_common::{ downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result, ScalarValue, @@ -38,6 +35,9 @@ use datafusion_expr::{ }; use datafusion_functions_aggregate_common::stats::StatsType; use datafusion_macros::user_doc; +use std::fmt::Debug; +use std::mem::size_of_val; +use std::sync::Arc; make_udaf_expr_and_func!( CovarianceSample, @@ -120,7 +120,7 @@ impl AggregateUDFImpl for CovarianceSample { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), @@ -131,7 +131,10 @@ impl AggregateUDFImpl for CovarianceSample { DataType::Float64, true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { @@ -210,7 +213,7 @@ impl AggregateUDFImpl for CovariancePopulation { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), @@ -221,7 +224,10 @@ impl AggregateUDFImpl for CovariancePopulation { DataType::Float64, true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index ec8c440b77e5..42c0a57fbf28 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -29,8 +29,8 @@ use arrow::array::{ use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions}; use arrow::datatypes::{ - DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, Float16Type, - Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, FieldRef, + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, @@ -45,7 +45,7 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt, - GroupsAccumulator, Signature, SortExpr, Volatility, + GroupsAccumulator, ReversedUDAF, Signature, SortExpr, Volatility, }; use datafusion_functions_aggregate_common::utils::get_sort_options; use datafusion_macros::user_doc; @@ -149,89 +149,92 @@ impl AggregateUDFImpl for FirstValue { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let ordering_dtypes = acc_args - .ordering_req + let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else { + return TrivialFirstValueAccumulator::try_new( + acc_args.return_field.data_type(), + acc_args.ignore_nulls, + ) + .map(|acc| Box::new(acc) as _); + }; + let ordering_dtypes = ordering .iter() .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; - - // When requirement is empty, or it is signalled by outside caller that - // the ordering requirement is/will be satisfied. - let requirement_satisfied = - acc_args.ordering_req.is_empty() || self.requirement_satisfied; - FirstValueAccumulator::try_new( - acc_args.return_type, + acc_args.return_field.data_type(), &ordering_dtypes, - acc_args.ordering_req.clone(), + ordering, acc_args.ignore_nulls, ) - .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new( format_state_name(args.name, "first_value"), - args.return_type.clone(), + args.return_type().clone(), true, - )]; - fields.extend(args.ordering_fields.to_vec()); - fields.push(Field::new("is_set", DataType::Boolean, true)); + ) + .into()]; + fields.extend(args.ordering_fields.iter().cloned()); + fields.push(Field::new("is_set", DataType::Boolean, true).into()); Ok(fields) } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { - // TODO: extract to function use DataType::*; - matches!( - args.return_type, - Int8 | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float16 - | Float32 - | Float64 - | Decimal128(_, _) - | Decimal256(_, _) - | Date32 - | Date64 - | Time32(_) - | Time64(_) - | Timestamp(_, _) - ) + !args.order_bys.is_empty() + && matches!( + args.return_field.data_type(), + Int8 | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float16 + | Float32 + | Float64 + | Decimal128(_, _) + | Decimal256(_, _) + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Timestamp(_, _) + ) } fn create_groups_accumulator( &self, args: AccumulatorArgs, ) -> Result> { - // TODO: extract to function - fn create_accumulator( + fn create_accumulator( args: AccumulatorArgs, - ) -> Result> - where - T: ArrowPrimitiveType + Send, - { - let ordering_dtypes = args - .ordering_req + ) -> Result> { + let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else { + return internal_err!("Groups accumulator must have an ordering."); + }; + + let ordering_dtypes = ordering .iter() .map(|e| e.expr.data_type(args.schema)) .collect::>>()?; - Ok(Box::new(FirstPrimitiveGroupsAccumulator::::try_new( - args.ordering_req.clone(), + FirstPrimitiveGroupsAccumulator::::try_new( + ordering, args.ignore_nulls, - args.return_type, + args.return_field.data_type(), &ordering_dtypes, true, - )?)) + ) + .map(|acc| Box::new(acc) as _) } - match args.return_type { + match args.return_field.data_type() { DataType::Int8 => create_accumulator::(args), DataType::Int16 => create_accumulator::(args), DataType::Int32 => create_accumulator::(args), @@ -276,19 +279,13 @@ impl AggregateUDFImpl for FirstValue { create_accumulator::(args) } - _ => { - internal_err!( - "GroupsAccumulator not supported for first_value({})", - args.return_type - ) - } + _ => internal_err!( + "GroupsAccumulator not supported for first_value({})", + args.return_field.data_type() + ), } } - fn aliases(&self) -> &[String] { - &[] - } - fn with_beneficial_ordering( self: Arc, beneficial_ordering: bool, @@ -302,8 +299,8 @@ impl AggregateUDFImpl for FirstValue { AggregateOrderSensitivity::Beneficial } - fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { - datafusion_expr::ReversedUDAF::Reversed(last_value_udaf()) + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Reversed(last_value_udaf()) } fn documentation(&self) -> Option<&Documentation> { @@ -349,8 +346,6 @@ where pick_first_in_group: bool, // derived from `ordering_req`. sort_options: Vec, - // Stores whether incoming data already satisfies the ordering requirement. - input_requirement_satisfied: bool, // Ignore null values. ignore_nulls: bool, /// The output type @@ -369,20 +364,17 @@ where ordering_dtypes: &[DataType], pick_first_in_group: bool, ) -> Result { - let requirement_satisfied = ordering_req.is_empty(); - let default_orderings = ordering_dtypes .iter() .map(ScalarValue::try_from) - .collect::>>()?; + .collect::>()?; - let sort_options = get_sort_options(ordering_req.as_ref()); + let sort_options = get_sort_options(&ordering_req); Ok(Self { null_builder: BooleanBufferBuilder::new(0), ordering_req, sort_options, - input_requirement_satisfied: requirement_satisfied, ignore_nulls, default_orderings, data_type: data_type.clone(), @@ -395,18 +387,6 @@ where }) } - fn need_update(&self, group_idx: usize) -> bool { - if !self.is_sets.get_bit(group_idx) { - return true; - } - - if self.ignore_nulls && !self.null_builder.get_bit(group_idx) { - return true; - } - - !self.input_requirement_satisfied - } - fn should_update_state( &self, group_idx: usize, @@ -572,17 +552,12 @@ where let group_idx = *group_idx; let passed_filter = opt_filter.is_none_or(|x| x.value(idx_in_val)); - let is_set = is_set_arr.is_none_or(|x| x.value(idx_in_val)); if !passed_filter || !is_set { continue; } - if !self.need_update(group_idx) { - continue; - } - if self.ignore_nulls && vals.is_null(idx_in_val) { continue; } @@ -719,7 +694,7 @@ where let (is_set_arr, val_and_order_cols) = match values.split_last() { Some(result) => result, - None => return internal_err!("Empty row in FISRT_VALUE"), + None => return internal_err!("Empty row in FIRST_VALUE"), }; let is_set_arr = as_boolean_array(is_set_arr)?; @@ -752,7 +727,7 @@ where fn size(&self) -> usize { self.vals.capacity() * size_of::() - + self.null_builder.capacity() / 8 // capacity is in bits, so convert to bytes + + self.null_builder.capacity() / 8 // capacity is in bits, so convert to bytes + self.is_sets.capacity() / 8 + self.size_of_orderings + self.min_of_each_group_buf.0.capacity() * size_of::() @@ -781,14 +756,96 @@ where } } } + +/// This accumulator is used when there is no ordering specified for the +/// `FIRST_VALUE` aggregation. It simply returns the first value it sees +/// according to the pre-existing ordering of the input data, and provides +/// a fast path for this case without needing to maintain any ordering state. +#[derive(Debug)] +pub struct TrivialFirstValueAccumulator { + first: ScalarValue, + // Whether we have seen the first value yet. + is_set: bool, + // Ignore null values. + ignore_nulls: bool, +} + +impl TrivialFirstValueAccumulator { + /// Creates a new `TrivialFirstValueAccumulator` for the given `data_type`. + pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result { + ScalarValue::try_from(data_type).map(|first| Self { + first, + is_set: false, + ignore_nulls, + }) + } +} + +impl Accumulator for TrivialFirstValueAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.first.clone(), ScalarValue::from(self.is_set)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if !self.is_set { + // Get first entry according to the pre-existing ordering (0th index): + let value = &values[0]; + let mut first_idx = None; + if self.ignore_nulls { + // If ignoring nulls, find the first non-null value. + for i in 0..value.len() { + if !value.is_null(i) { + first_idx = Some(i); + break; + } + } + } else if !value.is_empty() { + // If not ignoring nulls, return the first value if it exists. + first_idx = Some(0); + } + if let Some(first_idx) = first_idx { + let mut row = get_row_at_idx(values, first_idx)?; + self.first = row.swap_remove(0); + self.first.compact(); + self.is_set = true; + } + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // FIRST_VALUE(first1, first2, first3, ...) + // Second index contains is_set flag. + if !self.is_set { + let flags = states[1].as_boolean(); + let filtered_states = + filter_states_according_to_is_set(&states[0..1], flags)?; + if let Some(first) = filtered_states.first() { + if !first.is_empty() { + self.first = ScalarValue::try_from_array(first, 0)?; + self.is_set = true; + } + } + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(self.first.clone()) + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.first) + self.first.size() + } +} + #[derive(Debug)] pub struct FirstValueAccumulator { first: ScalarValue, - // At the beginning, `is_set` is false, which means `first` is not seen yet. - // Once we see the first value, we set the `is_set` flag and do not update `first` anymore. + // Whether we have seen the first value yet. is_set: bool, - // Stores ordering values, of the aggregator requirement corresponding to first value - // of the aggregator. These values are used during merging of multiple partitions. + // Stores values of the ordering columns corresponding to the first value. + // These values are used during merging of multiple partitions. orderings: Vec, // Stores the applicable ordering requirement. ordering_req: LexOrdering, @@ -809,14 +866,13 @@ impl FirstValueAccumulator { let orderings = ordering_dtypes .iter() .map(ScalarValue::try_from) - .collect::>>()?; - let requirement_satisfied = ordering_req.is_empty(); + .collect::>()?; ScalarValue::try_from(data_type).map(|first| Self { first, is_set: false, orderings, ordering_req, - requirement_satisfied, + requirement_satisfied: false, ignore_nulls, }) } @@ -827,9 +883,13 @@ impl FirstValueAccumulator { } // Updates state with the values in the given row. - fn update_with_new_row(&mut self, row: &[ScalarValue]) { - self.first = row[0].clone(); - self.orderings = row[1..].to_vec(); + fn update_with_new_row(&mut self, mut row: Vec) { + // Ensure any Array based scalars hold have a single value to reduce memory pressure + for s in row.iter_mut() { + s.compact(); + } + self.first = row.remove(0); + self.orderings = row; self.is_set = true; } @@ -880,29 +940,23 @@ impl Accumulator for FirstValueAccumulator { fn state(&mut self) -> Result> { let mut result = vec![self.first.clone()]; result.extend(self.orderings.iter().cloned()); - result.push(ScalarValue::Boolean(Some(self.is_set))); + result.push(ScalarValue::from(self.is_set)); Ok(result) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if !self.is_set { - if let Some(first_idx) = self.get_first_idx(values)? { - let row = get_row_at_idx(values, first_idx)?; - self.update_with_new_row(&row); - } - } else if !self.requirement_satisfied { - if let Some(first_idx) = self.get_first_idx(values)? { - let row = get_row_at_idx(values, first_idx)?; - let orderings = &row[1..]; - if compare_rows( - &self.orderings, - orderings, - &get_sort_options(self.ordering_req.as_ref()), - )? - .is_gt() - { - self.update_with_new_row(&row); - } + if let Some(first_idx) = self.get_first_idx(values)? { + let row = get_row_at_idx(values, first_idx)?; + if !self.is_set + || (!self.requirement_satisfied + && compare_rows( + &self.orderings, + &row[1..], + &get_sort_options(&self.ordering_req), + )? + .is_gt()) + { + self.update_with_new_row(row); } } Ok(()) @@ -916,19 +970,17 @@ impl Accumulator for FirstValueAccumulator { let filtered_states = filter_states_according_to_is_set(&states[0..is_set_idx], flags)?; // 1..is_set_idx range corresponds to ordering section - let sort_columns = convert_to_sort_cols( - &filtered_states[1..is_set_idx], - self.ordering_req.as_ref(), - ); + let sort_columns = + convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req); let comparator = LexicographicalComparator::try_new(&sort_columns)?; let min = (0..filtered_states[0].len()).min_by(|&a, &b| comparator.compare(a, b)); if let Some(first_idx) = min { - let first_row = get_row_at_idx(&filtered_states, first_idx)?; + let mut first_row = get_row_at_idx(&filtered_states, first_idx)?; // When collecting orderings, we exclude the is_set flag from the state. let first_ordering = &first_row[1..is_set_idx]; - let sort_options = get_sort_options(self.ordering_req.as_ref()); + let sort_options = get_sort_options(&self.ordering_req); // Either there is no existing value, or there is an earlier version in new data. if !self.is_set || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt() @@ -936,7 +988,9 @@ impl Accumulator for FirstValueAccumulator { // Update with first value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state // containing two is_set flags. - self.update_with_new_row(&first_row[0..is_set_idx]); + assert!(is_set_idx <= first_row.len()); + first_row.resize(is_set_idx, ScalarValue::Null); + self.update_with_new_row(first_row); } } Ok(()) @@ -1021,46 +1075,40 @@ impl AggregateUDFImpl for LastValue { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let ordering_dtypes = acc_args - .ordering_req + let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else { + return TrivialLastValueAccumulator::try_new( + acc_args.return_field.data_type(), + acc_args.ignore_nulls, + ) + .map(|acc| Box::new(acc) as _); + }; + let ordering_dtypes = ordering .iter() .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; - - let requirement_satisfied = - acc_args.ordering_req.is_empty() || self.requirement_satisfied; - LastValueAccumulator::try_new( - acc_args.return_type, + acc_args.return_field.data_type(), &ordering_dtypes, - acc_args.ordering_req.clone(), + ordering, acc_args.ignore_nulls, ) - .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { - let StateFieldsArgs { - name, - input_types, - return_type: _, - ordering_fields, - is_distinct: _, - } = args; + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new( - format_state_name(name, "last_value"), - input_types[0].clone(), + format_state_name(args.name, "last_value"), + args.return_field.data_type().clone(), true, - )]; - fields.extend(ordering_fields.to_vec()); - fields.push(Field::new("is_set", DataType::Boolean, true)); + ) + .into()]; + fields.extend(args.ordering_fields.iter().cloned()); + fields.push(Field::new("is_set", DataType::Boolean, true).into()); Ok(fields) } - fn aliases(&self) -> &[String] { - &[] - } - fn with_beneficial_ordering( self: Arc, beneficial_ordering: bool, @@ -1074,8 +1122,8 @@ impl AggregateUDFImpl for LastValue { AggregateOrderSensitivity::Beneficial } - fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { - datafusion_expr::ReversedUDAF::Reversed(first_value_udaf()) + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Reversed(first_value_udaf()) } fn documentation(&self) -> Option<&Documentation> { @@ -1084,26 +1132,27 @@ impl AggregateUDFImpl for LastValue { fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; - matches!( - args.return_type, - Int8 | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float16 - | Float32 - | Float64 - | Decimal128(_, _) - | Decimal256(_, _) - | Date32 - | Date64 - | Time32(_) - | Time64(_) - | Timestamp(_, _) - ) + !args.order_bys.is_empty() + && matches!( + args.return_field.data_type(), + Int8 | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float16 + | Float32 + | Float64 + | Decimal128(_, _) + | Decimal256(_, _) + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Timestamp(_, _) + ) } fn create_groups_accumulator( @@ -1116,22 +1165,25 @@ impl AggregateUDFImpl for LastValue { where T: ArrowPrimitiveType + Send, { - let ordering_dtypes = args - .ordering_req + let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else { + return internal_err!("Groups accumulator must have an ordering."); + }; + + let ordering_dtypes = ordering .iter() .map(|e| e.expr.data_type(args.schema)) .collect::>>()?; Ok(Box::new(FirstPrimitiveGroupsAccumulator::::try_new( - args.ordering_req.clone(), + ordering, args.ignore_nulls, - args.return_type, + args.return_field.data_type(), &ordering_dtypes, false, )?)) } - match args.return_type { + match args.return_field.data_type() { DataType::Int8 => create_accumulator::(args), DataType::Int16 => create_accumulator::(args), DataType::Int32 => create_accumulator::(args), @@ -1179,13 +1231,92 @@ impl AggregateUDFImpl for LastValue { _ => { internal_err!( "GroupsAccumulator not supported for last_value({})", - args.return_type + args.return_field.data_type() ) } } } } +/// This accumulator is used when there is no ordering specified for the +/// `LAST_VALUE` aggregation. It simply updates the last value it sees +/// according to the pre-existing ordering of the input data, and provides +/// a fast path for this case without needing to maintain any ordering state. +#[derive(Debug)] +pub struct TrivialLastValueAccumulator { + last: ScalarValue, + // The `is_set` flag keeps track of whether the last value is finalized. + // This information is used to discriminate genuine NULLs and NULLS that + // occur due to empty partitions. + is_set: bool, + // Ignore null values. + ignore_nulls: bool, +} + +impl TrivialLastValueAccumulator { + /// Creates a new `TrivialLastValueAccumulator` for the given `data_type`. + pub fn try_new(data_type: &DataType, ignore_nulls: bool) -> Result { + ScalarValue::try_from(data_type).map(|last| Self { + last, + is_set: false, + ignore_nulls, + }) + } +} + +impl Accumulator for TrivialLastValueAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.last.clone(), ScalarValue::from(self.is_set)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // Get last entry according to the pre-existing ordering (0th index): + let value = &values[0]; + let mut last_idx = None; + if self.ignore_nulls { + // If ignoring nulls, find the last non-null value. + for i in (0..value.len()).rev() { + if !value.is_null(i) { + last_idx = Some(i); + break; + } + } + } else if !value.is_empty() { + // If not ignoring nulls, return the last value if it exists. + last_idx = Some(value.len() - 1); + } + if let Some(last_idx) = last_idx { + let mut row = get_row_at_idx(values, last_idx)?; + self.last = row.swap_remove(0); + self.last.compact(); + self.is_set = true; + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // LAST_VALUE(last1, last2, last3, ...) + // Second index contains is_set flag. + let flags = states[1].as_boolean(); + let filtered_states = filter_states_according_to_is_set(&states[0..1], flags)?; + if let Some(last) = filtered_states.last() { + if !last.is_empty() { + self.last = ScalarValue::try_from_array(last, 0)?; + self.is_set = true; + } + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(self.last.clone()) + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.last) + self.last.size() + } +} + #[derive(Debug)] struct LastValueAccumulator { last: ScalarValue, @@ -1193,6 +1324,8 @@ struct LastValueAccumulator { // This information is used to discriminate genuine NULLs and NULLS that // occur due to empty partitions. is_set: bool, + // Stores values of the ordering columns corresponding to the first value. + // These values are used during merging of multiple partitions. orderings: Vec, // Stores the applicable ordering requirement. ordering_req: LexOrdering, @@ -1213,22 +1346,25 @@ impl LastValueAccumulator { let orderings = ordering_dtypes .iter() .map(ScalarValue::try_from) - .collect::>>()?; - let requirement_satisfied = ordering_req.is_empty(); + .collect::>()?; ScalarValue::try_from(data_type).map(|last| Self { last, is_set: false, orderings, ordering_req, - requirement_satisfied, + requirement_satisfied: false, ignore_nulls, }) } // Updates state with the values in the given row. - fn update_with_new_row(&mut self, row: &[ScalarValue]) { - self.last = row[0].clone(); - self.orderings = row[1..].to_vec(); + fn update_with_new_row(&mut self, mut row: Vec) { + // Ensure any Array based scalars hold have a single value to reduce memory pressure + for s in row.iter_mut() { + s.compact(); + } + self.last = row.remove(0); + self.orderings = row; self.is_set = true; } @@ -1250,6 +1386,7 @@ impl LastValueAccumulator { return Ok((!value.is_empty()).then_some(value.len() - 1)); } } + let sort_columns = ordering_values .iter() .zip(self.ordering_req.iter()) @@ -1281,31 +1418,27 @@ impl Accumulator for LastValueAccumulator { fn state(&mut self) -> Result> { let mut result = vec![self.last.clone()]; result.extend(self.orderings.clone()); - result.push(ScalarValue::Boolean(Some(self.is_set))); + result.push(ScalarValue::from(self.is_set)); Ok(result) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if !self.is_set || self.requirement_satisfied { - if let Some(last_idx) = self.get_last_idx(values)? { - let row = get_row_at_idx(values, last_idx)?; - self.update_with_new_row(&row); - } - } else if let Some(last_idx) = self.get_last_idx(values)? { + if let Some(last_idx) = self.get_last_idx(values)? { let row = get_row_at_idx(values, last_idx)?; let orderings = &row[1..]; // Update when there is a more recent entry - if compare_rows( - &self.orderings, - orderings, - &get_sort_options(self.ordering_req.as_ref()), - )? - .is_lt() + if !self.is_set + || self.requirement_satisfied + || compare_rows( + &self.orderings, + orderings, + &get_sort_options(&self.ordering_req), + )? + .is_lt() { - self.update_with_new_row(&row); + self.update_with_new_row(row); } } - Ok(()) } @@ -1317,19 +1450,17 @@ impl Accumulator for LastValueAccumulator { let filtered_states = filter_states_according_to_is_set(&states[0..is_set_idx], flags)?; // 1..is_set_idx range corresponds to ordering section - let sort_columns = convert_to_sort_cols( - &filtered_states[1..is_set_idx], - self.ordering_req.as_ref(), - ); + let sort_columns = + convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req); let comparator = LexicographicalComparator::try_new(&sort_columns)?; let max = (0..filtered_states[0].len()).max_by(|&a, &b| comparator.compare(a, b)); if let Some(last_idx) = max { - let last_row = get_row_at_idx(&filtered_states, last_idx)?; + let mut last_row = get_row_at_idx(&filtered_states, last_idx)?; // When collecting orderings, we exclude the is_set flag from the state. let last_ordering = &last_row[1..is_set_idx]; - let sort_options = get_sort_options(self.ordering_req.as_ref()); + let sort_options = get_sort_options(&self.ordering_req); // Either there is no existing value, or there is a newer (latest) // version in the new data: if !self.is_set @@ -1339,7 +1470,9 @@ impl Accumulator for LastValueAccumulator { // Update with last value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state // containing two is_set flags. - self.update_with_new_row(&last_row[0..is_set_idx]); + assert!(is_set_idx <= last_row.len()); + last_row.resize(is_set_idx, ScalarValue::Null); + self.update_with_new_row(last_row); } } Ok(()) @@ -1366,7 +1499,7 @@ fn filter_states_according_to_is_set( states .iter() .map(|state| compute::filter(state, flags).map_err(|e| arrow_datafusion_err!(e))) - .collect::>>() + .collect() } /// Combines array refs and their corresponding orderings to construct `SortColumn`s. @@ -1377,30 +1510,28 @@ fn convert_to_sort_cols(arrs: &[ArrayRef], sort_exprs: &LexOrdering) -> Vec>() + .collect() } #[cfg(test)] mod tests { - use arrow::{array::Int64Array, compute::SortOptions, datatypes::Schema}; + use std::iter::repeat_with; + + use arrow::{ + array::{Int64Array, ListArray}, + compute::SortOptions, + datatypes::Schema, + }; use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; use super::*; #[test] fn test_first_last_value_value() -> Result<()> { - let mut first_accumulator = FirstValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; - let mut last_accumulator = LastValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut first_accumulator = + TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?; + let mut last_accumulator = + TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?; // first value in the tuple is start of the range (inclusive), // second value in the tuple is end of the range (exclusive) let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)]; @@ -1437,22 +1568,14 @@ mod tests { .collect::>(); // FirstValueAccumulator - let mut first_accumulator = FirstValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut first_accumulator = + TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?; first_accumulator.update_batch(&[Arc::clone(&arrs[0])])?; let state1 = first_accumulator.state()?; - let mut first_accumulator = FirstValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut first_accumulator = + TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?; first_accumulator.update_batch(&[Arc::clone(&arrs[1])])?; let state2 = first_accumulator.state()?; @@ -1467,34 +1590,22 @@ mod tests { ])?); } - let mut first_accumulator = FirstValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut first_accumulator = + TrivialFirstValueAccumulator::try_new(&DataType::Int64, false)?; first_accumulator.merge_batch(&states)?; let merged_state = first_accumulator.state()?; assert_eq!(merged_state.len(), state1.len()); // LastValueAccumulator - let mut last_accumulator = LastValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut last_accumulator = + TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?; last_accumulator.update_batch(&[Arc::clone(&arrs[0])])?; let state1 = last_accumulator.state()?; - let mut last_accumulator = LastValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut last_accumulator = + TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?; last_accumulator.update_batch(&[Arc::clone(&arrs[1])])?; let state2 = last_accumulator.state()?; @@ -1509,12 +1620,8 @@ mod tests { ])?); } - let mut last_accumulator = LastValueAccumulator::try_new( - &DataType::Int64, - &[], - LexOrdering::default(), - false, - )?; + let mut last_accumulator = + TrivialLastValueAccumulator::try_new(&DataType::Int64, false)?; last_accumulator.merge_batch(&states)?; let merged_state = last_accumulator.state()?; @@ -1524,7 +1631,7 @@ mod tests { } #[test] - fn test_frist_group_acc() -> Result<()> { + fn test_first_group_acc() -> Result<()> { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int64, true), Field::new("b", DataType::Int64, true), @@ -1533,13 +1640,13 @@ mod tests { Field::new("e", DataType::Boolean, true), ])); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + let sort_keys = [PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]); + }]; let mut group_acc = FirstPrimitiveGroupsAccumulator::::try_new( - sort_key, + sort_keys.into(), true, &DataType::Int64, &[DataType::Int64], @@ -1627,13 +1734,13 @@ mod tests { Field::new("e", DataType::Boolean, true), ])); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + let sort_keys = [PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]); + }]; let mut group_acc = FirstPrimitiveGroupsAccumulator::::try_new( - sort_key, + sort_keys.into(), true, &DataType::Int64, &[DataType::Int64], @@ -1708,13 +1815,13 @@ mod tests { Field::new("e", DataType::Boolean, true), ])); - let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + let sort_keys = [PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]); + }]; let mut group_acc = FirstPrimitiveGroupsAccumulator::::try_new( - sort_key, + sort_keys.into(), true, &DataType::Int64, &[DataType::Int64], @@ -1772,4 +1879,56 @@ mod tests { Ok(()) } + + #[test] + fn test_first_list_acc_size() -> Result<()> { + fn size_after_batch(values: &[ArrayRef]) -> Result { + let mut first_accumulator = TrivialFirstValueAccumulator::try_new( + &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))), + false, + )?; + + first_accumulator.update_batch(values)?; + + Ok(first_accumulator.size()) + } + + let batch1 = ListArray::from_iter_primitive::( + repeat_with(|| Some(vec![Some(1)])).take(10000), + ); + let batch2 = + ListArray::from_iter_primitive::([Some(vec![Some(1)])]); + + let size1 = size_after_batch(&[Arc::new(batch1)])?; + let size2 = size_after_batch(&[Arc::new(batch2)])?; + assert_eq!(size1, size2); + + Ok(()) + } + + #[test] + fn test_last_list_acc_size() -> Result<()> { + fn size_after_batch(values: &[ArrayRef]) -> Result { + let mut last_accumulator = TrivialLastValueAccumulator::try_new( + &DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))), + false, + )?; + + last_accumulator.update_batch(values)?; + + Ok(last_accumulator.size()) + } + + let batch1 = ListArray::from_iter_primitive::( + repeat_with(|| Some(vec![Some(1)])).take(10000), + ); + let batch2 = + ListArray::from_iter_primitive::([Some(vec![Some(1)])]); + + let size1 = size_after_batch(&[Arc::new(batch1)])?; + let size2 = size_after_batch(&[Arc::new(batch2)])?; + assert_eq!(size1, size2); + + Ok(()) + } } diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 445774ff11e7..0727cf33036a 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -20,8 +20,8 @@ use std::any::Any; use std::fmt; -use arrow::datatypes::DataType; use arrow::datatypes::Field; +use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::{not_impl_err, Result}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; @@ -105,12 +105,13 @@ impl AggregateUDFImpl for Grouping { Ok(DataType::Int32) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![Field::new( format_state_name(args.name, "grouping"), DataType::Int32, true, - )]) + ) + .into()]) } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 7944280291eb..b5bb69f6da9d 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -220,8 +220,7 @@ mod tests { for alias in func.aliases() { assert!( names.insert(alias.to_string().to_lowercase()), - "duplicate function name: {}", - alias + "duplicate function name: {alias}" ); } } diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index b464dde6ccab..18f27c3c4ae3 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#[macro_export] macro_rules! make_udaf_expr { ($EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { // "fluent expr_fn" style function @@ -34,6 +35,7 @@ macro_rules! make_udaf_expr { }; } +#[macro_export] macro_rules! make_udaf_expr_and_func { ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { make_udaf_expr!($EXPR_FN, $($arg)*, $DOC, $AGGREGATE_UDF_FN); @@ -59,6 +61,7 @@ macro_rules! make_udaf_expr_and_func { }; } +#[macro_export] macro_rules! create_func { ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { create_func!($UDAF, $AGGREGATE_UDF_FN, <$UDAF>::default()); diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index ba6b63260e06..5c3d265d1d6b 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -35,7 +35,7 @@ use arrow::{ use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; -use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType}; +use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, FieldRef}; use datafusion_common::{ internal_datafusion_err, internal_err, DataFusionError, HashSet, Result, ScalarValue, @@ -125,9 +125,9 @@ impl AggregateUDFImpl for Median { Ok(arg_types[0].clone()) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { //Intermediate state is a list of the elements we have collected so far - let field = Field::new_list_field(args.input_types[0].clone(), true); + let field = Field::new_list_field(args.input_fields[0].data_type().clone(), true); let state_name = if args.is_distinct { "distinct_median" } else { @@ -138,7 +138,8 @@ impl AggregateUDFImpl for Median { format_state_name(args.name, state_name), DataType::List(Arc::new(field)), true, - )]) + ) + .into()]) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -213,10 +214,6 @@ impl AggregateUDFImpl for Median { } } - fn aliases(&self) -> &[String] { - &[] - } - fn documentation(&self) -> Option<&Documentation> { self.doc() } diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index ea4cad548803..0bd36a14be76 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -19,30 +19,21 @@ //! [`Min`] and [`MinAccumulator`] accumulator for the `min` function mod min_max_bytes; +mod min_max_struct; -use arrow::array::{ - ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, - Decimal128Array, Decimal256Array, DurationMicrosecondArray, DurationMillisecondArray, - DurationNanosecondArray, DurationSecondArray, Float16Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, - IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, - LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray, - Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, -}; -use arrow::compute; +use arrow::array::ArrayRef; use arrow::datatypes::{ DataType, Decimal128Type, Decimal256Type, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, DurationSecondType, Float16Type, - Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalUnit, - UInt16Type, UInt32Type, UInt64Type, UInt8Type, + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, }; use datafusion_common::stats::Precision; use datafusion_common::{ - downcast_value, exec_err, internal_err, ColumnStatistics, DataFusionError, Result, + exec_err, internal_err, ColumnStatistics, DataFusionError, Result, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use datafusion_functions_aggregate_common::min_max::{max_batch, min_batch}; use datafusion_physical_expr::expressions; use std::cmp::Ordering; use std::fmt::Debug; @@ -55,6 +46,7 @@ use arrow::datatypes::{ }; use crate::min_max::min_max_bytes::MinMaxBytesAccumulator; +use crate::min_max::min_max_struct::MinMaxStructAccumulator; use datafusion_common::ScalarValue; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation, @@ -231,17 +223,15 @@ impl AggregateUDFImpl for Max { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - Ok(Box::new(MaxAccumulator::try_new(acc_args.return_type)?)) - } - - fn aliases(&self) -> &[String] { - &[] + Ok(Box::new(MaxAccumulator::try_new( + acc_args.return_field.data_type(), + )?)) } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; matches!( - args.return_type, + args.return_field.data_type(), Int8 | Int16 | Int32 | Int64 @@ -266,6 +256,7 @@ impl AggregateUDFImpl for Max { | LargeBinary | BinaryView | Duration(_) + | Struct(_) ) } @@ -275,7 +266,7 @@ impl AggregateUDFImpl for Max { ) -> Result> { use DataType::*; use TimeUnit::*; - let data_type = args.return_type; + let data_type = args.return_field.data_type(); match data_type { Int8 => primitive_max_accumulator!(data_type, i8, Int8Type), Int16 => primitive_max_accumulator!(data_type, i16, Int16Type), @@ -341,7 +332,9 @@ impl AggregateUDFImpl for Max { Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { Ok(Box::new(MinMaxBytesAccumulator::new_max(data_type.clone()))) } - + Struct(_) => Ok(Box::new(MinMaxStructAccumulator::new_max( + data_type.clone(), + ))), // This is only reached if groups_accumulator_supported is out of sync _ => internal_err!("GroupsAccumulator not supported for max({})", data_type), } @@ -351,7 +344,9 @@ impl AggregateUDFImpl for Max { &self, args: AccumulatorArgs, ) -> Result> { - Ok(Box::new(SlidingMaxAccumulator::try_new(args.return_type)?)) + Ok(Box::new(SlidingMaxAccumulator::try_new( + args.return_field.data_type(), + )?)) } fn is_descending(&self) -> Option { @@ -383,280 +378,31 @@ impl AggregateUDFImpl for Max { } } -// Statically-typed version of min/max(array) -> ScalarValue for string types -macro_rules! typed_min_max_batch_string { - ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ - let array = downcast_value!($VALUES, $ARRAYTYPE); - let value = compute::$OP(array); - let value = value.and_then(|e| Some(e.to_string())); - ScalarValue::$SCALAR(value) - }}; -} -// Statically-typed version of min/max(array) -> ScalarValue for binary types. -macro_rules! typed_min_max_batch_binary { - ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ - let array = downcast_value!($VALUES, $ARRAYTYPE); - let value = compute::$OP(array); - let value = value.and_then(|e| Some(e.to_vec())); - ScalarValue::$SCALAR(value) - }}; -} - -// Statically-typed version of min/max(array) -> ScalarValue for non-string types. -macro_rules! typed_min_max_batch { - ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ - let array = downcast_value!($VALUES, $ARRAYTYPE); - let value = compute::$OP(array); - ScalarValue::$SCALAR(value, $($EXTRA_ARGS.clone()),*) - }}; -} - -// Statically-typed version of min/max(array) -> ScalarValue for non-string types. -// this is a macro to support both operations (min and max). -macro_rules! min_max_batch { - ($VALUES:expr, $OP:ident) => {{ - match $VALUES.data_type() { - DataType::Null => ScalarValue::Null, - DataType::Decimal128(precision, scale) => { - typed_min_max_batch!( - $VALUES, - Decimal128Array, - Decimal128, - $OP, - precision, - scale - ) - } - DataType::Decimal256(precision, scale) => { - typed_min_max_batch!( - $VALUES, - Decimal256Array, - Decimal256, - $OP, - precision, - scale - ) - } - // all types that have a natural order - DataType::Float64 => { - typed_min_max_batch!($VALUES, Float64Array, Float64, $OP) - } - DataType::Float32 => { - typed_min_max_batch!($VALUES, Float32Array, Float32, $OP) - } - DataType::Float16 => { - typed_min_max_batch!($VALUES, Float16Array, Float16, $OP) - } - DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP), - DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP), - DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP), - DataType::Int8 => typed_min_max_batch!($VALUES, Int8Array, Int8, $OP), - DataType::UInt64 => typed_min_max_batch!($VALUES, UInt64Array, UInt64, $OP), - DataType::UInt32 => typed_min_max_batch!($VALUES, UInt32Array, UInt32, $OP), - DataType::UInt16 => typed_min_max_batch!($VALUES, UInt16Array, UInt16, $OP), - DataType::UInt8 => typed_min_max_batch!($VALUES, UInt8Array, UInt8, $OP), - DataType::Timestamp(TimeUnit::Second, tz_opt) => { - typed_min_max_batch!( - $VALUES, - TimestampSecondArray, - TimestampSecond, - $OP, - tz_opt - ) - } - DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_min_max_batch!( - $VALUES, - TimestampMillisecondArray, - TimestampMillisecond, - $OP, - tz_opt - ), - DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_min_max_batch!( - $VALUES, - TimestampMicrosecondArray, - TimestampMicrosecond, - $OP, - tz_opt - ), - DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_min_max_batch!( - $VALUES, - TimestampNanosecondArray, - TimestampNanosecond, - $OP, - tz_opt - ), - DataType::Date32 => typed_min_max_batch!($VALUES, Date32Array, Date32, $OP), - DataType::Date64 => typed_min_max_batch!($VALUES, Date64Array, Date64, $OP), - DataType::Time32(TimeUnit::Second) => { - typed_min_max_batch!($VALUES, Time32SecondArray, Time32Second, $OP) - } - DataType::Time32(TimeUnit::Millisecond) => { - typed_min_max_batch!( - $VALUES, - Time32MillisecondArray, - Time32Millisecond, - $OP - ) - } - DataType::Time64(TimeUnit::Microsecond) => { - typed_min_max_batch!( - $VALUES, - Time64MicrosecondArray, - Time64Microsecond, - $OP - ) - } - DataType::Time64(TimeUnit::Nanosecond) => { - typed_min_max_batch!( - $VALUES, - Time64NanosecondArray, - Time64Nanosecond, - $OP - ) - } - DataType::Interval(IntervalUnit::YearMonth) => { - typed_min_max_batch!( - $VALUES, - IntervalYearMonthArray, - IntervalYearMonth, - $OP - ) - } - DataType::Interval(IntervalUnit::DayTime) => { - typed_min_max_batch!($VALUES, IntervalDayTimeArray, IntervalDayTime, $OP) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - typed_min_max_batch!( - $VALUES, - IntervalMonthDayNanoArray, - IntervalMonthDayNano, - $OP - ) - } - DataType::Duration(TimeUnit::Second) => { - typed_min_max_batch!($VALUES, DurationSecondArray, DurationSecond, $OP) - } - DataType::Duration(TimeUnit::Millisecond) => { - typed_min_max_batch!( - $VALUES, - DurationMillisecondArray, - DurationMillisecond, - $OP - ) - } - DataType::Duration(TimeUnit::Microsecond) => { - typed_min_max_batch!( - $VALUES, - DurationMicrosecondArray, - DurationMicrosecond, - $OP - ) - } - DataType::Duration(TimeUnit::Nanosecond) => { - typed_min_max_batch!( - $VALUES, - DurationNanosecondArray, - DurationNanosecond, - $OP - ) - } - other => { - // This should have been handled before - return internal_err!( - "Min/Max accumulator not implemented for type {:?}", - other - ); +macro_rules! min_max_generic { + ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ + if $VALUE.is_null() { + let mut delta_copy = $DELTA.clone(); + // When the new value won we want to compact it to + // avoid storing the entire input + delta_copy.compact(); + delta_copy + } else if $DELTA.is_null() { + $VALUE.clone() + } else { + match $VALUE.partial_cmp(&$DELTA) { + Some(choose_min_max!($OP)) => { + // When the new value won we want to compact it to + // avoid storing the entire input + let mut delta_copy = $DELTA.clone(); + delta_copy.compact(); + delta_copy + } + _ => $VALUE.clone(), } } }}; } -/// dynamically-typed min(array) -> ScalarValue -fn min_batch(values: &ArrayRef) -> Result { - Ok(match values.data_type() { - DataType::Utf8 => { - typed_min_max_batch_string!(values, StringArray, Utf8, min_string) - } - DataType::LargeUtf8 => { - typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string) - } - DataType::Utf8View => { - typed_min_max_batch_string!( - values, - StringViewArray, - Utf8View, - min_string_view - ) - } - DataType::Boolean => { - typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean) - } - DataType::Binary => { - typed_min_max_batch_binary!(&values, BinaryArray, Binary, min_binary) - } - DataType::LargeBinary => { - typed_min_max_batch_binary!( - &values, - LargeBinaryArray, - LargeBinary, - min_binary - ) - } - DataType::BinaryView => { - typed_min_max_batch_binary!( - &values, - BinaryViewArray, - BinaryView, - min_binary_view - ) - } - _ => min_max_batch!(values, min), - }) -} - -/// dynamically-typed max(array) -> ScalarValue -pub fn max_batch(values: &ArrayRef) -> Result { - Ok(match values.data_type() { - DataType::Utf8 => { - typed_min_max_batch_string!(values, StringArray, Utf8, max_string) - } - DataType::LargeUtf8 => { - typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string) - } - DataType::Utf8View => { - typed_min_max_batch_string!( - values, - StringViewArray, - Utf8View, - max_string_view - ) - } - DataType::Boolean => { - typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean) - } - DataType::Binary => { - typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary) - } - DataType::BinaryView => { - typed_min_max_batch_binary!( - &values, - BinaryViewArray, - BinaryView, - max_binary_view - ) - } - DataType::LargeBinary => { - typed_min_max_batch_binary!( - &values, - LargeBinaryArray, - LargeBinary, - max_binary - ) - } - _ => min_max_batch!(values, max), - }) -} - // min/max of two non-string scalar values. macro_rules! typed_min_max { ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ @@ -923,6 +669,37 @@ macro_rules! min_max { ) => { typed_min_max!(lhs, rhs, DurationNanosecond, $OP) } + + ( + lhs @ ScalarValue::Struct(_), + rhs @ ScalarValue::Struct(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + + ( + lhs @ ScalarValue::List(_), + rhs @ ScalarValue::List(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + + + ( + lhs @ ScalarValue::LargeList(_), + rhs @ ScalarValue::LargeList(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + + + ( + lhs @ ScalarValue::FixedSizeList(_), + rhs @ ScalarValue::FixedSizeList(_), + ) => { + min_max_generic!(lhs, rhs, $OP) + } + e => { return internal_err!( "MIN/MAX is not expected to receive scalars of incompatible types {:?}", @@ -1098,17 +875,15 @@ impl AggregateUDFImpl for Min { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - Ok(Box::new(MinAccumulator::try_new(acc_args.return_type)?)) - } - - fn aliases(&self) -> &[String] { - &[] + Ok(Box::new(MinAccumulator::try_new( + acc_args.return_field.data_type(), + )?)) } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; matches!( - args.return_type, + args.return_field.data_type(), Int8 | Int16 | Int32 | Int64 @@ -1133,6 +908,7 @@ impl AggregateUDFImpl for Min { | LargeBinary | BinaryView | Duration(_) + | Struct(_) ) } @@ -1142,7 +918,7 @@ impl AggregateUDFImpl for Min { ) -> Result> { use DataType::*; use TimeUnit::*; - let data_type = args.return_type; + let data_type = args.return_field.data_type(); match data_type { Int8 => primitive_min_accumulator!(data_type, i8, Int8Type), Int16 => primitive_min_accumulator!(data_type, i16, Int16Type), @@ -1208,7 +984,9 @@ impl AggregateUDFImpl for Min { Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { Ok(Box::new(MinMaxBytesAccumulator::new_min(data_type.clone()))) } - + Struct(_) => Ok(Box::new(MinMaxStructAccumulator::new_min( + data_type.clone(), + ))), // This is only reached if groups_accumulator_supported is out of sync _ => internal_err!("GroupsAccumulator not supported for min({})", data_type), } @@ -1218,7 +996,9 @@ impl AggregateUDFImpl for Min { &self, args: AccumulatorArgs, ) -> Result> { - Ok(Box::new(SlidingMinAccumulator::try_new(args.return_type)?)) + Ok(Box::new(SlidingMinAccumulator::try_new( + args.return_field.data_type(), + )?)) } fn is_descending(&self) -> Option { @@ -1627,8 +1407,15 @@ make_udaf_expr_and_func!( #[cfg(test)] mod tests { use super::*; - use arrow::datatypes::{ - IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, + use arrow::{ + array::{ + DictionaryArray, Float32Array, Int32Array, IntervalDayTimeArray, + IntervalMonthDayNanoArray, IntervalYearMonthArray, StringArray, + }, + datatypes::{ + IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, + IntervalYearMonthType, + }, }; use std::sync::Arc; @@ -1768,10 +1555,10 @@ mod tests { use rand::Rng; fn get_random_vec_i32(len: usize) -> Vec { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let mut input = Vec::with_capacity(len); for _i in 0..len { - input.push(rng.gen_range(0..100)); + input.push(rng.random_range(0..100)); } input } @@ -1854,9 +1641,31 @@ mod tests { #[test] fn test_get_min_max_return_type_coerce_dictionary() -> Result<()> { let data_type = - DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Int32)); + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); let result = get_min_max_result_type(&[data_type])?; - assert_eq!(result, vec![DataType::Int32]); + assert_eq!(result, vec![DataType::Utf8]); + Ok(()) + } + + #[test] + fn test_min_max_dictionary() -> Result<()> { + let values = StringArray::from(vec!["b", "c", "a", "🦀", "d"]); + let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), None, Some(4)]); + let dict_array = + DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap(); + let dict_array_ref = Arc::new(dict_array) as ArrayRef; + let rt_type = + get_min_max_result_type(&[dict_array_ref.data_type().clone()])?[0].clone(); + + let mut min_acc = MinAccumulator::try_new(&rt_type)?; + min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; + let min_result = min_acc.evaluate()?; + assert_eq!(min_result, ScalarValue::Utf8(Some("a".to_string()))); + + let mut max_acc = MaxAccumulator::try_new(&rt_type)?; + max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; + let max_result = max_acc.evaluate()?; + assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string()))); Ok(()) } } diff --git a/datafusion/functions-aggregate/src/min_max/min_max_struct.rs b/datafusion/functions-aggregate/src/min_max/min_max_struct.rs new file mode 100644 index 000000000000..8038f2f01d90 --- /dev/null +++ b/datafusion/functions-aggregate/src/min_max/min_max_struct.rs @@ -0,0 +1,544 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{cmp::Ordering, sync::Arc}; + +use arrow::{ + array::{ + Array, ArrayData, ArrayRef, AsArray, BooleanArray, MutableArrayData, StructArray, + }, + datatypes::DataType, +}; +use datafusion_common::{ + internal_err, + scalar::{copy_array_data, partial_cmp_struct}, + Result, +}; +use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls; + +/// Accumulator for MIN/MAX operations on Struct data types. +/// +/// This accumulator tracks the minimum or maximum struct value encountered +/// during aggregation, depending on the `is_min` flag. +/// +/// The comparison is done based on the struct fields in order. +pub(crate) struct MinMaxStructAccumulator { + /// Inner data storage. + inner: MinMaxStructState, + /// if true, is `MIN` otherwise is `MAX` + is_min: bool, +} + +impl MinMaxStructAccumulator { + pub fn new_min(data_type: DataType) -> Self { + Self { + inner: MinMaxStructState::new(data_type), + is_min: true, + } + } + + pub fn new_max(data_type: DataType) -> Self { + Self { + inner: MinMaxStructState::new(data_type), + is_min: false, + } + } +} + +impl GroupsAccumulator for MinMaxStructAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let array = &values[0]; + assert_eq!(array.len(), group_indices.len()); + assert_eq!(array.data_type(), &self.inner.data_type); + // apply filter if needed + let array = apply_filter_as_nulls(array, opt_filter)?; + + fn struct_min(a: &StructArray, b: &StructArray) -> bool { + matches!(partial_cmp_struct(a, b), Some(Ordering::Less)) + } + + fn struct_max(a: &StructArray, b: &StructArray) -> bool { + matches!(partial_cmp_struct(a, b), Some(Ordering::Greater)) + } + + if self.is_min { + self.inner.update_batch( + array.as_struct(), + group_indices, + total_num_groups, + struct_min, + ) + } else { + self.inner.update_batch( + array.as_struct(), + group_indices, + total_num_groups, + struct_max, + ) + } + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let (_, min_maxes) = self.inner.emit_to(emit_to); + let fields = match &self.inner.data_type { + DataType::Struct(fields) => fields, + _ => return internal_err!("Data type is not a struct"), + }; + let null_array = StructArray::new_null(fields.clone(), 1); + let min_maxes_data: Vec = min_maxes + .iter() + .map(|v| match v { + Some(v) => v.to_data(), + None => null_array.to_data(), + }) + .collect(); + let min_maxes_refs: Vec<&ArrayData> = min_maxes_data.iter().collect(); + let mut copy = MutableArrayData::new(min_maxes_refs, true, min_maxes_data.len()); + + for (i, item) in min_maxes_data.iter().enumerate() { + copy.extend(i, 0, item.len()); + } + let result = copy.freeze(); + assert_eq!(&self.inner.data_type, result.data_type()); + Ok(Arc::new(StructArray::from(result))) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + // min/max are their own states (no transition needed) + self.evaluate(emit_to).map(|arr| vec![arr]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // min/max are their own states (no transition needed) + self.update_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + // Min/max do not change the values as they are their own states + // apply the filter by combining with the null mask, if any + let output = apply_filter_as_nulls(&values[0], opt_filter)?; + Ok(vec![output]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.inner.size() + } +} + +#[derive(Debug)] +struct MinMaxStructState { + /// The minimum/maximum value for each group + min_max: Vec>, + /// The data type of the array + data_type: DataType, + /// The total bytes of the string data (for pre-allocating the final array, + /// and tracking memory usage) + total_data_bytes: usize, +} + +#[derive(Debug, Clone)] +enum MinMaxLocation { + /// the min/max value is stored in the existing `min_max` array + ExistingMinMax, + /// the min/max value is stored in the input array at the given index + Input(StructArray), +} + +/// Implement the MinMaxStructState with a comparison function +/// for comparing structs +impl MinMaxStructState { + /// Create a new MinMaxStructState + /// + /// # Arguments: + /// * `data_type`: The data type of the arrays that will be passed to this accumulator + fn new(data_type: DataType) -> Self { + Self { + min_max: vec![], + data_type, + total_data_bytes: 0, + } + } + + /// Set the specified group to the given value, updating memory usage appropriately + fn set_value(&mut self, group_index: usize, new_val: &StructArray) { + let new_val = StructArray::from(copy_array_data(&new_val.to_data())); + match self.min_max[group_index].as_mut() { + None => { + self.total_data_bytes += new_val.get_array_memory_size(); + self.min_max[group_index] = Some(new_val); + } + Some(existing_val) => { + // Copy data over to avoid re-allocating + self.total_data_bytes -= existing_val.get_array_memory_size(); + self.total_data_bytes += new_val.get_array_memory_size(); + *existing_val = new_val; + } + } + } + + /// Updates the min/max values for the given string values + /// + /// `cmp` is the comparison function to use, called like `cmp(new_val, existing_val)` + /// returns true if the `new_val` should replace `existing_val` + fn update_batch( + &mut self, + array: &StructArray, + group_indices: &[usize], + total_num_groups: usize, + mut cmp: F, + ) -> Result<()> + where + F: FnMut(&StructArray, &StructArray) -> bool + Send + Sync, + { + self.min_max.resize(total_num_groups, None); + // Minimize value copies by calculating the new min/maxes for each group + // in this batch (either the existing min/max or the new input value) + // and updating the owned values in `self.min_maxes` at most once + let mut locations = vec![MinMaxLocation::ExistingMinMax; total_num_groups]; + + // Figure out the new min value for each group + for (index, group_index) in (0..array.len()).zip(group_indices.iter()) { + let group_index = *group_index; + if array.is_null(index) { + continue; + } + let new_val = array.slice(index, 1); + + let existing_val = match &locations[group_index] { + // previous input value was the min/max, so compare it + MinMaxLocation::Input(existing_val) => existing_val, + MinMaxLocation::ExistingMinMax => { + let Some(existing_val) = self.min_max[group_index].as_ref() else { + // no existing min/max, so this is the new min/max + locations[group_index] = MinMaxLocation::Input(new_val); + continue; + }; + existing_val + } + }; + + // Compare the new value to the existing value, replacing if necessary + if cmp(&new_val, existing_val) { + locations[group_index] = MinMaxLocation::Input(new_val); + } + } + + // Update self.min_max with any new min/max values we found in the input + for (group_index, location) in locations.iter().enumerate() { + match location { + MinMaxLocation::ExistingMinMax => {} + MinMaxLocation::Input(new_val) => self.set_value(group_index, new_val), + } + } + Ok(()) + } + + /// Emits the specified min_max values + /// + /// Returns (data_capacity, min_maxes), updating the current value of total_data_bytes + /// + /// - `data_capacity`: the total length of all strings and their contents, + /// - `min_maxes`: the actual min/max values for each group + fn emit_to(&mut self, emit_to: EmitTo) -> (usize, Vec>) { + match emit_to { + EmitTo::All => { + ( + std::mem::take(&mut self.total_data_bytes), // reset total bytes and min_max + std::mem::take(&mut self.min_max), + ) + } + EmitTo::First(n) => { + let first_min_maxes: Vec<_> = self.min_max.drain(..n).collect(); + let first_data_capacity: usize = first_min_maxes + .iter() + .map(|opt| opt.as_ref().map(|s| s.len()).unwrap_or(0)) + .sum(); + self.total_data_bytes -= first_data_capacity; + (first_data_capacity, first_min_maxes) + } + } + } + + fn size(&self) -> usize { + self.total_data_bytes + self.min_max.len() * size_of::>() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray, StructArray}; + use arrow::datatypes::{DataType, Field, Fields, Int32Type}; + use std::sync::Arc; + + fn create_test_struct_array( + int_values: Vec>, + str_values: Vec>, + ) -> StructArray { + let int_array = Int32Array::from(int_values); + let str_array = StringArray::from(str_values); + + let fields = vec![ + Field::new("int_field", DataType::Int32, true), + Field::new("str_field", DataType::Utf8, true), + ]; + + StructArray::new( + Fields::from(fields), + vec![ + Arc::new(int_array) as ArrayRef, + Arc::new(str_array) as ArrayRef, + ], + None, + ) + } + + fn create_nested_struct_array( + int_values: Vec>, + str_values: Vec>, + ) -> StructArray { + let inner_struct = create_test_struct_array(int_values, str_values); + + let fields = vec![Field::new("inner", inner_struct.data_type().clone(), true)]; + + StructArray::new( + Fields::from(fields), + vec![Arc::new(inner_struct) as ArrayRef], + None, + ) + } + + #[test] + fn test_min_max_simple_struct() { + let array = create_test_struct_array( + vec![Some(1), Some(2), Some(3)], + vec![Some("a"), Some("b"), Some("c")], + ); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 0, 0]; + + min_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 1); + let int_array = min_result.column(0).as_primitive::(); + let str_array = min_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 1); + assert_eq!(str_array.value(0), "a"); + + assert_eq!(max_result.len(), 1); + let int_array = max_result.column(0).as_primitive::(); + let str_array = max_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 3); + assert_eq!(str_array.value(0), "c"); + } + + #[test] + fn test_min_max_nested_struct() { + let array = create_nested_struct_array( + vec![Some(1), Some(2), Some(3)], + vec![Some("a"), Some("b"), Some("c")], + ); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 0, 0]; + + min_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 1); + let inner = min_result.column(0).as_struct(); + let int_array = inner.column(0).as_primitive::(); + let str_array = inner.column(1).as_string::(); + assert_eq!(int_array.value(0), 1); + assert_eq!(str_array.value(0), "a"); + + assert_eq!(max_result.len(), 1); + let inner = max_result.column(0).as_struct(); + let int_array = inner.column(0).as_primitive::(); + let str_array = inner.column(1).as_string::(); + assert_eq!(int_array.value(0), 3); + assert_eq!(str_array.value(0), "c"); + } + + #[test] + fn test_min_max_with_nulls() { + let array = create_test_struct_array( + vec![Some(1), None, Some(3)], + vec![Some("a"), None, Some("c")], + ); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 0, 0]; + + min_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, None, 1) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 1); + let int_array = min_result.column(0).as_primitive::(); + let str_array = min_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 1); + assert_eq!(str_array.value(0), "a"); + + assert_eq!(max_result.len(), 1); + let int_array = max_result.column(0).as_primitive::(); + let str_array = max_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 3); + assert_eq!(str_array.value(0), "c"); + } + + #[test] + fn test_min_max_multiple_groups() { + let array = create_test_struct_array( + vec![Some(1), Some(2), Some(3), Some(4)], + vec![Some("a"), Some("b"), Some("c"), Some("d")], + ); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 1, 0, 1]; + + min_accumulator + .update_batch(&values, &group_indices, None, 2) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, None, 2) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 2); + let int_array = min_result.column(0).as_primitive::(); + let str_array = min_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 1); + assert_eq!(str_array.value(0), "a"); + assert_eq!(int_array.value(1), 2); + assert_eq!(str_array.value(1), "b"); + + assert_eq!(max_result.len(), 2); + let int_array = max_result.column(0).as_primitive::(); + let str_array = max_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 3); + assert_eq!(str_array.value(0), "c"); + assert_eq!(int_array.value(1), 4); + assert_eq!(str_array.value(1), "d"); + } + + #[test] + fn test_min_max_with_filter() { + let array = create_test_struct_array( + vec![Some(1), Some(2), Some(3), Some(4)], + vec![Some("a"), Some("b"), Some("c"), Some("d")], + ); + + // Create a filter that only keeps even numbers + let filter = BooleanArray::from(vec![false, true, false, true]); + + let mut min_accumulator = + MinMaxStructAccumulator::new_min(array.data_type().clone()); + let mut max_accumulator = + MinMaxStructAccumulator::new_max(array.data_type().clone()); + let values = vec![Arc::new(array) as ArrayRef]; + let group_indices = vec![0, 0, 0, 0]; + + min_accumulator + .update_batch(&values, &group_indices, Some(&filter), 1) + .unwrap(); + max_accumulator + .update_batch(&values, &group_indices, Some(&filter), 1) + .unwrap(); + let min_result = min_accumulator.evaluate(EmitTo::All).unwrap(); + let max_result = max_accumulator.evaluate(EmitTo::All).unwrap(); + let min_result = min_result.as_struct(); + let max_result = max_result.as_struct(); + + assert_eq!(min_result.len(), 1); + let int_array = min_result.column(0).as_primitive::(); + let str_array = min_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 2); + assert_eq!(str_array.value(0), "b"); + + assert_eq!(max_result.len(), 1); + let int_array = max_result.column(0).as_primitive::(); + let str_array = max_result.column(1).as_string::(); + assert_eq!(int_array.value(0), 4); + assert_eq!(str_array.value(0), "d"); + } +} diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index d84bd02a6baf..fac0d33ceabe 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -24,7 +24,7 @@ use std::mem::{size_of, size_of_val}; use std::sync::Arc; use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray}; -use arrow::datatypes::{DataType, Field, Fields}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields}; use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; @@ -86,7 +86,7 @@ pub fn nth_value( description = "The position (nth) of the value to retrieve, based on the ordering." ) )] -/// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi +/// Expression for a `NTH_VALUE(..., ... ORDER BY ...)` aggregation. In a multi /// partition setting, partial aggregations are computed for every partition, /// and then their results are merged. #[derive(Debug)] @@ -148,27 +148,28 @@ impl AggregateUDFImpl for NthValueAgg { } }; - let ordering_dtypes = acc_args - .ordering_req + let Some(ordering) = LexOrdering::new(acc_args.order_bys.to_vec()) else { + return TrivialNthValueAccumulator::try_new( + n, + acc_args.return_field.data_type(), + ) + .map(|acc| Box::new(acc) as _); + }; + let ordering_dtypes = ordering .iter() .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; - NthValueAccumulator::try_new( - n, - &data_type, - &ordering_dtypes, - acc_args.ordering_req.clone(), - ) - .map(|acc| Box::new(acc) as _) + NthValueAccumulator::try_new(n, &data_type, &ordering_dtypes, ordering) + .map(|acc| Box::new(acc) as _) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new_list( format_state_name(self.name(), "nth_value"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_types[0].clone(), true), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), false, )]; let orderings = args.ordering_fields.to_vec(); @@ -179,11 +180,7 @@ impl AggregateUDFImpl for NthValueAgg { false, )); } - Ok(fields) - } - - fn aliases(&self) -> &[String] { - &[] + Ok(fields.into_iter().map(Arc::new).collect()) } fn reverse_expr(&self) -> ReversedUDAF { @@ -195,6 +192,126 @@ impl AggregateUDFImpl for NthValueAgg { } } +#[derive(Debug)] +pub struct TrivialNthValueAccumulator { + /// The `N` value. + n: i64, + /// Stores entries in the `NTH_VALUE` result. + values: VecDeque, + /// Data types of the value. + datatype: DataType, +} + +impl TrivialNthValueAccumulator { + /// Create a new order-insensitive NTH_VALUE accumulator based on the given + /// item data type. + pub fn try_new(n: i64, datatype: &DataType) -> Result { + if n == 0 { + // n cannot be 0 + return internal_err!("Nth value indices are 1 based. 0 is invalid index"); + } + Ok(Self { + n, + values: VecDeque::new(), + datatype: datatype.clone(), + }) + } + + /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete + /// None represents all of the new `values` need to be added to the state. + fn append_new_data( + &mut self, + values: &[ArrayRef], + fetch: Option, + ) -> Result<()> { + let n_row = values[0].len(); + let n_to_add = if let Some(fetch) = fetch { + std::cmp::min(fetch, n_row) + } else { + n_row + }; + for index in 0..n_to_add { + let mut row = get_row_at_idx(values, index)?; + self.values.push_back(row.swap_remove(0)); + // At index 1, we have n index argument, which is constant. + } + Ok(()) + } +} + +impl Accumulator for TrivialNthValueAccumulator { + /// Updates its state with the `values`. Assumes data in the `values` satisfies the required + /// ordering for the accumulator (across consecutive batches, not just batch-wise). + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if !values.is_empty() { + let n_required = self.n.unsigned_abs() as usize; + let from_start = self.n > 0; + if from_start { + // direction is from start + let n_remaining = n_required.saturating_sub(self.values.len()); + self.append_new_data(values, Some(n_remaining))?; + } else { + // direction is from end + self.append_new_data(values, None)?; + let start_offset = self.values.len().saturating_sub(n_required); + if start_offset > 0 { + self.values.drain(0..start_offset); + } + } + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if !states.is_empty() { + // First entry in the state is the aggregation result. + let n_required = self.n.unsigned_abs() as usize; + let array_agg_res = ScalarValue::convert_array_to_scalar_vec(&states[0])?; + for v in array_agg_res.into_iter() { + self.values.extend(v); + if self.values.len() > n_required { + // There is enough data collected, can stop merging: + break; + } + } + } + Ok(()) + } + + fn state(&mut self) -> Result> { + let mut values_cloned = self.values.clone(); + let values_slice = values_cloned.make_contiguous(); + Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable( + values_slice, + &self.datatype, + ))]) + } + + fn evaluate(&mut self) -> Result { + let n_required = self.n.unsigned_abs() as usize; + let from_start = self.n > 0; + let nth_value_idx = if from_start { + // index is from start + let forward_idx = n_required - 1; + (forward_idx < self.values.len()).then_some(forward_idx) + } else { + // index is from end + self.values.len().checked_sub(n_required) + }; + if let Some(idx) = nth_value_idx { + Ok(self.values[idx].clone()) + } else { + ScalarValue::try_from(self.datatype.clone()) + } + } + + fn size(&self) -> usize { + size_of_val(self) + ScalarValue::size_of_vec_deque(&self.values) + - size_of_val(&self.values) + + size_of::() + } +} + #[derive(Debug)] pub struct NthValueAccumulator { /// The `N` value. @@ -236,6 +353,64 @@ impl NthValueAccumulator { ordering_req, }) } + + fn evaluate_orderings(&self) -> Result { + let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); + + let mut column_wise_ordering_values = vec![]; + let num_columns = fields.len(); + for i in 0..num_columns { + let column_values = self + .ordering_values + .iter() + .map(|x| x[i].clone()) + .collect::>(); + let array = if column_values.is_empty() { + new_empty_array(fields[i].data_type()) + } else { + ScalarValue::iter_to_array(column_values.into_iter())? + }; + column_wise_ordering_values.push(array); + } + + let struct_field = Fields::from(fields); + let ordering_array = + StructArray::try_new(struct_field, column_wise_ordering_values, None)?; + + Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar()) + } + + fn evaluate_values(&self) -> ScalarValue { + let mut values_cloned = self.values.clone(); + let values_slice = values_cloned.make_contiguous(); + ScalarValue::List(ScalarValue::new_list_nullable( + values_slice, + &self.datatypes[0], + )) + } + + /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete + /// None represents all of the new `values` need to be added to the state. + fn append_new_data( + &mut self, + values: &[ArrayRef], + fetch: Option, + ) -> Result<()> { + let n_row = values[0].len(); + let n_to_add = if let Some(fetch) = fetch { + std::cmp::min(fetch, n_row) + } else { + n_row + }; + for index in 0..n_to_add { + let row = get_row_at_idx(values, index)?; + self.values.push_back(row[0].clone()); + // At index 1, we have n index argument. + // Ordering values cover starting from 2nd index to end + self.ordering_values.push_back(row[2..].to_vec()); + } + Ok(()) + } } impl Accumulator for NthValueAccumulator { @@ -269,91 +444,60 @@ impl Accumulator for NthValueAccumulator { if states.is_empty() { return Ok(()); } - // First entry in the state is the aggregation result. - let array_agg_values = &states[0]; - let n_required = self.n.unsigned_abs() as usize; - if self.ordering_req.is_empty() { - let array_agg_res = - ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; - for v in array_agg_res.into_iter() { - self.values.extend(v); - if self.values.len() > n_required { - // There is enough data collected can stop merging - break; - } - } - } else if let Some(agg_orderings) = states[1].as_list_opt::() { - // 2nd entry stores values received for ordering requirement columns, for each aggregation value inside NTH_VALUE list. - // For each `StructArray` inside NTH_VALUE list, we will receive an `Array` that stores - // values received from its ordering requirement expression. (This information is necessary for during merging). - - // Stores NTH_VALUE results coming from each partition - let mut partition_values: Vec> = vec![]; - // Stores ordering requirement expression results coming from each partition - let mut partition_ordering_values: Vec>> = vec![]; - - // Existing values should be merged also. - partition_values.push(self.values.clone()); - - partition_ordering_values.push(self.ordering_values.clone()); - - let array_agg_res = - ScalarValue::convert_array_to_scalar_vec(array_agg_values)?; - - for v in array_agg_res.into_iter() { - partition_values.push(v.into()); - } - - let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; - - let ordering_values = orderings.into_iter().map(|partition_ordering_rows| { - // Extract value from struct to ordering_rows for each group/partition - partition_ordering_rows.into_iter().map(|ordering_row| { - if let ScalarValue::Struct(s) = ordering_row { - let mut ordering_columns_per_row = vec![]; - - for column in s.columns() { - let sv = ScalarValue::try_from_array(column, 0)?; - ordering_columns_per_row.push(sv); - } - - Ok(ordering_columns_per_row) - } else { - exec_err!( - "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}", - ordering_row.data_type() - ) - } - }).collect::>>() - }).collect::>>()?; - for ordering_values in ordering_values.into_iter() { - partition_ordering_values.push(ordering_values.into()); - } - - let sort_options = self - .ordering_req - .iter() - .map(|sort_expr| sort_expr.options) - .collect::>(); - let (new_values, new_orderings) = merge_ordered_arrays( - &mut partition_values, - &mut partition_ordering_values, - &sort_options, - )?; - self.values = new_values.into(); - self.ordering_values = new_orderings.into(); - } else { + // Second entry stores values received for ordering requirement columns + // for each aggregation value inside NTH_VALUE list. For each `StructArray` + // inside this list, we will receive an `Array` that stores values received + // from its ordering requirement expression. This information is necessary + // during merging. + let Some(agg_orderings) = states[1].as_list_opt::() else { return exec_err!("Expects to receive a list array"); + }; + + // Stores NTH_VALUE results coming from each partition + let mut partition_values = vec![self.values.clone()]; + // First entry in the state is the aggregation result. + let array_agg_res = ScalarValue::convert_array_to_scalar_vec(&states[0])?; + for v in array_agg_res.into_iter() { + partition_values.push(v.into()); } + // Stores ordering requirement expression results coming from each partition: + let mut partition_ordering_values = vec![self.ordering_values.clone()]; + let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?; + // Extract value from struct to ordering_rows for each group/partition: + for partition_ordering_rows in orderings.into_iter() { + let ordering_values = partition_ordering_rows.into_iter().map(|ordering_row| { + let ScalarValue::Struct(s_array) = ordering_row else { + return exec_err!( + "Expects to receive ScalarValue::Struct(Some(..), _) but got: {:?}", + ordering_row.data_type() + ); + }; + s_array + .columns() + .iter() + .map(|column| ScalarValue::try_from_array(column, 0)) + .collect() + }).collect::>>()?; + partition_ordering_values.push(ordering_values); + } + + let sort_options = self + .ordering_req + .iter() + .map(|sort_expr| sort_expr.options) + .collect::>(); + let (new_values, new_orderings) = merge_ordered_arrays( + &mut partition_values, + &mut partition_ordering_values, + &sort_options, + )?; + self.values = new_values.into(); + self.ordering_values = new_orderings.into(); Ok(()) } fn state(&mut self) -> Result> { - let mut result = vec![self.evaluate_values()]; - if !self.ordering_req.is_empty() { - result.push(self.evaluate_orderings()?); - } - Ok(result) + Ok(vec![self.evaluate_values(), self.evaluate_orderings()?]) } fn evaluate(&mut self) -> Result { @@ -396,63 +540,3 @@ impl Accumulator for NthValueAccumulator { total } } - -impl NthValueAccumulator { - fn evaluate_orderings(&self) -> Result { - let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]); - let struct_field = Fields::from(fields.clone()); - - let mut column_wise_ordering_values = vec![]; - let num_columns = fields.len(); - for i in 0..num_columns { - let column_values = self - .ordering_values - .iter() - .map(|x| x[i].clone()) - .collect::>(); - let array = if column_values.is_empty() { - new_empty_array(fields[i].data_type()) - } else { - ScalarValue::iter_to_array(column_values.into_iter())? - }; - column_wise_ordering_values.push(array); - } - - let ordering_array = - StructArray::try_new(struct_field, column_wise_ordering_values, None)?; - - Ok(SingleRowListArrayBuilder::new(Arc::new(ordering_array)).build_list_scalar()) - } - - fn evaluate_values(&self) -> ScalarValue { - let mut values_cloned = self.values.clone(); - let values_slice = values_cloned.make_contiguous(); - ScalarValue::List(ScalarValue::new_list_nullable( - values_slice, - &self.datatypes[0], - )) - } - - /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete - /// None represents all of the new `values` need to be added to the state. - fn append_new_data( - &mut self, - values: &[ArrayRef], - fetch: Option, - ) -> Result<()> { - let n_row = values[0].len(); - let n_to_add = if let Some(fetch) = fetch { - std::cmp::min(fetch, n_row) - } else { - n_row - }; - for index in 0..n_to_add { - let row = get_row_at_idx(values, index)?; - self.values.push_back(row[0].clone()); - // At index 1, we have n index argument. - // Ordering values cover starting from 2nd index to end - self.ordering_values.push_back(row[2..].to_vec()); - } - Ok(()) - } -} diff --git a/datafusion/functions-aggregate/src/planner.rs b/datafusion/functions-aggregate/src/planner.rs index c8cb84118995..f0e37f6b1dbe 100644 --- a/datafusion/functions-aggregate/src/planner.rs +++ b/datafusion/functions-aggregate/src/planner.rs @@ -100,7 +100,7 @@ impl ExprPlanner for AggregateFunctionPlanner { let new_expr = Expr::AggregateFunction(AggregateFunction::new_udf( func, - vec![Expr::Literal(COUNT_STAR_EXPANSION)], + vec![Expr::Literal(COUNT_STAR_EXPANSION, None)], distinct, filter, order_by, diff --git a/datafusion/functions-aggregate/src/regr.rs b/datafusion/functions-aggregate/src/regr.rs index 82575d15e50b..0f84aa1323f5 100644 --- a/datafusion/functions-aggregate/src/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -18,6 +18,7 @@ //! Defines physical expressions that can evaluated at runtime during query execution use arrow::array::Float64Array; +use arrow::datatypes::FieldRef; use arrow::{ array::{ArrayRef, UInt64Array}, compute::cast, @@ -38,7 +39,7 @@ use datafusion_expr::{ use std::any::Any; use std::fmt::Debug; use std::mem::size_of_val; -use std::sync::LazyLock; +use std::sync::{Arc, LazyLock}; macro_rules! make_regr_udaf_expr_and_func { ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => { @@ -278,7 +279,7 @@ impl AggregateUDFImpl for Regr { Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "count"), @@ -310,7 +311,10 @@ impl AggregateUDFImpl for Regr { DataType::Float64, true, ), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index adf86a128cfb..bf6d21a808e7 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -23,8 +23,8 @@ use std::mem::align_of_val; use std::sync::Arc; use arrow::array::Float64Array; +use arrow::datatypes::FieldRef; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; - use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; @@ -109,7 +109,7 @@ impl AggregateUDFImpl for Stddev { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "count"), @@ -122,7 +122,10 @@ impl AggregateUDFImpl for Stddev { true, ), Field::new(format_state_name(args.name, "m2"), DataType::Float64, true), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -217,7 +220,7 @@ impl AggregateUDFImpl for StddevPop { &self.signature } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "count"), @@ -230,7 +233,10 @@ impl AggregateUDFImpl for StddevPop { true, ), Field::new(format_state_name(args.name, "m2"), DataType::Float64, true), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -387,7 +393,6 @@ mod tests { use datafusion_expr::AggregateUDF; use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays; use datafusion_physical_expr::expressions::col; - use datafusion_physical_expr_common::sort_expr::LexOrdering; use std::sync::Arc; #[test] @@ -436,10 +441,10 @@ mod tests { schema: &Schema, ) -> Result { let args1 = AccumulatorArgs { - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), schema, ignore_nulls: false, - ordering_req: &LexOrdering::default(), + order_bys: &[], name: "a", is_distinct: false, is_reversed: false, @@ -447,10 +452,10 @@ mod tests { }; let args2 = AccumulatorArgs { - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), schema, ignore_nulls: false, - ordering_req: &LexOrdering::default(), + order_bys: &[], name: "a", is_distinct: false, is_reversed: false, diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index a7594b9ccb01..09199e19cffc 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -17,12 +17,15 @@ //! [`StringAgg`] accumulator for the `string_agg` function +use std::any::Any; +use std::mem::size_of_val; + use crate::array_agg::ArrayAgg; + use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::cast::as_generic_string_array; -use datafusion_common::Result; -use datafusion_common::{internal_err, not_impl_err, ScalarValue}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; +use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility, @@ -30,8 +33,6 @@ use datafusion_expr::{ use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs; use datafusion_macros::user_doc; use datafusion_physical_expr::expressions::Literal; -use std::any::Any; -use std::mem::size_of_val; make_udaf_expr_and_func!( StringAgg, @@ -95,9 +96,15 @@ impl StringAgg { TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8View]), TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]), TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]), TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8View]), + TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8View]), + TypeSignature::Exact(vec![DataType::Utf8View, DataType::LargeUtf8]), + TypeSignature::Exact(vec![DataType::Utf8View, DataType::Null]), + TypeSignature::Exact(vec![DataType::Utf8View, DataType::Utf8]), ], Volatility::Immutable, ), @@ -129,7 +136,7 @@ impl AggregateUDFImpl for StringAgg { Ok(DataType::LargeUtf8) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { self.array_agg.state_fields(args) } @@ -154,7 +161,12 @@ impl AggregateUDFImpl for StringAgg { }; let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs { - return_type: &DataType::new_list(acc_args.return_type.clone(), true), + return_field: Field::new( + "f", + DataType::new_list(acc_args.return_field.data_type().clone(), true), + true, + ) + .into(), exprs: &filter_index(acc_args.exprs, 1), ..acc_args })?; @@ -206,6 +218,10 @@ impl Accumulator for StringAggAccumulator { .iter() .flatten() .collect(), + DataType::Utf8View => as_string_view_array(list.values())? + .iter() + .flatten() + .collect(), _ => { return internal_err!( "Expected elements to of type Utf8 or LargeUtf8, but got {}", @@ -256,7 +272,7 @@ mod tests { use arrow::datatypes::{Fields, Schema}; use datafusion_common::internal_err; use datafusion_physical_expr::expressions::Column; - use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use std::sync::Arc; #[test] @@ -398,7 +414,7 @@ mod tests { struct StringAggAccumulatorBuilder { sep: String, distinct: bool, - ordering: LexOrdering, + order_bys: Vec, schema: Schema, } @@ -407,7 +423,7 @@ mod tests { Self { sep: sep.to_string(), distinct: Default::default(), - ordering: Default::default(), + order_bys: vec![], schema: Schema { fields: Fields::from(vec![Field::new( "col", @@ -424,7 +440,7 @@ mod tests { } fn order_by_col(mut self, col: &str, sort_options: SortOptions) -> Self { - self.ordering.extend([PhysicalSortExpr::new( + self.order_bys.extend([PhysicalSortExpr::new( Arc::new( Column::new_with_schema(col, &self.schema) .expect("column not available in schema"), @@ -436,10 +452,10 @@ mod tests { fn build(&self) -> Result> { StringAgg::new().accumulator(AccumulatorArgs { - return_type: &DataType::LargeUtf8, + return_field: Field::new("f", DataType::LargeUtf8, true).into(), schema: &self.schema, ignore_nulls: false, - ordering_req: &self.ordering, + order_bys: &self.order_bys, is_reversed: false, name: "", is_distinct: self.distinct, diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 76a1315c2d88..9495e087d250 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -26,8 +26,8 @@ use std::mem::{size_of, size_of_val}; use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; use arrow::array::{ArrowNumericType, AsArray}; -use arrow::datatypes::ArrowNativeType; use arrow::datatypes::ArrowPrimitiveType; +use arrow::datatypes::{ArrowNativeType, FieldRef}; use arrow::datatypes::{ DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, @@ -63,17 +63,27 @@ make_udaf_expr_and_func!( /// `helper` is a macro accepting (ArrowPrimitiveType, DataType) macro_rules! downcast_sum { ($args:ident, $helper:ident) => { - match $args.return_type { - DataType::UInt64 => $helper!(UInt64Type, $args.return_type), - DataType::Int64 => $helper!(Int64Type, $args.return_type), - DataType::Float64 => $helper!(Float64Type, $args.return_type), - DataType::Decimal128(_, _) => $helper!(Decimal128Type, $args.return_type), - DataType::Decimal256(_, _) => $helper!(Decimal256Type, $args.return_type), + match $args.return_field.data_type().clone() { + DataType::UInt64 => { + $helper!(UInt64Type, $args.return_field.data_type().clone()) + } + DataType::Int64 => { + $helper!(Int64Type, $args.return_field.data_type().clone()) + } + DataType::Float64 => { + $helper!(Float64Type, $args.return_field.data_type().clone()) + } + DataType::Decimal128(_, _) => { + $helper!(Decimal128Type, $args.return_field.data_type().clone()) + } + DataType::Decimal256(_, _) => { + $helper!(Decimal256Type, $args.return_field.data_type().clone()) + } _ => { not_impl_err!( "Sum not supported for {}: {}", $args.name, - $args.return_type + $args.return_field.data_type() ) } } @@ -191,27 +201,25 @@ impl AggregateUDFImpl for Sum { } } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { if args.is_distinct { Ok(vec![Field::new_list( format_state_name(args.name, "sum distinct"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.return_type.clone(), true), + Field::new_list_field(args.return_type().clone(), true), false, - )]) + ) + .into()]) } else { Ok(vec![Field::new( format_state_name(args.name, "sum"), - args.return_type.clone(), + args.return_type().clone(), true, - )]) + ) + .into()]) } } - fn aliases(&self) -> &[String] { - &[] - } - fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { !args.is_distinct } diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 53e3e0cc56cd..586b2dab0ae6 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -18,15 +18,13 @@ //! [`VarianceSample`]: variance sample aggregations. //! [`VariancePopulation`]: variance population aggregations. +use arrow::datatypes::FieldRef; use arrow::{ array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array}, buffer::NullBuffer, compute::kernels::cast, datatypes::{DataType, Field}, }; -use std::mem::{size_of, size_of_val}; -use std::{fmt::Debug, sync::Arc}; - use datafusion_common::{downcast_value, not_impl_err, plan_err, Result, ScalarValue}; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, @@ -38,6 +36,8 @@ use datafusion_functions_aggregate_common::{ aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType, }; use datafusion_macros::user_doc; +use std::mem::{size_of, size_of_val}; +use std::{fmt::Debug, sync::Arc}; make_udaf_expr_and_func!( VarianceSample, @@ -107,13 +107,16 @@ impl AggregateUDFImpl for VarianceSample { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean"), DataType::Float64, true), Field::new(format_state_name(name, "m2"), DataType::Float64, true), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -200,13 +203,16 @@ impl AggregateUDFImpl for VariancePopulation { Ok(DataType::Float64) } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean"), DataType::Float64, true), Field::new(format_state_name(name, "m2"), DataType::Float64, true), - ]) + ] + .into_iter() + .map(Arc::new) + .collect()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/functions-nested/Cargo.toml b/datafusion/functions-nested/Cargo.toml index 9a7b1f460ef5..87a480e16003 100644 --- a/datafusion/functions-nested/Cargo.toml +++ b/datafusion/functions-nested/Cargo.toml @@ -46,6 +46,7 @@ datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } datafusion-functions-aggregate = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } datafusion-macros = { workspace = true } datafusion-physical-expr-common = { workspace = true } itertools = { workspace = true, features = ["use_std"] } diff --git a/datafusion/functions-nested/README.md b/datafusion/functions-nested/README.md index 8a5047c838ab..0fa93619b97b 100644 --- a/datafusion/functions-nested/README.md +++ b/datafusion/functions-nested/README.md @@ -24,4 +24,9 @@ This crate contains functions for working with arrays, maps and structs, such as `array_append` that work with `ListArray`, `LargeListArray` and `FixedListArray` types from the `arrow` crate. +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index 2774b24b902a..55dd7ad14460 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -36,7 +36,7 @@ fn keys(rng: &mut ThreadRng) -> Vec { let mut keys = HashSet::with_capacity(1000); while keys.len() < 1000 { - keys.insert(rng.gen_range(0..10000).to_string()); + keys.insert(rng.random_range(0..10000).to_string()); } keys.into_iter().collect() @@ -46,20 +46,23 @@ fn values(rng: &mut ThreadRng) -> Vec { let mut values = HashSet::with_capacity(1000); while values.len() < 1000 { - values.insert(rng.gen_range(0..10000)); + values.insert(rng.random_range(0..10000)); } values.into_iter().collect() } fn criterion_benchmark(c: &mut Criterion) { c.bench_function("make_map_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let keys = keys(&mut rng); let values = values(&mut rng); let mut buffer = Vec::new(); for i in 0..1000 { - buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); - buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); + buffer.push(Expr::Literal( + ScalarValue::Utf8(Some(keys[i].clone())), + None, + )); + buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])), None)); } let planner = NestedFunctionPlanner {}; @@ -74,7 +77,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("map_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let field = Arc::new(Field::new_list_field(DataType::Utf8, true)); let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 1000])); let key_list = ListArray::new( @@ -94,17 +97,23 @@ fn criterion_benchmark(c: &mut Criterion) { let keys = ColumnarValue::Scalar(ScalarValue::List(Arc::new(key_list))); let values = ColumnarValue::Scalar(ScalarValue::List(Arc::new(value_list))); - let return_type = &map_udf() + let return_type = map_udf() .return_type(&[DataType::Utf8, DataType::Int32]) .expect("should get return type"); + let arg_fields = vec![ + Field::new("a", keys.data_type(), true).into(), + Field::new("a", values.data_type(), true).into(), + ]; + let return_field = Field::new("f", return_type, true).into(); b.iter(|| { black_box( map_udf() .invoke_with_args(ScalarFunctionArgs { args: vec![keys.clone(), values.clone()], + arg_fields: arg_fields.clone(), number_rows: 1, - return_type, + return_field: Arc::clone(&return_field), }) .expect("map should work on valid values"), ); diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index 5ef1491313b1..57ff8490b9d4 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -17,16 +17,14 @@ //! [`ScalarUDFImpl`] definitions for array_has, array_has_all and array_has_any functions. -use arrow::array::{ - Array, ArrayRef, BooleanArray, Datum, GenericListArray, OffsetSizeTrait, Scalar, -}; +use arrow::array::{Array, ArrayRef, BooleanArray, Datum, Scalar}; use arrow::buffer::BooleanBuffer; use arrow::datatypes::DataType; use arrow::row::{RowConverter, Rows, SortField}; -use datafusion_common::cast::as_generic_list_array; +use datafusion_common::cast::{as_fixed_size_list_array, as_generic_list_array}; use datafusion_common::utils::string_utils::string_array_to_vec; use datafusion_common::utils::take_function_args; -use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::{InList, ScalarFunction}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ @@ -133,7 +131,7 @@ impl ScalarUDFImpl for ArrayHas { // if the haystack is a constant list, we can use an inlist expression which is more // efficient because the haystack is not varying per-row - if let Expr::Literal(ScalarValue::List(array)) = haystack { + if let Expr::Literal(ScalarValue::List(array), _) = haystack { // TODO: support LargeList // (not supported by `convert_array_to_scalar_vec`) // (FixedSizeList not supported either, but seems to have worked fine when attempting to @@ -147,7 +145,7 @@ impl ScalarUDFImpl for ArrayHas { let list = scalar_values .into_iter() .flatten() - .map(Expr::Literal) + .map(|v| Expr::Literal(v, None)) .collect(); return Ok(ExprSimplifyResult::Simplified(Expr::InList(InList { @@ -218,34 +216,98 @@ fn array_has_inner_for_scalar( haystack: &ArrayRef, needle: &dyn Datum, ) -> Result { - match haystack.data_type() { - DataType::List(_) => array_has_dispatch_for_scalar::(haystack, needle), - DataType::LargeList(_) => array_has_dispatch_for_scalar::(haystack, needle), - _ => exec_err!( - "array_has does not support type '{:?}'.", - haystack.data_type() - ), - } + let haystack = haystack.as_ref().try_into()?; + array_has_dispatch_for_scalar(haystack, needle) } fn array_has_inner_for_array(haystack: &ArrayRef, needle: &ArrayRef) -> Result { - match haystack.data_type() { - DataType::List(_) => array_has_dispatch_for_array::(haystack, needle), - DataType::LargeList(_) => array_has_dispatch_for_array::(haystack, needle), - _ => exec_err!( - "array_has does not support type '{:?}'.", - haystack.data_type() - ), + let haystack = haystack.as_ref().try_into()?; + array_has_dispatch_for_array(haystack, needle) +} + +enum ArrayWrapper<'a> { + FixedSizeList(&'a arrow::array::FixedSizeListArray), + List(&'a arrow::array::GenericListArray), + LargeList(&'a arrow::array::GenericListArray), +} + +impl<'a> TryFrom<&'a dyn Array> for ArrayWrapper<'a> { + type Error = DataFusionError; + + fn try_from( + value: &'a dyn Array, + ) -> std::result::Result, Self::Error> { + match value.data_type() { + DataType::List(_) => { + Ok(ArrayWrapper::List(as_generic_list_array::(value)?)) + } + DataType::LargeList(_) => Ok(ArrayWrapper::LargeList( + as_generic_list_array::(value)?, + )), + DataType::FixedSizeList(_, _) => Ok(ArrayWrapper::FixedSizeList( + as_fixed_size_list_array(value)?, + )), + _ => exec_err!("array_has does not support type '{:?}'.", value.data_type()), + } } } -fn array_has_dispatch_for_array( - haystack: &ArrayRef, +impl<'a> ArrayWrapper<'a> { + fn len(&self) -> usize { + match self { + ArrayWrapper::FixedSizeList(arr) => arr.len(), + ArrayWrapper::List(arr) => arr.len(), + ArrayWrapper::LargeList(arr) => arr.len(), + } + } + + fn iter(&self) -> Box> + 'a> { + match self { + ArrayWrapper::FixedSizeList(arr) => Box::new(arr.iter()), + ArrayWrapper::List(arr) => Box::new(arr.iter()), + ArrayWrapper::LargeList(arr) => Box::new(arr.iter()), + } + } + + fn values(&self) -> &ArrayRef { + match self { + ArrayWrapper::FixedSizeList(arr) => arr.values(), + ArrayWrapper::List(arr) => arr.values(), + ArrayWrapper::LargeList(arr) => arr.values(), + } + } + + fn value_type(&self) -> DataType { + match self { + ArrayWrapper::FixedSizeList(arr) => arr.value_type(), + ArrayWrapper::List(arr) => arr.value_type(), + ArrayWrapper::LargeList(arr) => arr.value_type(), + } + } + + fn offsets(&self) -> Box + 'a> { + match self { + ArrayWrapper::FixedSizeList(arr) => { + let offsets = (0..=arr.len()) + .step_by(arr.value_length() as usize) + .collect::>(); + Box::new(offsets.into_iter()) + } + ArrayWrapper::List(arr) => { + Box::new(arr.offsets().iter().map(|o| (*o) as usize)) + } + ArrayWrapper::LargeList(arr) => { + Box::new(arr.offsets().iter().map(|o| (*o) as usize)) + } + } + } +} + +fn array_has_dispatch_for_array( + haystack: ArrayWrapper<'_>, needle: &ArrayRef, ) -> Result { - let haystack = as_generic_list_array::(haystack)?; let mut boolean_builder = BooleanArray::builder(haystack.len()); - for (i, arr) in haystack.iter().enumerate() { if arr.is_none() || needle.is_null(i) { boolean_builder.append_null(); @@ -261,14 +323,12 @@ fn array_has_dispatch_for_array( Ok(Arc::new(boolean_builder.finish())) } -fn array_has_dispatch_for_scalar( - haystack: &ArrayRef, +fn array_has_dispatch_for_scalar( + haystack: ArrayWrapper<'_>, needle: &dyn Datum, ) -> Result { - let haystack = as_generic_list_array::(haystack)?; let values = haystack.values(); let is_nested = values.data_type().is_nested(); - let offsets = haystack.value_offsets(); // If first argument is empty list (second argument is non-null), return false // i.e. array_has([], non-null element) -> false if values.is_empty() { @@ -279,51 +339,128 @@ fn array_has_dispatch_for_scalar( } let eq_array = compare_with_eq(values, needle, is_nested)?; let mut final_contained = vec![None; haystack.len()]; - for (i, offset) in offsets.windows(2).enumerate() { - let start = offset[0].to_usize().unwrap(); - let end = offset[1].to_usize().unwrap(); + + // Check validity buffer to distinguish between null and empty arrays + let validity = match &haystack { + ArrayWrapper::FixedSizeList(arr) => arr.nulls(), + ArrayWrapper::List(arr) => arr.nulls(), + ArrayWrapper::LargeList(arr) => arr.nulls(), + }; + + for (i, (start, end)) in haystack.offsets().tuple_windows().enumerate() { let length = end - start; - // For non-nested list, length is 0 for null + + // Check if the array at this position is null + if let Some(validity_buffer) = validity { + if !validity_buffer.is_valid(i) { + final_contained[i] = None; // null array -> null result + continue; + } + } + + // For non-null arrays: length is 0 for empty arrays if length == 0 { - continue; + final_contained[i] = Some(false); // empty array -> false + } else { + let sliced_array = eq_array.slice(start, length); + final_contained[i] = Some(sliced_array.true_count() > 0); } - let sliced_array = eq_array.slice(start, length); - final_contained[i] = Some(sliced_array.true_count() > 0); } Ok(Arc::new(BooleanArray::from(final_contained))) } fn array_has_all_inner(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::List(_) => { - array_has_all_and_any_dispatch::(&args[0], &args[1], ComparisonType::All) - } - DataType::LargeList(_) => { - array_has_all_and_any_dispatch::(&args[0], &args[1], ComparisonType::All) + array_has_all_and_any_inner(args, ComparisonType::All) +} + +// General row comparison for array_has_all and array_has_any +fn general_array_has_for_all_and_any<'a>( + haystack: &ArrayWrapper<'a>, + needle: &ArrayWrapper<'a>, + comparison_type: ComparisonType, +) -> Result { + let mut boolean_builder = BooleanArray::builder(haystack.len()); + let converter = RowConverter::new(vec![SortField::new(haystack.value_type())])?; + + for (arr, sub_arr) in haystack.iter().zip(needle.iter()) { + if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { + let arr_values = converter.convert_columns(&[arr])?; + let sub_arr_values = converter.convert_columns(&[sub_arr])?; + boolean_builder.append_value(general_array_has_all_and_any_kernel( + arr_values, + sub_arr_values, + comparison_type, + )); + } else { + boolean_builder.append_null(); } - _ => exec_err!( - "array_has does not support type '{:?}'.", - args[0].data_type() - ), } + + Ok(Arc::new(boolean_builder.finish())) } -fn array_has_any_inner(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::List(_) => { - array_has_all_and_any_dispatch::(&args[0], &args[1], ComparisonType::Any) +// String comparison for array_has_all and array_has_any +fn array_has_all_and_any_string_internal<'a>( + haystack: &ArrayWrapper<'a>, + needle: &ArrayWrapper<'a>, + comparison_type: ComparisonType, +) -> Result { + let mut boolean_builder = BooleanArray::builder(haystack.len()); + for (arr, sub_arr) in haystack.iter().zip(needle.iter()) { + match (arr, sub_arr) { + (Some(arr), Some(sub_arr)) => { + let haystack_array = string_array_to_vec(&arr); + let needle_array = string_array_to_vec(&sub_arr); + boolean_builder.append_value(array_has_string_kernel( + haystack_array, + needle_array, + comparison_type, + )); + } + (_, _) => { + boolean_builder.append_null(); + } } - DataType::LargeList(_) => { - array_has_all_and_any_dispatch::(&args[0], &args[1], ComparisonType::Any) + } + + Ok(Arc::new(boolean_builder.finish())) +} + +fn array_has_all_and_any_dispatch<'a>( + haystack: &ArrayWrapper<'a>, + needle: &ArrayWrapper<'a>, + comparison_type: ComparisonType, +) -> Result { + if needle.values().is_empty() { + let buffer = match comparison_type { + ComparisonType::All => BooleanBuffer::new_set(haystack.len()), + ComparisonType::Any => BooleanBuffer::new_unset(haystack.len()), + }; + Ok(Arc::new(BooleanArray::from(buffer))) + } else { + match needle.value_type() { + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { + array_has_all_and_any_string_internal(haystack, needle, comparison_type) + } + _ => general_array_has_for_all_and_any(haystack, needle, comparison_type), } - _ => exec_err!( - "array_has does not support type '{:?}'.", - args[0].data_type() - ), } } +fn array_has_all_and_any_inner( + args: &[ArrayRef], + comparison_type: ComparisonType, +) -> Result { + let haystack: ArrayWrapper = args[0].as_ref().try_into()?; + let needle: ArrayWrapper = args[1].as_ref().try_into()?; + array_has_all_and_any_dispatch(&haystack, &needle, comparison_type) +} + +fn array_has_any_inner(args: &[ArrayRef]) -> Result { + array_has_all_and_any_inner(args, ComparisonType::Any) +} + #[user_doc( doc_section(label = "Array Functions"), description = "Returns true if all elements of sub-array exist in array.", @@ -481,55 +618,6 @@ enum ComparisonType { Any, } -fn array_has_all_and_any_dispatch( - haystack: &ArrayRef, - needle: &ArrayRef, - comparison_type: ComparisonType, -) -> Result { - let haystack = as_generic_list_array::(haystack)?; - let needle = as_generic_list_array::(needle)?; - if needle.values().is_empty() { - let buffer = match comparison_type { - ComparisonType::All => BooleanBuffer::new_set(haystack.len()), - ComparisonType::Any => BooleanBuffer::new_unset(haystack.len()), - }; - return Ok(Arc::new(BooleanArray::from(buffer))); - } - match needle.data_type() { - DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { - array_has_all_and_any_string_internal::(haystack, needle, comparison_type) - } - _ => general_array_has_for_all_and_any::(haystack, needle, comparison_type), - } -} - -// String comparison for array_has_all and array_has_any -fn array_has_all_and_any_string_internal( - array: &GenericListArray, - needle: &GenericListArray, - comparison_type: ComparisonType, -) -> Result { - let mut boolean_builder = BooleanArray::builder(array.len()); - for (arr, sub_arr) in array.iter().zip(needle.iter()) { - match (arr, sub_arr) { - (Some(arr), Some(sub_arr)) => { - let haystack_array = string_array_to_vec(&arr); - let needle_array = string_array_to_vec(&sub_arr); - boolean_builder.append_value(array_has_string_kernel( - haystack_array, - needle_array, - comparison_type, - )); - } - (_, _) => { - boolean_builder.append_null(); - } - } - } - - Ok(Arc::new(boolean_builder.finish())) -} - fn array_has_string_kernel( haystack: Vec>, needle: Vec>, @@ -547,32 +635,6 @@ fn array_has_string_kernel( } } -// General row comparison for array_has_all and array_has_any -fn general_array_has_for_all_and_any( - haystack: &GenericListArray, - needle: &GenericListArray, - comparison_type: ComparisonType, -) -> Result { - let mut boolean_builder = BooleanArray::builder(haystack.len()); - let converter = RowConverter::new(vec![SortField::new(haystack.value_type())])?; - - for (arr, sub_arr) in haystack.iter().zip(needle.iter()) { - if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let arr_values = converter.convert_columns(&[arr])?; - let sub_arr_values = converter.convert_columns(&[sub_arr])?; - boolean_builder.append_value(general_array_has_all_and_any_kernel( - arr_values, - sub_arr_values, - comparison_type, - )); - } else { - boolean_builder.append_null(); - } - } - - Ok(Arc::new(boolean_builder.finish())) -} - fn general_array_has_all_and_any_kernel( haystack_rows: Rows, needle_rows: Rows, diff --git a/datafusion/functions-nested/src/cardinality.rs b/datafusion/functions-nested/src/cardinality.rs index f2f23841586c..98bda81ef25f 100644 --- a/datafusion/functions-nested/src/cardinality.rs +++ b/datafusion/functions-nested/src/cardinality.rs @@ -23,12 +23,12 @@ use arrow::array::{ }; use arrow::datatypes::{ DataType, - DataType::{FixedSizeList, LargeList, List, Map, UInt64}, + DataType::{LargeList, List, Map, Null, UInt64}, }; use datafusion_common::cast::{as_large_list_array, as_list_array, as_map_array}; -use datafusion_common::utils::take_function_args; +use datafusion_common::exec_err; +use datafusion_common::utils::{take_function_args, ListCoercion}; use datafusion_common::Result; -use datafusion_common::{exec_err, plan_err}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -52,7 +52,7 @@ impl Cardinality { vec![ TypeSignature::ArraySignature(ArrayFunctionSignature::Array { arguments: vec![ArrayFunctionArgument::Array], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }), TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray), ], @@ -103,13 +103,8 @@ impl ScalarUDFImpl for Cardinality { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) | Map(_, _) => UInt64, - _ => { - return plan_err!("The cardinality function can only accept List/LargeList/FixedSizeList/Map."); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(UInt64) } fn invoke_with_args( @@ -131,21 +126,22 @@ impl ScalarUDFImpl for Cardinality { /// Cardinality SQL function pub fn cardinality_inner(args: &[ArrayRef]) -> Result { let [array] = take_function_args("cardinality", args)?; - match &array.data_type() { + match array.data_type() { + Null => Ok(Arc::new(UInt64Array::from_value(0, array.len()))), List(_) => { - let list_array = as_list_array(&array)?; + let list_array = as_list_array(array)?; generic_list_cardinality::(list_array) } LargeList(_) => { - let list_array = as_large_list_array(&array)?; + let list_array = as_large_list_array(array)?; generic_list_cardinality::(list_array) } Map(_, _) => { - let map_array = as_map_array(&array)?; + let map_array = as_map_array(array)?; generic_map_cardinality(map_array) } - other => { - exec_err!("cardinality does not support type '{:?}'", other) + arg_type => { + exec_err!("cardinality does not support type {arg_type}") } } } diff --git a/datafusion/functions-nested/src/concat.rs b/datafusion/functions-nested/src/concat.rs index f4b9208e5c83..e8b7fc27b481 100644 --- a/datafusion/functions-nested/src/concat.rs +++ b/datafusion/functions-nested/src/concat.rs @@ -17,29 +17,32 @@ //! [`ScalarUDFImpl`] definitions for `array_append`, `array_prepend` and `array_concat` functions. +use std::any::Any; use std::sync::Arc; -use std::{any::Any, cmp::Ordering}; +use crate::make_array::make_array_inner; +use crate::utils::{align_array_dimensions, check_datatypes, make_scalar_function}; use arrow::array::{ - Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, NullBufferBuilder, - OffsetSizeTrait, + Array, ArrayData, ArrayRef, Capacities, GenericListArray, MutableArrayData, + NullBufferBuilder, OffsetSizeTrait, }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, Field}; -use datafusion_common::utils::ListCoercion; +use datafusion_common::utils::{ + base_type, coerced_type_with_base_type_only, ListCoercion, +}; use datafusion_common::Result; use datafusion_common::{ cast::as_generic_list_array, - exec_err, not_impl_err, plan_err, + exec_err, plan_err, utils::{list_ndims, take_function_args}, }; +use datafusion_expr::binary::type_union_resolution; use datafusion_expr::{ - ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, - ScalarUDFImpl, Signature, TypeSignature, Volatility, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; - -use crate::utils::{align_array_dimensions, check_datatypes, make_scalar_function}; +use itertools::Itertools; make_udf_expr_and_func!( ArrayAppend, @@ -106,7 +109,12 @@ impl ScalarUDFImpl for ArrayAppend { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + let [array_type, element_type] = take_function_args(self.name(), arg_types)?; + if array_type.is_null() { + Ok(DataType::new_list(element_type.clone(), true)) + } else { + Ok(array_type.clone()) + } } fn invoke_with_args( @@ -166,18 +174,7 @@ impl Default for ArrayPrepend { impl ArrayPrepend { pub fn new() -> Self { Self { - signature: Signature { - type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::Array { - arguments: vec![ - ArrayFunctionArgument::Element, - ArrayFunctionArgument::Array, - ], - array_coercion: Some(ListCoercion::FixedSizedListToList), - }, - ), - volatility: Volatility::Immutable, - }, + signature: Signature::element_and_array(Volatility::Immutable), aliases: vec![ String::from("list_prepend"), String::from("array_push_front"), @@ -201,7 +198,12 @@ impl ScalarUDFImpl for ArrayPrepend { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[1].clone()) + let [element_type, array_type] = take_function_args(self.name(), arg_types)?; + if array_type.is_null() { + Ok(DataType::new_list(element_type.clone(), true)) + } else { + Ok(array_type.clone()) + } } fn invoke_with_args( @@ -263,7 +265,7 @@ impl Default for ArrayConcat { impl ArrayConcat { pub fn new() -> Self { Self { - signature: Signature::variadic_any(Volatility::Immutable), + signature: Signature::user_defined(Volatility::Immutable), aliases: vec![ String::from("array_cat"), String::from("list_concat"), @@ -287,39 +289,40 @@ impl ScalarUDFImpl for ArrayConcat { } fn return_type(&self, arg_types: &[DataType]) -> Result { - let mut expr_type = DataType::Null; let mut max_dims = 0; + let mut large_list = false; + let mut element_types = Vec::with_capacity(arg_types.len()); for arg_type in arg_types { - let DataType::List(field) = arg_type else { - return plan_err!( - "The array_concat function can only accept list as the args." - ); - }; - if !field.data_type().equals_datatype(&DataType::Null) { - let dims = list_ndims(arg_type); - expr_type = match max_dims.cmp(&dims) { - Ordering::Greater => expr_type, - Ordering::Equal => { - if expr_type == DataType::Null { - arg_type.clone() - } else if !expr_type.equals_datatype(arg_type) { - return plan_err!( - "It is not possible to concatenate arrays of different types. Expected: {}, got: {}", expr_type, arg_type - ); - } else { - expr_type - } - } - - Ordering::Less => { - max_dims = dims; - arg_type.clone() - } - }; + match arg_type { + DataType::Null | DataType::List(_) | DataType::FixedSizeList(..) => (), + DataType::LargeList(_) => large_list = true, + arg_type => { + return plan_err!("{} does not support type {arg_type}", self.name()) + } } + + max_dims = max_dims.max(list_ndims(arg_type)); + element_types.push(base_type(arg_type)) } - Ok(expr_type) + if max_dims == 0 { + Ok(DataType::Null) + } else if let Some(mut return_type) = type_union_resolution(&element_types) { + for _ in 1..max_dims { + return_type = DataType::new_list(return_type, true) + } + + if large_list { + Ok(DataType::new_large_list(return_type, true)) + } else { + Ok(DataType::new_list(return_type, true)) + } + } else { + plan_err!( + "Failed to unify argument types of {}: {arg_types:?}", + self.name() + ) + } } fn invoke_with_args( @@ -333,6 +336,16 @@ impl ScalarUDFImpl for ArrayConcat { &self.aliases } + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let base_type = base_type(&self.return_type(arg_types)?); + let coercion = Some(&ListCoercion::FixedSizedListToList); + let arg_types = arg_types.iter().map(|arg_type| { + coerced_type_with_base_type_only(arg_type, &base_type, coercion) + }); + + Ok(arg_types.collect()) + } + fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -341,24 +354,38 @@ impl ScalarUDFImpl for ArrayConcat { /// Array_concat/Array_cat SQL function pub(crate) fn array_concat_inner(args: &[ArrayRef]) -> Result { if args.is_empty() { - return exec_err!("array_concat expects at least one arguments"); + return exec_err!("array_concat expects at least one argument"); } - let mut new_args = vec![]; + let mut all_null = true; + let mut large_list = false; for arg in args { - let ndim = list_ndims(arg.data_type()); - let base_type = datafusion_common::utils::base_type(arg.data_type()); - if ndim == 0 { - return not_impl_err!("Array is not type '{base_type:?}'."); + match arg.data_type() { + DataType::Null => continue, + DataType::LargeList(_) => large_list = true, + _ => (), } - if !base_type.eq(&DataType::Null) { - new_args.push(Arc::clone(arg)); + if arg.null_count() < arg.len() { + all_null = false; } } - match &args[0].data_type() { - DataType::LargeList(_) => concat_internal::(new_args.as_slice()), - _ => concat_internal::(new_args.as_slice()), + if all_null { + // Return a null array with the same type as the first non-null-type argument + let return_type = args + .iter() + .map(|arg| arg.data_type()) + .find_or_first(|d| !d.is_null()) + .unwrap(); // Safe because args is non-empty + + Ok(arrow::array::make_array(ArrayData::new_null( + return_type, + args[0].len(), + ))) + } else if large_list { + concat_internal::(args) + } else { + concat_internal::(args) } } @@ -427,21 +454,23 @@ fn concat_internal(args: &[ArrayRef]) -> Result { /// Array_append SQL function pub(crate) fn array_append_inner(args: &[ArrayRef]) -> Result { - let [array, _] = take_function_args("array_append", args)?; - + let [array, values] = take_function_args("array_append", args)?; match array.data_type() { + DataType::Null => make_array_inner(&[Arc::clone(values)]), + DataType::List(_) => general_append_and_prepend::(args, true), DataType::LargeList(_) => general_append_and_prepend::(args, true), - _ => general_append_and_prepend::(args, true), + arg_type => exec_err!("array_append does not support type {arg_type}"), } } /// Array_prepend SQL function pub(crate) fn array_prepend_inner(args: &[ArrayRef]) -> Result { - let [_, array] = take_function_args("array_prepend", args)?; - + let [values, array] = take_function_args("array_prepend", args)?; match array.data_type() { + DataType::Null => make_array_inner(&[Arc::clone(values)]), + DataType::List(_) => general_append_and_prepend::(args, false), DataType::LargeList(_) => general_append_and_prepend::(args, false), - _ => general_append_and_prepend::(args, false), + arg_type => exec_err!("array_prepend does not support type {arg_type}"), } } diff --git a/datafusion/functions-nested/src/dimension.rs b/datafusion/functions-nested/src/dimension.rs index a7d033641413..d1e6b1be4cfa 100644 --- a/datafusion/functions-nested/src/dimension.rs +++ b/datafusion/functions-nested/src/dimension.rs @@ -17,24 +17,26 @@ //! [`ScalarUDFImpl`] definitions for array_dims and array_ndims functions. -use arrow::array::{ - Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, UInt64Array, -}; +use arrow::array::{Array, ArrayRef, ListArray, UInt64Array}; use arrow::datatypes::{ DataType, - DataType::{FixedSizeList, LargeList, List, UInt64}, - Field, UInt64Type, + DataType::{FixedSizeList, LargeList, List, Null, UInt64}, + UInt64Type, }; use std::any::Any; -use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result}; +use datafusion_common::cast::{ + as_fixed_size_list_array, as_large_list_array, as_list_array, +}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use crate::utils::{compute_array_dims, make_scalar_function}; +use datafusion_common::utils::list_ndims; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; +use itertools::Itertools; use std::sync::Arc; make_udf_expr_and_func!( @@ -77,7 +79,7 @@ impl Default for ArrayDims { impl ArrayDims { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature::arrays(1, None, Volatility::Immutable), aliases: vec!["list_dims".to_string()], } } @@ -95,15 +97,8 @@ impl ScalarUDFImpl for ArrayDims { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => { - List(Arc::new(Field::new_list_field(UInt64, true))) - } - _ => { - return plan_err!("The array_dims function can only accept List/LargeList/FixedSizeList."); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::new_list(UInt64, true)) } fn invoke_with_args( @@ -156,7 +151,7 @@ pub(super) struct ArrayNdims { impl ArrayNdims { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature::arrays(1, None, Volatility::Immutable), aliases: vec![String::from("list_ndims")], } } @@ -174,13 +169,8 @@ impl ScalarUDFImpl for ArrayNdims { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, - _ => { - return plan_err!("The array_ndims function can only accept List/LargeList/FixedSizeList."); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(UInt64) } fn invoke_with_args( @@ -202,61 +192,42 @@ impl ScalarUDFImpl for ArrayNdims { /// Array_dims SQL function pub fn array_dims_inner(args: &[ArrayRef]) -> Result { let [array] = take_function_args("array_dims", args)?; - - let data = match array.data_type() { - List(_) => { - let array = as_list_array(&array)?; - array - .iter() - .map(compute_array_dims) - .collect::>>()? - } - LargeList(_) => { - let array = as_large_list_array(&array)?; - array - .iter() - .map(compute_array_dims) - .collect::>>()? - } - array_type => { - return exec_err!("array_dims does not support type '{array_type:?}'"); + let data: Vec<_> = match array.data_type() { + List(_) => as_list_array(&array)? + .iter() + .map(compute_array_dims) + .try_collect()?, + LargeList(_) => as_large_list_array(&array)? + .iter() + .map(compute_array_dims) + .try_collect()?, + FixedSizeList(..) => as_fixed_size_list_array(&array)? + .iter() + .map(compute_array_dims) + .try_collect()?, + arg_type => { + return exec_err!("array_dims does not support type {arg_type}"); } }; let result = ListArray::from_iter_primitive::(data); - - Ok(Arc::new(result) as ArrayRef) + Ok(Arc::new(result)) } /// Array_ndims SQL function pub fn array_ndims_inner(args: &[ArrayRef]) -> Result { - let [array_dim] = take_function_args("array_ndims", args)?; + let [array] = take_function_args("array_ndims", args)?; - fn general_list_ndims( - array: &GenericListArray, - ) -> Result { - let mut data = Vec::new(); - let ndims = datafusion_common::utils::list_ndims(array.data_type()); - - for arr in array.iter() { - if arr.is_some() { - data.push(Some(ndims)) - } else { - data.push(None) - } - } - - Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) + fn general_list_ndims(array: &ArrayRef) -> Result { + let ndims = list_ndims(array.data_type()); + let data = vec![ndims; array.len()]; + let result = UInt64Array::new(data.into(), array.nulls().cloned()); + Ok(Arc::new(result)) } - match array_dim.data_type() { - List(_) => { - let array = as_list_array(&array_dim)?; - general_list_ndims::(array) - } - LargeList(_) => { - let array = as_large_list_array(&array_dim)?; - general_list_ndims::(array) - } - array_type => exec_err!("array_ndims does not support type {array_type:?}"), + + match array.data_type() { + Null => Ok(Arc::new(UInt64Array::new_null(array.len()))), + List(_) | LargeList(_) | FixedSizeList(..) => general_list_ndims(array), + arg_type => exec_err!("array_ndims does not support type {arg_type}"), } } diff --git a/datafusion/functions-nested/src/distance.rs b/datafusion/functions-nested/src/distance.rs index cfc7fccdd70c..3392e194b176 100644 --- a/datafusion/functions-nested/src/distance.rs +++ b/datafusion/functions-nested/src/distance.rs @@ -23,21 +23,22 @@ use arrow::array::{ }; use arrow::datatypes::{ DataType, - DataType::{FixedSizeList, Float64, LargeList, List}, + DataType::{FixedSizeList, LargeList, List, Null}, }; use datafusion_common::cast::{ as_float32_array, as_float64_array, as_generic_list_array, as_int32_array, as_int64_array, }; -use datafusion_common::utils::coerced_fixed_size_list_to_list; +use datafusion_common::utils::{coerced_type_with_base_type_only, ListCoercion}; use datafusion_common::{ - exec_err, internal_datafusion_err, utils::take_function_args, Result, + exec_err, internal_datafusion_err, plan_err, utils::take_function_args, Result, }; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions::{downcast_arg, downcast_named_arg}; use datafusion_macros::user_doc; +use itertools::Itertools; use std::any::Any; use std::sync::Arc; @@ -104,24 +105,26 @@ impl ScalarUDFImpl for ArrayDistance { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => Ok(Float64), - _ => exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."), - } + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { let [_, _] = take_function_args(self.name(), arg_types)?; - let mut result = Vec::new(); - for arg_type in arg_types { - match arg_type { - List(_) | LargeList(_) | FixedSizeList(_, _) => result.push(coerced_fixed_size_list_to_list(arg_type)), - _ => return exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."), + let coercion = Some(&ListCoercion::FixedSizedListToList); + let arg_types = arg_types.iter().map(|arg_type| { + if matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + Ok(coerced_type_with_base_type_only( + arg_type, + &DataType::Float64, + coercion, + )) + } else { + plan_err!("{} does not support type {arg_type}", self.name()) } - } + }); - Ok(result) + arg_types.try_collect() } fn invoke_with_args( @@ -142,12 +145,11 @@ impl ScalarUDFImpl for ArrayDistance { pub fn array_distance_inner(args: &[ArrayRef]) -> Result { let [array1, array2] = take_function_args("array_distance", args)?; - - match (&array1.data_type(), &array2.data_type()) { + match (array1.data_type(), array2.data_type()) { (List(_), List(_)) => general_array_distance::(args), (LargeList(_), LargeList(_)) => general_array_distance::(args), - (array_type1, array_type2) => { - exec_err!("array_distance does not support types '{array_type1:?}' and '{array_type2:?}'") + (arg_type1, arg_type2) => { + exec_err!("array_distance does not support types {arg_type1} and {arg_type2}") } } } @@ -243,7 +245,7 @@ fn compute_array_distance( /// Converts an array of any numeric type to a Float64Array. fn convert_to_f64_array(array: &ArrayRef) -> Result { match array.data_type() { - Float64 => Ok(as_float64_array(array)?.clone()), + DataType::Float64 => Ok(as_float64_array(array)?.clone()), DataType::Float32 => { let array = as_float32_array(array)?; let converted: Float64Array = diff --git a/datafusion/functions-nested/src/empty.rs b/datafusion/functions-nested/src/empty.rs index dcefd583e937..67c795886bde 100644 --- a/datafusion/functions-nested/src/empty.rs +++ b/datafusion/functions-nested/src/empty.rs @@ -18,13 +18,14 @@ //! [`ScalarUDFImpl`] definitions for array_empty function. use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, BooleanArray, OffsetSizeTrait}; +use arrow::array::{Array, ArrayRef, BooleanArray, OffsetSizeTrait}; +use arrow::buffer::BooleanBuffer; use arrow::datatypes::{ DataType, DataType::{Boolean, FixedSizeList, LargeList, List}, }; use datafusion_common::cast::as_generic_list_array; -use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -71,7 +72,7 @@ impl Default for ArrayEmpty { impl ArrayEmpty { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature::arrays(1, None, Volatility::Immutable), aliases: vec!["array_empty".to_string(), "list_empty".to_string()], } } @@ -89,13 +90,8 @@ impl ScalarUDFImpl for ArrayEmpty { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => Boolean, - _ => { - return plan_err!("The array_empty function can only accept List/LargeList/FixedSizeList."); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Boolean) } fn invoke_with_args( @@ -117,21 +113,25 @@ impl ScalarUDFImpl for ArrayEmpty { /// Array_empty SQL function pub fn array_empty_inner(args: &[ArrayRef]) -> Result { let [array] = take_function_args("array_empty", args)?; - - let array_type = array.data_type(); - match array_type { + match array.data_type() { List(_) => general_array_empty::(array), LargeList(_) => general_array_empty::(array), - _ => exec_err!("array_empty does not support type '{array_type:?}'."), + FixedSizeList(_, size) => { + let values = if *size == 0 { + BooleanBuffer::new_set(array.len()) + } else { + BooleanBuffer::new_unset(array.len()) + }; + Ok(Arc::new(BooleanArray::new(values, array.nulls().cloned()))) + } + arg_type => exec_err!("array_empty does not support type {arg_type}"), } } fn general_array_empty(array: &ArrayRef) -> Result { - let array = as_generic_list_array::(array)?; - - let builder = array + let result = as_generic_list_array::(array)? .iter() .map(|arr| arr.map(|arr| arr.is_empty())) .collect::(); - Ok(Arc::new(builder)) + Ok(Arc::new(result)) } diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index 321dda55ce09..95bf5a7341d9 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -19,12 +19,12 @@ use arrow::array::{ Array, ArrayRef, ArrowNativeTypeOp, Capacities, GenericListArray, Int64Array, - MutableArrayData, NullBufferBuilder, OffsetSizeTrait, + MutableArrayData, NullArray, NullBufferBuilder, OffsetSizeTrait, }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::DataType; use arrow::datatypes::{ - DataType::{FixedSizeList, LargeList, List}, + DataType::{FixedSizeList, LargeList, List, Null}, Field, }; use datafusion_common::cast::as_int64_array; @@ -163,13 +163,9 @@ impl ScalarUDFImpl for ArrayElement { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { - List(field) - | LargeList(field) - | FixedSizeList(field, _) => Ok(field.data_type().clone()), - DataType::Null => Ok(List(Arc::new(Field::new_list_field(DataType::Int64, true)))), - _ => plan_err!( - "ArrayElement can only accept List, LargeList or FixedSizeList as the first argument" - ), + Null => Ok(Null), + List(field) | LargeList(field) => Ok(field.data_type().clone()), + arg_type => plan_err!("{} does not support type {arg_type}", self.name()), } } @@ -200,6 +196,7 @@ fn array_element_inner(args: &[ArrayRef]) -> Result { let [array, indexes] = take_function_args("array_element", args)?; match &array.data_type() { + Null => Ok(Arc::new(NullArray::new(array.len()))), List(_) => { let array = as_list_array(&array)?; let indexes = as_int64_array(&indexes)?; @@ -210,10 +207,9 @@ fn array_element_inner(args: &[ArrayRef]) -> Result { let indexes = as_int64_array(&indexes)?; general_array_element::(array, indexes) } - _ => exec_err!( - "array_element does not support type: {:?}", - array.data_type() - ), + arg_type => { + exec_err!("array_element does not support type {arg_type}") + } } } @@ -225,6 +221,10 @@ where i64: TryInto, { let values = array.values(); + if values.data_type().is_null() { + return Ok(Arc::new(NullArray::new(array.len()))); + } + let original_data = values.to_data(); let capacity = Capacities::Array(original_data.len()); @@ -238,8 +238,7 @@ where { let index: O = index.try_into().map_err(|_| { DataFusionError::Execution(format!( - "array_element got invalid index: {}", - index + "array_element got invalid index: {index}" )) })?; // 0 ~ len - 1 diff --git a/datafusion/functions-nested/src/flatten.rs b/datafusion/functions-nested/src/flatten.rs index f288035948dc..c6fa2831f4f0 100644 --- a/datafusion/functions-nested/src/flatten.rs +++ b/datafusion/functions-nested/src/flatten.rs @@ -18,19 +18,18 @@ //! [`ScalarUDFImpl`] definitions for flatten function. use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, GenericListArray, OffsetSizeTrait}; +use arrow::array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow::buffer::OffsetBuffer; use arrow::datatypes::{ DataType, DataType::{FixedSizeList, LargeList, List, Null}, }; -use datafusion_common::cast::{ - as_generic_list_array, as_large_list_array, as_list_array, -}; +use datafusion_common::cast::{as_large_list_array, as_list_array}; +use datafusion_common::utils::ListCoercion; use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ - ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, - TypeSignature, Volatility, + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -77,9 +76,11 @@ impl Flatten { pub fn new() -> Self { Self { signature: Signature { - // TODO (https://github.com/apache/datafusion/issues/13757) flatten should be single-step, not recursive type_signature: TypeSignature::ArraySignature( - ArrayFunctionSignature::RecursiveArray, + ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, ), volatility: Volatility::Immutable, }, @@ -102,25 +103,23 @@ impl ScalarUDFImpl for Flatten { } fn return_type(&self, arg_types: &[DataType]) -> Result { - fn get_base_type(data_type: &DataType) -> Result { - match data_type { - List(field) | FixedSizeList(field, _) - if matches!(field.data_type(), List(_) | FixedSizeList(_, _)) => - { - get_base_type(field.data_type()) - } - LargeList(field) if matches!(field.data_type(), LargeList(_)) => { - get_base_type(field.data_type()) + let data_type = match &arg_types[0] { + List(field) | FixedSizeList(field, _) => match field.data_type() { + List(field) | FixedSizeList(field, _) => List(Arc::clone(field)), + _ => arg_types[0].clone(), + }, + LargeList(field) => match field.data_type() { + List(field) | LargeList(field) | FixedSizeList(field, _) => { + LargeList(Arc::clone(field)) } - Null | List(_) | LargeList(_) => Ok(data_type.to_owned()), - FixedSizeList(field, _) => Ok(List(Arc::clone(field))), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), - } - } + _ => arg_types[0].clone(), + }, + Null => Null, + _ => exec_err!( + "Not reachable, data_type should be List, LargeList or FixedSizeList" + )?, + }; - let data_type = get_base_type(&arg_types[0])?; Ok(data_type) } @@ -146,14 +145,64 @@ pub fn flatten_inner(args: &[ArrayRef]) -> Result { match array.data_type() { List(_) => { - let list_arr = as_list_array(&array)?; - let flattened_array = flatten_internal::(list_arr.clone(), None)?; - Ok(Arc::new(flattened_array) as ArrayRef) + let (_field, offsets, values, nulls) = + as_list_array(&array)?.clone().into_parts(); + let values = cast_fsl_to_list(values)?; + + match values.data_type() { + List(_) => { + let (inner_field, inner_offsets, inner_values, _) = + as_list_array(&values)?.clone().into_parts(); + let offsets = get_offsets_for_flatten::(inner_offsets, offsets); + let flattened_array = GenericListArray::::new( + inner_field, + offsets, + inner_values, + nulls, + ); + + Ok(Arc::new(flattened_array) as ArrayRef) + } + LargeList(_) => { + exec_err!("flatten does not support type '{:?}'", array.data_type())? + } + _ => Ok(Arc::clone(array) as ArrayRef), + } } LargeList(_) => { - let list_arr = as_large_list_array(&array)?; - let flattened_array = flatten_internal::(list_arr.clone(), None)?; - Ok(Arc::new(flattened_array) as ArrayRef) + let (_field, offsets, values, nulls) = + as_large_list_array(&array)?.clone().into_parts(); + let values = cast_fsl_to_list(values)?; + + match values.data_type() { + List(_) => { + let (inner_field, inner_offsets, inner_values, _) = + as_list_array(&values)?.clone().into_parts(); + let offsets = get_large_offsets_for_flatten(inner_offsets, offsets); + let flattened_array = GenericListArray::::new( + inner_field, + offsets, + inner_values, + nulls, + ); + + Ok(Arc::new(flattened_array) as ArrayRef) + } + LargeList(_) => { + let (inner_field, inner_offsets, inner_values, nulls) = + as_large_list_array(&values)?.clone().into_parts(); + let offsets = get_offsets_for_flatten::(inner_offsets, offsets); + let flattened_array = GenericListArray::::new( + inner_field, + offsets, + inner_values, + nulls, + ); + + Ok(Arc::new(flattened_array) as ArrayRef) + } + _ => Ok(Arc::clone(array) as ArrayRef), + } } Null => Ok(Arc::clone(array)), _ => { @@ -162,37 +211,6 @@ pub fn flatten_inner(args: &[ArrayRef]) -> Result { } } -fn flatten_internal( - list_arr: GenericListArray, - indexes: Option>, -) -> Result> { - let (field, offsets, values, _) = list_arr.clone().into_parts(); - let data_type = field.data_type(); - - match data_type { - // Recursively get the base offsets for flattened array - List(_) | LargeList(_) => { - let sub_list = as_generic_list_array::(&values)?; - if let Some(indexes) = indexes { - let offsets = get_offsets_for_flatten(offsets, indexes); - flatten_internal::(sub_list.clone(), Some(offsets)) - } else { - flatten_internal::(sub_list.clone(), Some(offsets)) - } - } - // Reach the base level, create a new list array - _ => { - if let Some(indexes) = indexes { - let offsets = get_offsets_for_flatten(offsets, indexes); - let list_arr = GenericListArray::::new(field, offsets, values, None); - Ok(list_arr) - } else { - Ok(list_arr) - } - } - } -} - // Create new offsets that are equivalent to `flatten` the array. fn get_offsets_for_flatten( offsets: OffsetBuffer, @@ -205,3 +223,25 @@ fn get_offsets_for_flatten( .collect(); OffsetBuffer::new(offsets.into()) } + +// Create new large offsets that are equivalent to `flatten` the array. +fn get_large_offsets_for_flatten( + offsets: OffsetBuffer, + indexes: OffsetBuffer

, +) -> OffsetBuffer { + let buffer = offsets.into_inner(); + let offsets: Vec = indexes + .iter() + .map(|i| buffer[i.to_usize().unwrap()].to_i64().unwrap()) + .collect(); + OffsetBuffer::new(offsets.into()) +} + +fn cast_fsl_to_list(array: ArrayRef) -> Result { + match array.data_type() { + FixedSizeList(field, _) => { + Ok(arrow::compute::cast(&array, &List(Arc::clone(field)))?) + } + _ => Ok(array), + } +} diff --git a/datafusion/functions-nested/src/length.rs b/datafusion/functions-nested/src/length.rs index 3c3a42da0d69..0da12684158e 100644 --- a/datafusion/functions-nested/src/length.rs +++ b/datafusion/functions-nested/src/length.rs @@ -19,13 +19,16 @@ use crate::utils::make_scalar_function; use arrow::array::{ - Array, ArrayRef, Int64Array, LargeListArray, ListArray, OffsetSizeTrait, UInt64Array, + Array, ArrayRef, FixedSizeListArray, Int64Array, LargeListArray, ListArray, + OffsetSizeTrait, UInt64Array, }; use arrow::datatypes::{ DataType, DataType::{FixedSizeList, LargeList, List, UInt64}, }; -use datafusion_common::cast::{as_generic_list_array, as_int64_array}; +use datafusion_common::cast::{ + as_fixed_size_list_array, as_generic_list_array, as_int64_array, +}; use datafusion_common::{exec_err, internal_datafusion_err, plan_err, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, @@ -119,6 +122,23 @@ impl ScalarUDFImpl for ArrayLength { } } +macro_rules! array_length_impl { + ($array:expr, $dimension:expr) => {{ + let array = $array; + let dimension = match $dimension { + Some(d) => as_int64_array(d)?.clone(), + None => Int64Array::from_value(1, array.len()), + }; + let result = array + .iter() + .zip(dimension.iter()) + .map(|(arr, dim)| compute_array_length(arr, dim)) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) + }}; +} + /// Array_length SQL function pub fn array_length_inner(args: &[ArrayRef]) -> Result { if args.len() != 1 && args.len() != 2 { @@ -128,26 +148,18 @@ pub fn array_length_inner(args: &[ArrayRef]) -> Result { match &args[0].data_type() { List(_) => general_array_length::(args), LargeList(_) => general_array_length::(args), + FixedSizeList(_, _) => fixed_size_array_length(args), array_type => exec_err!("array_length does not support type '{array_type:?}'"), } } +fn fixed_size_array_length(array: &[ArrayRef]) -> Result { + array_length_impl!(as_fixed_size_list_array(&array[0])?, array.get(1)) +} + /// Dispatch array length computation based on the offset type. fn general_array_length(array: &[ArrayRef]) -> Result { - let list_array = as_generic_list_array::(&array[0])?; - let dimension = if array.len() == 2 { - as_int64_array(&array[1])?.clone() - } else { - Int64Array::from_value(1, list_array.len()) - }; - - let result = list_array - .iter() - .zip(dimension.iter()) - .map(|(arr, dim)| compute_array_length(arr, dim)) - .collect::>()?; - - Ok(Arc::new(result) as ArrayRef) + array_length_impl!(as_generic_list_array::(&array[0])?, array.get(1)) } /// Returns the length of a concrete array dimension @@ -185,6 +197,10 @@ fn compute_array_length( value = downcast_arg!(value, LargeListArray).value(0); current_dimension += 1; } + FixedSizeList(_, _) => { + value = downcast_arg!(value, FixedSizeListArray).value(0); + current_dimension += 1; + } _ => return Ok(None), } } diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index c9a61d98cd44..1d3f11b50c61 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -50,10 +50,11 @@ pub mod flatten; pub mod length; pub mod make_array; pub mod map; +pub mod map_entries; pub mod map_extract; pub mod map_keys; pub mod map_values; -pub mod max; +pub mod min_max; pub mod planner; pub mod position; pub mod range; @@ -95,9 +96,12 @@ pub mod expr_fn { pub use super::flatten::flatten; pub use super::length::array_length; pub use super::make_array::make_array; + pub use super::map_entries::map_entries; pub use super::map_extract::map_extract; pub use super::map_keys::map_keys; pub use super::map_values::map_values; + pub use super::min_max::array_max; + pub use super::min_max::array_min; pub use super::position::array_position; pub use super::position::array_positions; pub use super::range::gen_series; @@ -146,7 +150,8 @@ pub fn all_default_nested_functions() -> Vec> { length::array_length_udf(), distance::array_distance_udf(), flatten::flatten_udf(), - max::array_max_udf(), + min_max::array_max_udf(), + min_max::array_min_udf(), sort::array_sort_udf(), repeat::array_repeat_udf(), resize::array_resize_udf(), @@ -163,6 +168,7 @@ pub fn all_default_nested_functions() -> Vec> { replace::array_replace_all_udf(), replace::array_replace_udf(), map::map_udf(), + map_entries::map_entries_udf(), map_extract::map_extract_udf(), map_keys::map_keys_udf(), map_values::map_values_udf(), @@ -201,8 +207,7 @@ mod tests { for alias in func.aliases() { assert!( names.insert(alias.to_string().to_lowercase()), - "duplicate function name: {}", - alias + "duplicate function name: {alias}" ); } } diff --git a/datafusion/functions-nested/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs index 4daaafc5a888..babb03919157 100644 --- a/datafusion/functions-nested/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -28,10 +28,7 @@ use arrow::array::{ }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::DataType; -use arrow::datatypes::{ - DataType::{List, Null}, - Field, -}; +use arrow::datatypes::{DataType::Null, Field}; use datafusion_common::utils::SingleRowListArrayBuilder; use datafusion_common::{plan_err, Result}; use datafusion_expr::binary::{ @@ -105,16 +102,14 @@ impl ScalarUDFImpl for MakeArray { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match arg_types.len() { - 0 => Ok(empty_array_type()), - _ => { - // At this point, all the type in array should be coerced to the same one - Ok(List(Arc::new(Field::new_list_field( - arg_types[0].to_owned(), - true, - )))) - } - } + let element_type = if arg_types.is_empty() { + Null + } else { + // At this point, all the type in array should be coerced to the same one. + arg_types[0].to_owned() + }; + + Ok(DataType::new_list(element_type, true)) } fn invoke_with_args( @@ -129,26 +124,16 @@ impl ScalarUDFImpl for MakeArray { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let mut errors = vec![]; - match try_type_union_resolution_with_struct(arg_types) { - Ok(r) => return Ok(r), - Err(e) => { - errors.push(e); - } + if let Ok(unified) = try_type_union_resolution_with_struct(arg_types) { + return Ok(unified); } - if let Some(new_type) = type_union_resolution(arg_types) { - if new_type.is_null() { - Ok(vec![DataType::Int64; arg_types.len()]) - } else { - Ok(vec![new_type; arg_types.len()]) - } + if let Some(unified) = type_union_resolution(arg_types) { + Ok(vec![unified; arg_types.len()]) } else { plan_err!( - "Fail to find the valid type between {:?} for {}, errors are {:?}", - arg_types, - self.name(), - errors + "Failed to unify argument types of {}: {arg_types:?}", + self.name() ) } } @@ -158,35 +143,25 @@ impl ScalarUDFImpl for MakeArray { } } -// Empty array is a special case that is useful for many other array functions -pub(super) fn empty_array_type() -> DataType { - List(Arc::new(Field::new_list_field(DataType::Int64, true))) -} - /// `make_array_inner` is the implementation of the `make_array` function. /// Constructs an array using the input `data` as `ArrayRef`. /// Returns a reference-counted `Array` instance result. pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result { - let mut data_type = Null; - for arg in arrays { - let arg_data_type = arg.data_type(); - if !arg_data_type.equals_datatype(&Null) { - data_type = arg_data_type.clone(); - break; - } - } + let data_type = arrays.iter().find_map(|arg| { + let arg_type = arg.data_type(); + (!arg_type.is_null()).then_some(arg_type) + }); - match data_type { + let data_type = data_type.unwrap_or(&Null); + if data_type.is_null() { // Either an empty array or all nulls: - Null => { - let length = arrays.iter().map(|a| a.len()).sum(); - // By default Int64 - let array = new_null_array(&DataType::Int64, length); - Ok(Arc::new( - SingleRowListArrayBuilder::new(array).build_list_array(), - )) - } - _ => array_array::(arrays, data_type), + let length = arrays.iter().map(|a| a.len()).sum(); + let array = new_null_array(&Null, length); + Ok(Arc::new( + SingleRowListArrayBuilder::new(array).build_list_array(), + )) + } else { + array_array::(arrays, data_type.clone()) } } diff --git a/datafusion/functions-nested/src/map_entries.rs b/datafusion/functions-nested/src/map_entries.rs new file mode 100644 index 000000000000..b3323d0b5c39 --- /dev/null +++ b/datafusion/functions-nested/src/map_entries.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for map_entries function. + +use crate::utils::{get_map_entry_field, make_scalar_function}; +use arrow::array::{Array, ArrayRef, ListArray}; +use arrow::datatypes::{DataType, Field, Fields}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_expr::{ + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, + TypeSignature, Volatility, +}; +use datafusion_macros::user_doc; +use std::any::Any; +use std::sync::Arc; + +make_udf_expr_and_func!( + MapEntriesFunc, + map_entries, + map, + "Return a list of all entries in the map.", + map_entries_udf +); + +#[user_doc( + doc_section(label = "Map Functions"), + description = "Returns a list of all entries in the map.", + syntax_example = "map_entries(map)", + sql_example = r#"```sql +SELECT map_entries(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[{'key': a, 'value': 1}, {'key': b, 'value': NULL}, {'key': c, 'value': 3}] + +SELECT map_entries(map([100, 5], [42, 43])); +---- +[{'key': 100, 'value': 42}, {'key': 5, 'value': 43}] +```"#, + argument( + name = "map", + description = "Map expression. Can be a constant, column, or function, and any combination of map operators." + ) +)] +#[derive(Debug)] +pub struct MapEntriesFunc { + signature: Signature, +} + +impl Default for MapEntriesFunc { + fn default() -> Self { + Self::new() + } +} + +impl MapEntriesFunc { + pub fn new() -> Self { + Self { + signature: Signature::new( + TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray), + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for MapEntriesFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "map_entries" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let [map_type] = take_function_args(self.name(), arg_types)?; + let map_fields = get_map_entry_field(map_type)?; + Ok(DataType::List(Arc::new(Field::new_list_field( + DataType::Struct(Fields::from(vec![ + Field::new( + "key", + map_fields.first().unwrap().data_type().clone(), + false, + ), + Field::new( + "value", + map_fields.get(1).unwrap().data_type().clone(), + true, + ), + ])), + false, + )))) + } + + fn invoke_with_args( + &self, + args: datafusion_expr::ScalarFunctionArgs, + ) -> Result { + make_scalar_function(map_entries_inner)(&args.args) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn map_entries_inner(args: &[ArrayRef]) -> Result { + let [map_arg] = take_function_args("map_entries", args)?; + + let map_array = match map_arg.data_type() { + DataType::Map(_, _) => as_map_array(&map_arg)?, + _ => return exec_err!("Argument for map_entries should be a map"), + }; + + Ok(Arc::new(ListArray::new( + Arc::new(Field::new_list_field( + DataType::Struct(Fields::from(vec![ + Field::new("key", map_array.key_type().clone(), false), + Field::new("value", map_array.value_type().clone(), true), + ])), + false, + )), + map_array.offsets().clone(), + Arc::new(map_array.entries().clone()), + map_array.nulls().cloned(), + ))) +} diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs index f82e4bfa1a89..8247fdd4a74c 100644 --- a/datafusion/functions-nested/src/map_values.rs +++ b/datafusion/functions-nested/src/map_values.rs @@ -19,15 +19,16 @@ use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow::array::{Array, ArrayRef, ListArray}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::utils::take_function_args; -use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_common::{cast::as_map_array, exec_err, internal_err, Result}; use datafusion_expr::{ ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; +use std::ops::Deref; use std::sync::Arc; make_udf_expr_and_func!( @@ -91,13 +92,23 @@ impl ScalarUDFImpl for MapValuesFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - let [map_type] = take_function_args(self.name(), arg_types)?; - let map_fields = get_map_entry_field(map_type)?; - Ok(DataType::List(Arc::new(Field::new_list_field( - map_fields.last().unwrap().data_type().clone(), - true, - )))) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + let [map_type] = take_function_args(self.name(), args.arg_fields)?; + + Ok(Field::new( + self.name(), + DataType::List(get_map_values_field_as_list_field(map_type.data_type())?), + // Nullable if the map is nullable + args.arg_fields.iter().any(|x| x.is_nullable()), + ) + .into()) } fn invoke_with_args( @@ -121,9 +132,139 @@ fn map_values_inner(args: &[ArrayRef]) -> Result { }; Ok(Arc::new(ListArray::new( - Arc::new(Field::new_list_field(map_array.value_type().clone(), true)), + get_map_values_field_as_list_field(map_arg.data_type())?, map_array.offsets().clone(), Arc::clone(map_array.values()), map_array.nulls().cloned(), ))) } + +fn get_map_values_field_as_list_field(map_type: &DataType) -> Result { + let map_fields = get_map_entry_field(map_type)?; + + let values_field = map_fields + .last() + .unwrap() + .deref() + .clone() + .with_name(Field::LIST_FIELD_DEFAULT_NAME); + + Ok(Arc::new(values_field)) +} + +#[cfg(test)] +mod tests { + use crate::map_values::MapValuesFunc; + use arrow::datatypes::{DataType, Field, FieldRef}; + use datafusion_common::ScalarValue; + use datafusion_expr::ScalarUDFImpl; + use std::sync::Arc; + + #[test] + fn return_type_field() { + fn get_map_field( + is_map_nullable: bool, + is_keys_nullable: bool, + is_values_nullable: bool, + ) -> FieldRef { + Field::new_map( + "something", + "entries", + Arc::new(Field::new("keys", DataType::Utf8, is_keys_nullable)), + Arc::new(Field::new( + "values", + DataType::LargeUtf8, + is_values_nullable, + )), + false, + is_map_nullable, + ) + .into() + } + + fn get_list_field( + name: &str, + is_list_nullable: bool, + list_item_type: DataType, + is_list_items_nullable: bool, + ) -> FieldRef { + Field::new_list( + name, + Arc::new(Field::new_list_field( + list_item_type, + is_list_items_nullable, + )), + is_list_nullable, + ) + .into() + } + + fn get_return_field(field: FieldRef) -> FieldRef { + let func = MapValuesFunc::new(); + let args = datafusion_expr::ReturnFieldArgs { + arg_fields: &[field], + scalar_arguments: &[None::<&ScalarValue>], + }; + + func.return_field_from_args(args).unwrap() + } + + // Test cases: + // + // | Input Map || Expected Output | + // | ------------------------------------------------------ || ----------------------------------------------------- | + // | map nullable | map keys nullable | map values nullable || expected list nullable | expected list items nullable | + // | ------------ | ----------------- | ------------------- || ---------------------- | ---------------------------- | + // | false | false | false || false | false | + // | false | false | true || false | true | + // | false | true | false || false | false | + // | false | true | true || false | true | + // | true | false | false || true | false | + // | true | false | true || true | true | + // | true | true | false || true | false | + // | true | true | true || true | true | + // + // --------------- + // We added the key nullability to show that it does not affect the nullability of the list or the list items. + + assert_eq!( + get_return_field(get_map_field(false, false, false)), + get_list_field("map_values", false, DataType::LargeUtf8, false) + ); + + assert_eq!( + get_return_field(get_map_field(false, false, true)), + get_list_field("map_values", false, DataType::LargeUtf8, true) + ); + + assert_eq!( + get_return_field(get_map_field(false, true, false)), + get_list_field("map_values", false, DataType::LargeUtf8, false) + ); + + assert_eq!( + get_return_field(get_map_field(false, true, true)), + get_list_field("map_values", false, DataType::LargeUtf8, true) + ); + + assert_eq!( + get_return_field(get_map_field(true, false, false)), + get_list_field("map_values", true, DataType::LargeUtf8, false) + ); + + assert_eq!( + get_return_field(get_map_field(true, false, true)), + get_list_field("map_values", true, DataType::LargeUtf8, true) + ); + + assert_eq!( + get_return_field(get_map_field(true, true, false)), + get_list_field("map_values", true, DataType::LargeUtf8, false) + ); + + assert_eq!( + get_return_field(get_map_field(true, true, true)), + get_list_field("map_values", true, DataType::LargeUtf8, true) + ); + } +} diff --git a/datafusion/functions-nested/src/max.rs b/datafusion/functions-nested/src/max.rs deleted file mode 100644 index 32957edc62b5..000000000000 --- a/datafusion/functions-nested/src/max.rs +++ /dev/null @@ -1,138 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`ScalarUDFImpl`] definitions for array_max function. -use crate::utils::make_scalar_function; -use arrow::array::ArrayRef; -use arrow::datatypes::DataType; -use arrow::datatypes::DataType::List; -use datafusion_common::cast::as_list_array; -use datafusion_common::utils::take_function_args; -use datafusion_common::{exec_err, ScalarValue}; -use datafusion_doc::Documentation; -use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, -}; -use datafusion_functions_aggregate::min_max; -use datafusion_macros::user_doc; -use itertools::Itertools; -use std::any::Any; - -make_udf_expr_and_func!( - ArrayMax, - array_max, - array, - "returns the maximum value in the array.", - array_max_udf -); - -#[user_doc( - doc_section(label = "Array Functions"), - description = "Returns the maximum value in the array.", - syntax_example = "array_max(array)", - sql_example = r#"```sql -> select array_max([3,1,4,2]); -+-----------------------------------------+ -| array_max(List([3,1,4,2])) | -+-----------------------------------------+ -| 4 | -+-----------------------------------------+ -```"#, - argument( - name = "array", - description = "Array expression. Can be a constant, column, or function, and any combination of array operators." - ) -)] -#[derive(Debug)] -pub struct ArrayMax { - signature: Signature, - aliases: Vec, -} - -impl Default for ArrayMax { - fn default() -> Self { - Self::new() - } -} - -impl ArrayMax { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec!["list_max".to_string()], - } - } -} - -impl ScalarUDFImpl for ArrayMax { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - "array_max" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { - match &arg_types[0] { - List(field) => Ok(field.data_type().clone()), - _ => exec_err!("Not reachable, data_type should be List"), - } - } - - fn invoke_with_args( - &self, - args: ScalarFunctionArgs, - ) -> datafusion_common::Result { - make_scalar_function(array_max_inner)(&args.args) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } - - fn documentation(&self) -> Option<&Documentation> { - self.doc() - } -} - -/// array_max SQL function -/// -/// There is one argument for array_max as the array. -/// `array_max(array)` -/// -/// For example: -/// > array_max(\[1, 3, 2]) -> 3 -pub fn array_max_inner(args: &[ArrayRef]) -> datafusion_common::Result { - let [arg1] = take_function_args("array_max", args)?; - - match arg1.data_type() { - List(_) => { - let input_list_array = as_list_array(&arg1)?; - let result_vec = input_list_array - .iter() - .flat_map(|arr| min_max::max_batch(&arr.unwrap())) - .collect_vec(); - ScalarValue::iter_to_array(result_vec) - } - _ => exec_err!("array_max does not support type: {:?}", arg1.data_type()), - } -} diff --git a/datafusion/functions-nested/src/min_max.rs b/datafusion/functions-nested/src/min_max.rs new file mode 100644 index 000000000000..0eb416471a86 --- /dev/null +++ b/datafusion/functions-nested/src/min_max.rs @@ -0,0 +1,224 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_max function. +use crate::utils::make_scalar_function; +use arrow::array::{ArrayRef, GenericListArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{LargeList, List}; +use datafusion_common::cast::{as_large_list_array, as_list_array}; +use datafusion_common::utils::take_function_args; +use datafusion_common::Result; +use datafusion_common::{exec_err, plan_err, ScalarValue}; +use datafusion_doc::Documentation; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions_aggregate_common::min_max::{max_batch, min_batch}; +use datafusion_macros::user_doc; +use itertools::Itertools; +use std::any::Any; + +make_udf_expr_and_func!( + ArrayMax, + array_max, + array, + "returns the maximum value in the array.", + array_max_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the maximum value in the array.", + syntax_example = "array_max(array)", + sql_example = r#"```sql +> select array_max([3,1,4,2]); ++-----------------------------------------+ +| array_max(List([3,1,4,2])) | ++-----------------------------------------+ +| 4 | ++-----------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug)] +pub struct ArrayMax { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayMax { + fn default() -> Self { + Self::new() + } +} + +impl ArrayMax { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + aliases: vec!["list_max".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayMax { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_max" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let [array] = take_function_args(self.name(), arg_types)?; + match array { + List(field) | LargeList(field) => Ok(field.data_type().clone()), + arg_type => plan_err!("{} does not support type {arg_type}", self.name()), + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_max_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +/// array_max SQL function +/// +/// There is one argument for array_max as the array. +/// `array_max(array)` +/// +/// For example: +/// > array_max(\[1, 3, 2]) -> 3 +pub fn array_max_inner(args: &[ArrayRef]) -> Result { + let [array] = take_function_args("array_max", args)?; + match array.data_type() { + List(_) => array_min_max_helper(as_list_array(array)?, max_batch), + LargeList(_) => array_min_max_helper(as_large_list_array(array)?, max_batch), + arg_type => exec_err!("array_max does not support type: {arg_type}"), + } +} + +make_udf_expr_and_func!( + ArrayMin, + array_min, + array, + "returns the minimum value in the array", + array_min_udf +); +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the minimum value in the array.", + syntax_example = "array_min(array)", + sql_example = r#"```sql +> select array_min([3,1,4,2]); ++-----------------------------------------+ +| array_min(List([3,1,4,2])) | ++-----------------------------------------+ +| 1 | ++-----------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug)] +struct ArrayMin { + signature: Signature, +} + +impl Default for ArrayMin { + fn default() -> Self { + Self::new() + } +} + +impl ArrayMin { + fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ArrayMin { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_min" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let [array] = take_function_args(self.name(), arg_types)?; + match array { + List(field) | LargeList(field) => Ok(field.data_type().clone()), + arg_type => plan_err!("{} does not support type {}", self.name(), arg_type), + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_min_inner)(&args.args) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +pub fn array_min_inner(args: &[ArrayRef]) -> Result { + let [array] = take_function_args("array_min", args)?; + match array.data_type() { + List(_) => array_min_max_helper(as_list_array(array)?, min_batch), + LargeList(_) => array_min_max_helper(as_large_list_array(array)?, min_batch), + arg_type => exec_err!("array_min does not support type: {arg_type}"), + } +} + +fn array_min_max_helper( + array: &GenericListArray, + agg_fn: fn(&ArrayRef) -> Result, +) -> Result { + let null_value = ScalarValue::try_from(array.value_type())?; + let result_vec: Vec = array + .iter() + .map(|arr| arr.as_ref().map_or_else(|| Ok(null_value.clone()), agg_fn)) + .try_collect()?; + ScalarValue::iter_to_array(result_vec) +} diff --git a/datafusion/functions-nested/src/position.rs b/datafusion/functions-nested/src/position.rs index b186b65407c3..76ee7133d333 100644 --- a/datafusion/functions-nested/src/position.rs +++ b/datafusion/functions-nested/src/position.rs @@ -52,7 +52,7 @@ make_udf_expr_and_func!( #[user_doc( doc_section(label = "Array Functions"), - description = "Returns the position of the first occurrence of the specified element in the array.", + description = "Returns the position of the first occurrence of the specified element in the array, or NULL if not found.", syntax_example = "array_position(array, element)\narray_position(array, element, index)", sql_example = r#"```sql > select array_position([1, 2, 2, 3, 1, 4], 2); @@ -76,7 +76,10 @@ make_udf_expr_and_func!( name = "element", description = "Element to search for position in the array." ), - argument(name = "index", description = "Index at which to start searching.") + argument( + name = "index", + description = "Index at which to start searching (1-indexed)." + ) )] #[derive(Debug)] pub struct ArrayPosition { @@ -170,7 +173,7 @@ fn general_position_dispatch(args: &[ArrayRef]) -> Result= arr.len() { + if from < 0 || from as usize > arr.len() { return internal_err!("start_from index out of bounds"); } } else { diff --git a/datafusion/functions-nested/src/reverse.rs b/datafusion/functions-nested/src/reverse.rs index 140cd19aeff9..eb36047e09f6 100644 --- a/datafusion/functions-nested/src/reverse.rs +++ b/datafusion/functions-nested/src/reverse.rs @@ -19,12 +19,15 @@ use crate::utils::make_scalar_function; use arrow::array::{ - Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, OffsetSizeTrait, + Array, ArrayRef, Capacities, FixedSizeListArray, GenericListArray, MutableArrayData, + OffsetSizeTrait, }; use arrow::buffer::OffsetBuffer; -use arrow::datatypes::DataType::{LargeList, List, Null}; +use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null}; use arrow::datatypes::{DataType, FieldRef}; -use datafusion_common::cast::{as_large_list_array, as_list_array}; +use datafusion_common::cast::{ + as_fixed_size_list_array, as_large_list_array, as_list_array, +}; use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, @@ -125,6 +128,10 @@ pub fn array_reverse_inner(arg: &[ArrayRef]) -> Result { let array = as_large_list_array(input_array)?; general_array_reverse::(array, field) } + FixedSizeList(field, _) => { + let array = as_fixed_size_list_array(input_array)?; + fixed_size_array_reverse(array, field) + } Null => Ok(Arc::clone(input_array)), array_type => exec_err!("array_reverse does not support type '{array_type:?}'."), } @@ -175,3 +182,40 @@ fn general_array_reverse>( Some(nulls.into()), )?)) } + +fn fixed_size_array_reverse( + array: &FixedSizeListArray, + field: &FieldRef, +) -> Result { + let values = array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(original_data.len()); + let mut nulls = vec![]; + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); + let value_length = array.value_length() as usize; + + for row_index in 0..array.len() { + // skip the null value + if array.is_null(row_index) { + nulls.push(false); + mutable.extend(0, 0, 1); + continue; + } else { + nulls.push(true); + } + let start = row_index * value_length; + let end = start + value_length; + for idx in (start..end).rev() { + mutable.extend(0, idx, idx + 1); + } + } + + let data = mutable.freeze(); + Ok(Arc::new(FixedSizeListArray::try_new( + Arc::clone(field), + array.value_length(), + arrow::array::make_array(data), + Some(nulls.into()), + )?)) +} diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index a67945b1f1e1..4f9457aa59c6 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -17,16 +17,21 @@ //! [`ScalarUDFImpl`] definitions for array_union, array_intersect and array_distinct functions. -use crate::make_array::{empty_array_type, make_array_inner}; use crate::utils::make_scalar_function; -use arrow::array::{new_empty_array, Array, ArrayRef, GenericListArray, OffsetSizeTrait}; +use arrow::array::{ + new_null_array, Array, ArrayRef, GenericListArray, LargeListArray, ListArray, + OffsetSizeTrait, +}; use arrow::buffer::OffsetBuffer; use arrow::compute; -use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null}; +use arrow::datatypes::DataType::{LargeList, List, Null}; use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::row::{RowConverter, SortField}; use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{exec_err, internal_err, utils::take_function_args, Result}; +use datafusion_common::utils::ListCoercion; +use datafusion_common::{ + exec_err, internal_err, plan_err, utils::take_function_args, Result, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; @@ -104,7 +109,11 @@ impl Default for ArrayUnion { impl ArrayUnion { pub fn new() -> Self { Self { - signature: Signature::any(2, Volatility::Immutable), + signature: Signature::arrays( + 2, + Some(ListCoercion::FixedSizedListToList), + Volatility::Immutable, + ), aliases: vec![String::from("list_union")], } } @@ -124,8 +133,10 @@ impl ScalarUDFImpl for ArrayUnion { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match (&arg_types[0], &arg_types[1]) { - (&Null, dt) => Ok(dt.clone()), + let [array1, array2] = take_function_args(self.name(), arg_types)?; + match (array1, array2) { + (Null, Null) => Ok(DataType::new_list(Null, true)), + (Null, dt) => Ok(dt.clone()), (dt, Null) => Ok(dt.clone()), (dt, _) => Ok(dt.clone()), } @@ -183,7 +194,11 @@ pub(super) struct ArrayIntersect { impl ArrayIntersect { pub fn new() -> Self { Self { - signature: Signature::any(2, Volatility::Immutable), + signature: Signature::arrays( + 2, + Some(ListCoercion::FixedSizedListToList), + Volatility::Immutable, + ), aliases: vec![String::from("list_intersect")], } } @@ -203,10 +218,12 @@ impl ScalarUDFImpl for ArrayIntersect { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match (arg_types[0].clone(), arg_types[1].clone()) { - (Null, Null) | (Null, _) => Ok(Null), - (_, Null) => Ok(empty_array_type()), - (dt, _) => Ok(dt), + let [array1, array2] = take_function_args(self.name(), arg_types)?; + match (array1, array2) { + (Null, Null) => Ok(DataType::new_list(Null, true)), + (Null, dt) => Ok(dt.clone()), + (dt, Null) => Ok(dt.clone()), + (dt, _) => Ok(dt.clone()), } } @@ -273,16 +290,11 @@ impl ScalarUDFImpl for ArrayDistinct { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(Arc::new( - Field::new_list_field(field.data_type().clone(), true), - ))), - LargeList(field) => Ok(LargeList(Arc::new(Field::new_list_field( - field.data_type().clone(), - true, - )))), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), + List(field) => Ok(DataType::new_list(field.data_type().clone(), true)), + LargeList(field) => { + Ok(DataType::new_large_list(field.data_type().clone(), true)) + } + arg_type => plan_err!("{} does not support type {arg_type}", self.name()), } } @@ -305,24 +317,18 @@ impl ScalarUDFImpl for ArrayDistinct { /// array_distinct SQL function /// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4] fn array_distinct_inner(args: &[ArrayRef]) -> Result { - let [input_array] = take_function_args("array_distinct", args)?; - - // handle null - if input_array.data_type() == &Null { - return Ok(Arc::clone(input_array)); - } - - // handle for list & largelist - match input_array.data_type() { + let [array] = take_function_args("array_distinct", args)?; + match array.data_type() { + Null => Ok(Arc::clone(array)), List(field) => { - let array = as_list_array(&input_array)?; + let array = as_list_array(&array)?; general_array_distinct(array, field) } LargeList(field) => { - let array = as_large_list_array(&input_array)?; + let array = as_large_list_array(&array)?; general_array_distinct(array, field) } - array_type => exec_err!("array_distinct does not support type '{array_type:?}'"), + arg_type => exec_err!("array_distinct does not support type {arg_type}"), } } @@ -347,80 +353,76 @@ fn generic_set_lists( field: Arc, set_op: SetOp, ) -> Result { - if matches!(l.value_type(), Null) { + if l.is_empty() || l.value_type().is_null() { let field = Arc::new(Field::new_list_field(r.value_type(), true)); return general_array_distinct::(r, &field); - } else if matches!(r.value_type(), Null) { + } else if r.is_empty() || r.value_type().is_null() { let field = Arc::new(Field::new_list_field(l.value_type(), true)); return general_array_distinct::(l, &field); } - // Handle empty array at rhs case - // array_union(arr, []) -> arr; - // array_intersect(arr, []) -> []; - if r.value_length(0).is_zero() { - if set_op == SetOp::Union { - return Ok(Arc::new(l.clone()) as ArrayRef); - } else { - return Ok(Arc::new(r.clone()) as ArrayRef); - } - } - if l.value_type() != r.value_type() { return internal_err!("{set_op:?} is not implemented for '{l:?}' and '{r:?}'"); } - let dt = l.value_type(); - let mut offsets = vec![OffsetSize::usize_as(0)]; let mut new_arrays = vec![]; - - let converter = RowConverter::new(vec![SortField::new(dt)])?; + let converter = RowConverter::new(vec![SortField::new(l.value_type())])?; for (first_arr, second_arr) in l.iter().zip(r.iter()) { - if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { - let l_values = converter.convert_columns(&[first_arr])?; - let r_values = converter.convert_columns(&[second_arr])?; - - let l_iter = l_values.iter().sorted().dedup(); - let values_set: HashSet<_> = l_iter.clone().collect(); - let mut rows = if set_op == SetOp::Union { - l_iter.collect::>() - } else { - vec![] - }; - for r_val in r_values.iter().sorted().dedup() { - match set_op { - SetOp::Union => { - if !values_set.contains(&r_val) { - rows.push(r_val); - } + let l_values = if let Some(first_arr) = first_arr { + converter.convert_columns(&[first_arr])? + } else { + converter.convert_columns(&[])? + }; + + let r_values = if let Some(second_arr) = second_arr { + converter.convert_columns(&[second_arr])? + } else { + converter.convert_columns(&[])? + }; + + let l_iter = l_values.iter().sorted().dedup(); + let values_set: HashSet<_> = l_iter.clone().collect(); + let mut rows = if set_op == SetOp::Union { + l_iter.collect() + } else { + vec![] + }; + + for r_val in r_values.iter().sorted().dedup() { + match set_op { + SetOp::Union => { + if !values_set.contains(&r_val) { + rows.push(r_val); } - SetOp::Intersect => { - if values_set.contains(&r_val) { - rows.push(r_val); - } + } + SetOp::Intersect => { + if values_set.contains(&r_val) { + rows.push(r_val); } } } - - let last_offset = match offsets.last().copied() { - Some(offset) => offset, - None => return internal_err!("offsets should not be empty"), - }; - offsets.push(last_offset + OffsetSize::usize_as(rows.len())); - let arrays = converter.convert_rows(rows)?; - let array = match arrays.first() { - Some(array) => Arc::clone(array), - None => { - return internal_err!("{set_op}: failed to get array from rows"); - } - }; - new_arrays.push(array); } + + let last_offset = match offsets.last() { + Some(offset) => *offset, + None => return internal_err!("offsets should not be empty"), + }; + + offsets.push(last_offset + OffsetSize::usize_as(rows.len())); + let arrays = converter.convert_rows(rows)?; + let array = match arrays.first() { + Some(array) => Arc::clone(array), + None => { + return internal_err!("{set_op}: failed to get array from rows"); + } + }; + + new_arrays.push(array); } let offsets = OffsetBuffer::new(offsets.into()); - let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); + let new_arrays_ref: Vec<_> = new_arrays.iter().map(|v| v.as_ref()).collect(); let values = compute::concat(&new_arrays_ref)?; let arr = GenericListArray::::try_new(field, offsets, values, None)?; Ok(Arc::new(arr)) @@ -431,38 +433,59 @@ fn general_set_op( array2: &ArrayRef, set_op: SetOp, ) -> Result { + fn empty_array(data_type: &DataType, len: usize, large: bool) -> Result { + let field = Arc::new(Field::new_list_field(data_type.clone(), true)); + let values = new_null_array(data_type, len); + if large { + Ok(Arc::new(LargeListArray::try_new( + field, + OffsetBuffer::new_zeroed(len), + values, + None, + )?)) + } else { + Ok(Arc::new(ListArray::try_new( + field, + OffsetBuffer::new_zeroed(len), + values, + None, + )?)) + } + } + match (array1.data_type(), array2.data_type()) { + (Null, Null) => Ok(Arc::new(ListArray::new_null( + Arc::new(Field::new_list_field(Null, true)), + array1.len(), + ))), (Null, List(field)) => { if set_op == SetOp::Intersect { - return Ok(new_empty_array(&Null)); + return empty_array(field.data_type(), array1.len(), false); } let array = as_list_array(&array2)?; general_array_distinct::(array, field) } - (List(field), Null) => { if set_op == SetOp::Intersect { - return make_array_inner(&[]); + return empty_array(field.data_type(), array1.len(), false); } let array = as_list_array(&array1)?; general_array_distinct::(array, field) } (Null, LargeList(field)) => { if set_op == SetOp::Intersect { - return Ok(new_empty_array(&Null)); + return empty_array(field.data_type(), array1.len(), true); } let array = as_large_list_array(&array2)?; general_array_distinct::(array, field) } (LargeList(field), Null) => { if set_op == SetOp::Intersect { - return make_array_inner(&[]); + return empty_array(field.data_type(), array1.len(), true); } let array = as_large_list_array(&array1)?; general_array_distinct::(array, field) } - (Null, Null) => Ok(new_empty_array(&Null)), - (List(field), List(_)) => { let array1 = as_list_array(&array1)?; let array2 = as_list_array(&array2)?; diff --git a/datafusion/functions-nested/src/sort.rs b/datafusion/functions-nested/src/sort.rs index 85737ef135bc..7b2f41c0541c 100644 --- a/datafusion/functions-nested/src/sort.rs +++ b/datafusion/functions-nested/src/sort.rs @@ -21,11 +21,11 @@ use crate::utils::make_scalar_function; use arrow::array::{new_null_array, Array, ArrayRef, ListArray, NullBufferBuilder}; use arrow::buffer::OffsetBuffer; use arrow::compute::SortColumn; -use arrow::datatypes::DataType::{FixedSizeList, LargeList, List}; use arrow::datatypes::{DataType, Field}; use arrow::{compute, compute::SortOptions}; use datafusion_common::cast::{as_list_array, as_string_array}; -use datafusion_common::{exec_err, Result}; +use datafusion_common::utils::ListCoercion; +use datafusion_common::{exec_err, plan_err, Result}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -93,14 +93,14 @@ impl ArraySort { vec![ TypeSignature::ArraySignature(ArrayFunctionSignature::Array { arguments: vec![ArrayFunctionArgument::Array], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }), TypeSignature::ArraySignature(ArrayFunctionSignature::Array { arguments: vec![ ArrayFunctionArgument::Array, ArrayFunctionArgument::String, ], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }), TypeSignature::ArraySignature(ArrayFunctionSignature::Array { arguments: vec![ @@ -108,7 +108,7 @@ impl ArraySort { ArrayFunctionArgument::String, ArrayFunctionArgument::String, ], - array_coercion: None, + array_coercion: Some(ListCoercion::FixedSizedListToList), }), ], Volatility::Immutable, @@ -133,17 +133,13 @@ impl ScalarUDFImpl for ArraySort { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(Arc::new( - Field::new_list_field(field.data_type().clone(), true), - ))), - LargeList(field) => Ok(LargeList(Arc::new(Field::new_list_field( - field.data_type().clone(), - true, - )))), DataType::Null => Ok(DataType::Null), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), + DataType::List(field) => { + Ok(DataType::new_list(field.data_type().clone(), true)) + } + arg_type => { + plan_err!("{} does not support type {arg_type}", self.name()) + } } } @@ -169,6 +165,16 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result { return exec_err!("array_sort expects one to three arguments"); } + if args[0].data_type().is_null() { + return Ok(Arc::clone(&args[0])); + } + + let list_array = as_list_array(&args[0])?; + let row_count = list_array.len(); + if row_count == 0 || list_array.value_type().is_null() { + return Ok(Arc::clone(&args[0])); + } + if args[1..].iter().any(|array| array.is_null(0)) { return Ok(new_null_array(args[0].data_type(), args[0].len())); } @@ -193,12 +199,6 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result { _ => return exec_err!("array_sort expects 1 to 3 arguments"), }; - let list_array = as_list_array(&args[0])?; - let row_count = list_array.len(); - if row_count == 0 { - return Ok(Arc::clone(&args[0])); - } - let mut array_lengths = vec![]; let mut arrays = vec![]; let mut valid = NullBufferBuilder::new(row_count); diff --git a/datafusion/functions-nested/src/utils.rs b/datafusion/functions-nested/src/utils.rs index 74b21a3ceb47..ed08a8235874 100644 --- a/datafusion/functions-nested/src/utils.rs +++ b/datafusion/functions-nested/src/utils.rs @@ -22,17 +22,15 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Field, Fields}; use arrow::array::{ - Array, ArrayRef, BooleanArray, GenericListArray, ListArray, OffsetSizeTrait, Scalar, - UInt32Array, + Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar, UInt32Array, }; use arrow::buffer::OffsetBuffer; -use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{ - exec_err, internal_datafusion_err, internal_err, plan_err, Result, ScalarValue, +use datafusion_common::cast::{ + as_fixed_size_list_array, as_large_list_array, as_list_array, }; +use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::ColumnarValue; -use datafusion_functions::{downcast_arg, downcast_named_arg}; pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> { let data_type = args[0].data_type(); @@ -234,8 +232,16 @@ pub(crate) fn compute_array_dims( loop { match value.data_type() { - DataType::List(..) => { - value = downcast_arg!(value, ListArray).value(0); + DataType::List(_) => { + value = as_list_array(&value)?.value(0); + res.push(Some(value.len() as u64)); + } + DataType::LargeList(_) => { + value = as_large_list_array(&value)?.value(0); + res.push(Some(value.len() as u64)); + } + DataType::FixedSizeList(..) => { + value = as_fixed_size_list_array(&value)?.value(0); res.push(Some(value.len() as u64)); } _ => return Ok(Some(res)), @@ -261,6 +267,7 @@ pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> { #[cfg(test)] mod tests { use super::*; + use arrow::array::ListArray; use arrow::datatypes::Int64Type; use datafusion_common::utils::SingleRowListArrayBuilder; diff --git a/datafusion/functions-table/README.md b/datafusion/functions-table/README.md index c4e7a5aff999..485abe560dad 100644 --- a/datafusion/functions-table/README.md +++ b/datafusion/functions-table/README.md @@ -23,4 +23,9 @@ This crate contains table functions that can be used in DataFusion queries. +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/functions-table/src/generate_series.rs b/datafusion/functions-table/src/generate_series.rs index ee95567ab73d..ecd8870124ab 100644 --- a/datafusion/functions-table/src/generate_series.rs +++ b/datafusion/functions-table/src/generate_series.rs @@ -15,8 +15,12 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::Int64Array; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::array::timezone::Tz; +use arrow::array::types::TimestampNanosecondType; +use arrow::array::{ArrayRef, Int64Array, TimestampNanosecondArray}; +use arrow::datatypes::{ + DataType, Field, IntervalMonthDayNano, Schema, SchemaRef, TimeUnit, +}; use arrow::record_batch::RecordBatch; use async_trait::async_trait; use datafusion_catalog::Session; @@ -28,97 +32,265 @@ use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec}; use datafusion_physical_plan::ExecutionPlan; use parking_lot::RwLock; use std::fmt; +use std::str::FromStr; use std::sync::Arc; +/// Empty generator that produces no rows - used when series arguments contain null values +#[derive(Debug, Clone)] +struct Empty { + name: &'static str, +} + +impl LazyBatchGenerator for Empty { + fn generate_next_batch(&mut self) -> Result> { + Ok(None) + } +} + +impl fmt::Display for Empty { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}: empty", self.name) + } +} + +/// Trait for values that can be generated in a series +trait SeriesValue: fmt::Debug + Clone + Send + Sync + 'static { + type StepType: fmt::Debug + Clone + Send + Sync; + type ValueType: fmt::Debug + Clone + Send + Sync; + + /// Check if we've reached the end of the series + fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool; + + /// Advance to the next value in the series + fn advance(&mut self, step: &Self::StepType) -> Result<()>; + + /// Create an Arrow array from a vector of values + fn create_array(&self, values: Vec) -> Result; + + /// Convert self to ValueType for array creation + fn to_value_type(&self) -> Self::ValueType; + + /// Display the value for debugging + fn display_value(&self) -> String; +} + +impl SeriesValue for i64 { + type StepType = i64; + type ValueType = i64; + + fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool { + reach_end_int64(*self, end, *step, include_end) + } + + fn advance(&mut self, step: &Self::StepType) -> Result<()> { + *self += step; + Ok(()) + } + + fn create_array(&self, values: Vec) -> Result { + Ok(Arc::new(Int64Array::from(values))) + } + + fn to_value_type(&self) -> Self::ValueType { + *self + } + + fn display_value(&self) -> String { + self.to_string() + } +} + +#[derive(Debug, Clone)] +struct TimestampValue { + value: i64, + parsed_tz: Option, + tz_str: Option>, +} + +impl SeriesValue for TimestampValue { + type StepType = IntervalMonthDayNano; + type ValueType = i64; + + fn should_stop(&self, end: Self, step: &Self::StepType, include_end: bool) -> bool { + let step_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0; + + if include_end { + if step_negative { + self.value < end.value + } else { + self.value > end.value + } + } else if step_negative { + self.value <= end.value + } else { + self.value >= end.value + } + } + + fn advance(&mut self, step: &Self::StepType) -> Result<()> { + let tz = self + .parsed_tz + .unwrap_or_else(|| Tz::from_str("+00:00").unwrap()); + let Some(next_ts) = + TimestampNanosecondType::add_month_day_nano(self.value, *step, tz) + else { + return plan_err!( + "Failed to add interval {:?} to timestamp {}", + step, + self.value + ); + }; + self.value = next_ts; + Ok(()) + } + + fn create_array(&self, values: Vec) -> Result { + let array = TimestampNanosecondArray::from(values); + + // Use timezone from self (now we have access to tz through &self) + let array = match self.tz_str.as_ref() { + Some(tz_str) => array.with_timezone(Arc::clone(tz_str)), + None => array, + }; + + Ok(Arc::new(array)) + } + + fn to_value_type(&self) -> Self::ValueType { + self.value + } + + fn display_value(&self) -> String { + self.value.to_string() + } +} + /// Indicates the arguments used for generating a series. #[derive(Debug, Clone)] enum GenSeriesArgs { /// ContainsNull signifies that at least one argument(start, end, step) was null, thus no series will be generated. - ContainsNull { + ContainsNull { name: &'static str }, + /// Int64Args holds the start, end, and step values for generating integer series when all arguments are not null. + Int64Args { + start: i64, + end: i64, + step: i64, + /// Indicates whether the end value should be included in the series. include_end: bool, name: &'static str, }, - /// AllNotNullArgs holds the start, end, and step values for generating the series when all arguments are not null. - AllNotNullArgs { + /// TimestampArgs holds the start, end, and step values for generating timestamp series when all arguments are not null. + TimestampArgs { start: i64, end: i64, - step: i64, + step: IntervalMonthDayNano, + tz: Option>, + /// Indicates whether the end value should be included in the series. + include_end: bool, + name: &'static str, + }, + /// DateArgs holds the start, end, and step values for generating date series when all arguments are not null. + /// Internally, dates are converted to timestamps and use the timestamp logic. + DateArgs { + start: i64, + end: i64, + step: IntervalMonthDayNano, /// Indicates whether the end value should be included in the series. include_end: bool, name: &'static str, }, } -/// Table that generates a series of integers from `start`(inclusive) to `end`(inclusive), incrementing by step +/// Table that generates a series of integers/timestamps from `start`(inclusive) to `end`, incrementing by step #[derive(Debug, Clone)] struct GenerateSeriesTable { schema: SchemaRef, args: GenSeriesArgs, } -/// Table state that generates a series of integers from `start`(inclusive) to `end`(inclusive), incrementing by step #[derive(Debug, Clone)] -struct GenerateSeriesState { +struct GenericSeriesState { schema: SchemaRef, - start: i64, // Kept for display - end: i64, - step: i64, + start: T, + end: T, + step: T::StepType, batch_size: usize, - - /// Tracks current position when generating table - current: i64, - /// Indicates whether the end value should be included in the series. + current: T, include_end: bool, name: &'static str, } -impl GenerateSeriesState { - fn reach_end(&self, val: i64) -> bool { - if self.step > 0 { - if self.include_end { - return val > self.end; - } else { - return val >= self.end; - } +impl LazyBatchGenerator for GenericSeriesState { + fn generate_next_batch(&mut self) -> Result> { + let mut buf = Vec::with_capacity(self.batch_size); + + while buf.len() < self.batch_size + && !self + .current + .should_stop(self.end.clone(), &self.step, self.include_end) + { + buf.push(self.current.to_value_type()); + self.current.advance(&self.step)?; } - if self.include_end { - val < self.end - } else { - val <= self.end + if buf.is_empty() { + return Ok(None); } + + let array = self.current.create_array(buf)?; + let batch = RecordBatch::try_new(Arc::clone(&self.schema), vec![array])?; + Ok(Some(batch)) } } -/// Detail to display for 'Explain' plan -impl fmt::Display for GenerateSeriesState { +impl fmt::Display for GenericSeriesState { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, "{}: start={}, end={}, batch_size={}", - self.name, self.start, self.end, self.batch_size + self.name, + self.start.display_value(), + self.end.display_value(), + self.batch_size ) } } -impl LazyBatchGenerator for GenerateSeriesState { - fn generate_next_batch(&mut self) -> Result> { - let mut buf = Vec::with_capacity(self.batch_size); - while buf.len() < self.batch_size && !self.reach_end(self.current) { - buf.push(self.current); - self.current += self.step; +fn reach_end_int64(val: i64, end: i64, step: i64, include_end: bool) -> bool { + if step > 0 { + if include_end { + val > end + } else { + val >= end } - let array = Int64Array::from(buf); + } else if include_end { + val < end + } else { + val <= end + } +} - if array.is_empty() { - return Ok(None); - } +fn validate_interval_step( + step: IntervalMonthDayNano, + start: i64, + end: i64, +) -> Result<()> { + if step.months == 0 && step.days == 0 && step.nanoseconds == 0 { + return plan_err!("Step interval cannot be zero"); + } - let batch = - RecordBatch::try_new(Arc::clone(&self.schema), vec![Arc::new(array)])?; + let step_is_positive = step.months > 0 || step.days > 0 || step.nanoseconds > 0; + let step_is_negative = step.months < 0 || step.days < 0 || step.nanoseconds < 0; - Ok(Some(batch)) + if start > end && step_is_positive { + return plan_err!("Start is bigger than end, but increment is positive: Cannot generate infinite series"); } + + if start < end && step_is_negative { + return plan_err!("Start is smaller than end, but increment is negative: Cannot generate infinite series"); + } + + Ok(()) } #[async_trait] @@ -147,40 +319,96 @@ impl TableProvider for GenerateSeriesTable { Some(projection) => Arc::new(self.schema.project(projection)?), None => self.schema(), }; - let state = match self.args { - // if args have null, then return 0 row - GenSeriesArgs::ContainsNull { include_end, name } => GenerateSeriesState { + let generator: Arc> = match &self.args { + GenSeriesArgs::ContainsNull { name } => Arc::new(RwLock::new(Empty { name })), + GenSeriesArgs::Int64Args { + start, + end, + step, + include_end, + name, + } => Arc::new(RwLock::new(GenericSeriesState { schema: self.schema(), - start: 0, - end: 0, - step: 1, - current: 1, + start: *start, + end: *end, + step: *step, + current: *start, batch_size, - include_end, + include_end: *include_end, name, - }, - GenSeriesArgs::AllNotNullArgs { + })), + GenSeriesArgs::TimestampArgs { start, end, step, + tz, include_end, name, - } => GenerateSeriesState { - schema: self.schema(), + } => { + let parsed_tz = tz + .as_ref() + .map(|s| Tz::from_str(s.as_ref())) + .transpose() + .map_err(|e| { + datafusion_common::DataFusionError::Internal(format!( + "Failed to parse timezone: {e}" + )) + })? + .unwrap_or_else(|| Tz::from_str("+00:00").unwrap()); + Arc::new(RwLock::new(GenericSeriesState { + schema: self.schema(), + start: TimestampValue { + value: *start, + parsed_tz: Some(parsed_tz), + tz_str: tz.clone(), + }, + end: TimestampValue { + value: *end, + parsed_tz: Some(parsed_tz), + tz_str: tz.clone(), + }, + step: *step, + current: TimestampValue { + value: *start, + parsed_tz: Some(parsed_tz), + tz_str: tz.clone(), + }, + batch_size, + include_end: *include_end, + name, + })) + } + GenSeriesArgs::DateArgs { start, end, step, - current: start, - batch_size, include_end, name, - }, + } => Arc::new(RwLock::new(GenericSeriesState { + schema: self.schema(), + start: TimestampValue { + value: *start, + parsed_tz: None, + tz_str: None, + }, + end: TimestampValue { + value: *end, + parsed_tz: None, + tz_str: None, + }, + step: *step, + current: TimestampValue { + value: *start, + parsed_tz: None, + tz_str: None, + }, + batch_size, + include_end: *include_end, + name, + })), }; - Ok(Arc::new(LazyMemoryExec::try_new( - schema, - vec![Arc::new(RwLock::new(state))], - )?)) + Ok(Arc::new(LazyMemoryExec::try_new(schema, vec![generator])?)) } } @@ -196,12 +424,44 @@ impl TableFunctionImpl for GenerateSeriesFuncImpl { return plan_err!("{} function requires 1 to 3 arguments", self.name); } + // Determine the data type from the first argument + match &exprs[0] { + Expr::Literal( + // Default to int64 for null + ScalarValue::Null | ScalarValue::Int64(_), + _, + ) => self.call_int64(exprs), + Expr::Literal(s, _) if matches!(s.data_type(), DataType::Timestamp(_, _)) => { + self.call_timestamp(exprs) + } + Expr::Literal(s, _) if matches!(s.data_type(), DataType::Date32) => { + self.call_date(exprs) + } + Expr::Literal(scalar, _) => { + plan_err!( + "Argument #1 must be an INTEGER, TIMESTAMP, DATE or NULL, got {:?}", + scalar.data_type() + ) + } + _ => plan_err!("Arguments must be literals"), + } + } +} + +impl GenerateSeriesFuncImpl { + fn call_int64(&self, exprs: &[Expr]) -> Result> { let mut normalize_args = Vec::new(); - for expr in exprs { + for (expr_index, expr) in exprs.iter().enumerate() { match expr { - Expr::Literal(ScalarValue::Null) => {} - Expr::Literal(ScalarValue::Int64(Some(n))) => normalize_args.push(*n), - _ => return plan_err!("First argument must be an integer literal"), + Expr::Literal(ScalarValue::Null, _) => {} + Expr::Literal(ScalarValue::Int64(Some(n)), _) => normalize_args.push(*n), + other => { + return plan_err!( + "Argument #{} must be an INTEGER or NULL, got {:?}", + expr_index + 1, + other + ) + } }; } @@ -215,10 +475,7 @@ impl TableFunctionImpl for GenerateSeriesFuncImpl { // contain null return Ok(Arc::new(GenerateSeriesTable { schema, - args: GenSeriesArgs::ContainsNull { - include_end: self.include_end, - name: self.name, - }, + args: GenSeriesArgs::ContainsNull { name: self.name }, })); } @@ -232,20 +489,20 @@ impl TableFunctionImpl for GenerateSeriesFuncImpl { }; if start > end && step > 0 { - return plan_err!("start is bigger than end, but increment is positive: cannot generate infinite series"); + return plan_err!("Start is bigger than end, but increment is positive: Cannot generate infinite series"); } if start < end && step < 0 { - return plan_err!("start is smaller than end, but increment is negative: cannot generate infinite series"); + return plan_err!("Start is smaller than end, but increment is negative: Cannot generate infinite series"); } if step == 0 { - return plan_err!("step cannot be zero"); + return plan_err!("Step cannot be zero"); } Ok(Arc::new(GenerateSeriesTable { schema, - args: GenSeriesArgs::AllNotNullArgs { + args: GenSeriesArgs::Int64Args { start, end, step, @@ -254,6 +511,174 @@ impl TableFunctionImpl for GenerateSeriesFuncImpl { }, })) } + + fn call_timestamp(&self, exprs: &[Expr]) -> Result> { + if exprs.len() != 3 { + return plan_err!( + "{} function with timestamps requires exactly 3 arguments", + self.name + ); + } + + // Parse start timestamp + let (start_ts, tz) = match &exprs[0] { + Expr::Literal(ScalarValue::TimestampNanosecond(ts, tz), _) => { + (*ts, tz.clone()) + } + other => { + return plan_err!( + "First argument must be a timestamp or NULL, got {:?}", + other + ) + } + }; + + // Parse end timestamp + let end_ts = match &exprs[1] { + Expr::Literal(ScalarValue::Null, _) => None, + Expr::Literal(ScalarValue::TimestampNanosecond(ts, _), _) => *ts, + other => { + return plan_err!( + "Second argument must be a timestamp or NULL, got {:?}", + other + ) + } + }; + + // Parse step interval + let step_interval = match &exprs[2] { + Expr::Literal(ScalarValue::Null, _) => None, + Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), _) => *interval, + other => { + return plan_err!( + "Third argument must be an interval or NULL, got {:?}", + other + ) + } + }; + + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), + false, + )])); + + // Check if any argument is null + let (Some(start), Some(end), Some(step)) = (start_ts, end_ts, step_interval) + else { + return Ok(Arc::new(GenerateSeriesTable { + schema, + args: GenSeriesArgs::ContainsNull { name: self.name }, + })); + }; + + // Validate step interval + validate_interval_step(step, start, end)?; + + Ok(Arc::new(GenerateSeriesTable { + schema, + args: GenSeriesArgs::TimestampArgs { + start, + end, + step, + tz, + include_end: self.include_end, + name: self.name, + }, + })) + } + + fn call_date(&self, exprs: &[Expr]) -> Result> { + if exprs.len() != 3 { + return plan_err!( + "{} function with dates requires exactly 3 arguments", + self.name + ); + } + + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + )])); + + // Parse start date + let start_date = match &exprs[0] { + Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date, + Expr::Literal(ScalarValue::Date32(None), _) + | Expr::Literal(ScalarValue::Null, _) => { + return Ok(Arc::new(GenerateSeriesTable { + schema, + args: GenSeriesArgs::ContainsNull { name: self.name }, + })); + } + other => { + return plan_err!( + "First argument must be a date or NULL, got {:?}", + other + ) + } + }; + + // Parse end date + let end_date = match &exprs[1] { + Expr::Literal(ScalarValue::Date32(Some(date)), _) => *date, + Expr::Literal(ScalarValue::Date32(None), _) + | Expr::Literal(ScalarValue::Null, _) => { + return Ok(Arc::new(GenerateSeriesTable { + schema, + args: GenSeriesArgs::ContainsNull { name: self.name }, + })); + } + other => { + return plan_err!( + "Second argument must be a date or NULL, got {:?}", + other + ) + } + }; + + // Parse step interval + let step_interval = match &exprs[2] { + Expr::Literal(ScalarValue::IntervalMonthDayNano(Some(interval)), _) => { + *interval + } + Expr::Literal(ScalarValue::IntervalMonthDayNano(None), _) + | Expr::Literal(ScalarValue::Null, _) => { + return Ok(Arc::new(GenerateSeriesTable { + schema, + args: GenSeriesArgs::ContainsNull { name: self.name }, + })); + } + other => { + return plan_err!( + "Third argument must be an interval or NULL, got {:?}", + other + ) + } + }; + + // Convert Date32 (days since epoch) to timestamp nanoseconds (nanoseconds since epoch) + // Date32 is days since 1970-01-01, so multiply by nanoseconds per day + const NANOS_PER_DAY: i64 = 24 * 60 * 60 * 1_000_000_000; + + let start_ts = start_date as i64 * NANOS_PER_DAY; + let end_ts = end_date as i64 * NANOS_PER_DAY; + + // Validate step interval + validate_interval_step(step_interval, start_ts, end_ts)?; + + Ok(Arc::new(GenerateSeriesTable { + schema, + args: GenSeriesArgs::DateArgs { + start: start_ts, + end: end_ts, + step: step_interval, + include_end: self.include_end, + name: self.name, + }, + })) + } } #[derive(Debug)] diff --git a/datafusion/functions-window-common/README.md b/datafusion/functions-window-common/README.md index de12d25f9731..9f64c9dc8298 100644 --- a/datafusion/functions-window-common/README.md +++ b/datafusion/functions-window-common/README.md @@ -21,6 +21,11 @@ [DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. -This crate contains common functions for implementing user-defined window functions. +This crate contains common functions for implementing window functions. + +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/functions-window-common/src/expr.rs b/datafusion/functions-window-common/src/expr.rs index 76e27b045b0a..774cd5182b30 100644 --- a/datafusion/functions-window-common/src/expr.rs +++ b/datafusion/functions-window-common/src/expr.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::FieldRef; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; @@ -25,9 +25,9 @@ pub struct ExpressionArgs<'a> { /// The expressions passed as arguments to the user-defined window /// function. input_exprs: &'a [Arc], - /// The corresponding data types of expressions passed as arguments + /// The corresponding fields of expressions passed as arguments /// to the user-defined window function. - input_types: &'a [DataType], + input_fields: &'a [FieldRef], } impl<'a> ExpressionArgs<'a> { @@ -42,11 +42,11 @@ impl<'a> ExpressionArgs<'a> { /// pub fn new( input_exprs: &'a [Arc], - input_types: &'a [DataType], + input_fields: &'a [FieldRef], ) -> Self { Self { input_exprs, - input_types, + input_fields, } } @@ -56,9 +56,9 @@ impl<'a> ExpressionArgs<'a> { self.input_exprs } - /// Returns the [`DataType`]s corresponding to the input expressions + /// Returns the [`FieldRef`]s corresponding to the input expressions /// to the user-defined window function. - pub fn input_types(&self) -> &'a [DataType] { - self.input_types + pub fn input_fields(&self) -> &'a [FieldRef] { + self.input_fields } } diff --git a/datafusion/functions-window-common/src/field.rs b/datafusion/functions-window-common/src/field.rs index 03f88b0b95cc..8d22efa3bcf4 100644 --- a/datafusion/functions-window-common/src/field.rs +++ b/datafusion/functions-window-common/src/field.rs @@ -15,14 +15,14 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::FieldRef; /// Metadata for defining the result field from evaluating a /// user-defined window function. pub struct WindowUDFFieldArgs<'a> { - /// The data types corresponding to the arguments to the + /// The fields corresponding to the arguments to the /// user-defined window function. - input_types: &'a [DataType], + input_fields: &'a [FieldRef], /// The display name of the user-defined window function. display_name: &'a str, } @@ -32,22 +32,22 @@ impl<'a> WindowUDFFieldArgs<'a> { /// /// # Arguments /// - /// * `input_types` - The data types corresponding to the + /// * `input_fields` - The fields corresponding to the /// arguments to the user-defined window function. /// * `function_name` - The qualified schema name of the /// user-defined window function expression. /// - pub fn new(input_types: &'a [DataType], display_name: &'a str) -> Self { + pub fn new(input_fields: &'a [FieldRef], display_name: &'a str) -> Self { WindowUDFFieldArgs { - input_types, + input_fields, display_name, } } - /// Returns the data type of input expressions passed as arguments + /// Returns the field of input expressions passed as arguments /// to the user-defined window function. - pub fn input_types(&self) -> &[DataType] { - self.input_types + pub fn input_fields(&self) -> &[FieldRef] { + self.input_fields } /// Returns the name for the field of the final result of evaluating @@ -56,9 +56,9 @@ impl<'a> WindowUDFFieldArgs<'a> { self.display_name } - /// Returns `Some(DataType)` of input expression at index, otherwise + /// Returns `Some(Field)` of input expression at index, otherwise /// returns `None` if the index is out of bounds. - pub fn get_input_type(&self, index: usize) -> Option { - self.input_types.get(index).cloned() + pub fn get_input_field(&self, index: usize) -> Option { + self.input_fields.get(index).cloned() } } diff --git a/datafusion/functions-window-common/src/partition.rs b/datafusion/functions-window-common/src/partition.rs index e853aa8fb05d..61125e596130 100644 --- a/datafusion/functions-window-common/src/partition.rs +++ b/datafusion/functions-window-common/src/partition.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::FieldRef; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; @@ -26,9 +26,9 @@ pub struct PartitionEvaluatorArgs<'a> { /// The expressions passed as arguments to the user-defined window /// function. input_exprs: &'a [Arc], - /// The corresponding data types of expressions passed as arguments + /// The corresponding fields of expressions passed as arguments /// to the user-defined window function. - input_types: &'a [DataType], + input_fields: &'a [FieldRef], /// Set to `true` if the user-defined window function is reversed. is_reversed: bool, /// Set to `true` if `IGNORE NULLS` is specified. @@ -51,13 +51,13 @@ impl<'a> PartitionEvaluatorArgs<'a> { /// pub fn new( input_exprs: &'a [Arc], - input_types: &'a [DataType], + input_fields: &'a [FieldRef], is_reversed: bool, ignore_nulls: bool, ) -> Self { Self { input_exprs, - input_types, + input_fields, is_reversed, ignore_nulls, } @@ -69,10 +69,10 @@ impl<'a> PartitionEvaluatorArgs<'a> { self.input_exprs } - /// Returns the [`DataType`]s corresponding to the input expressions + /// Returns the [`FieldRef`]s corresponding to the input expressions /// to the user-defined window function. - pub fn input_types(&self) -> &'a [DataType] { - self.input_types + pub fn input_fields(&self) -> &'a [FieldRef] { + self.input_fields } /// Returns `true` when the user-defined window function is diff --git a/datafusion/functions-window/Cargo.toml b/datafusion/functions-window/Cargo.toml index e0c17c579b19..23ee608a8267 100644 --- a/datafusion/functions-window/Cargo.toml +++ b/datafusion/functions-window/Cargo.toml @@ -38,6 +38,7 @@ workspace = true name = "datafusion_functions_window" [dependencies] +arrow = { workspace = true } datafusion-common = { workspace = true } datafusion-doc = { workspace = true } datafusion-expr = { workspace = true } @@ -47,6 +48,3 @@ datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } log = { workspace = true } paste = "1.0.15" - -[dev-dependencies] -arrow = { workspace = true } diff --git a/datafusion/functions-window/README.md b/datafusion/functions-window/README.md index 18590983ca47..746d625b4f8e 100644 --- a/datafusion/functions-window/README.md +++ b/datafusion/functions-window/README.md @@ -21,6 +21,11 @@ [DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. -This crate contains user-defined window functions. +This crate contains window function definitions. + +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/functions-window/src/cume_dist.rs b/datafusion/functions-window/src/cume_dist.rs index d156416a82a4..ed8669948188 100644 --- a/datafusion/functions-window/src/cume_dist.rs +++ b/datafusion/functions-window/src/cume_dist.rs @@ -17,6 +17,7 @@ //! `cume_dist` window function implementation +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::{ArrayRef, Float64Array}; use datafusion_common::arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::Field; @@ -101,8 +102,8 @@ impl WindowUDFImpl for CumeDist { Ok(Box::::default()) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::Float64, false)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, false).into()) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-window/src/lead_lag.rs b/datafusion/functions-window/src/lead_lag.rs index 5df20cf5b980..e2a755371ebc 100644 --- a/datafusion/functions-window/src/lead_lag.rs +++ b/datafusion/functions-window/src/lead_lag.rs @@ -18,6 +18,7 @@ //! `lead` and `lag` window function implementations use crate::utils::{get_scalar_value_from_args, get_signed_integer}; +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::Field; @@ -157,6 +158,24 @@ static LAG_DOCUMENTATION: LazyLock = LazyLock::new(|| { the value of expression should be retrieved. Defaults to 1.") .with_argument("default", "The default value if the offset is \ not within the partition. Must be of the same type as expression.") + .with_sql_example(r#"```sql + --Example usage of the lag window function: + SELECT employee_id, + salary, + lag(salary, 1, 0) OVER (ORDER BY employee_id) AS prev_salary + FROM employees; +``` + +```sql ++-------------+--------+-------------+ +| employee_id | salary | prev_salary | ++-------------+--------+-------------+ +| 1 | 30000 | 0 | +| 2 | 50000 | 30000 | +| 3 | 70000 | 50000 | +| 4 | 60000 | 70000 | ++-------------+--------+-------------+ +```"#) .build() }); @@ -175,6 +194,27 @@ static LEAD_DOCUMENTATION: LazyLock = LazyLock::new(|| { forward the value of expression should be retrieved. Defaults to 1.") .with_argument("default", "The default value if the offset is \ not within the partition. Must be of the same type as expression.") + .with_sql_example(r#"```sql +-- Example usage of lead() : +SELECT + employee_id, + department, + salary, + lead(salary, 1, 0) OVER (PARTITION BY department ORDER BY salary) AS next_salary +FROM employees; +``` + +```sql ++-------------+-------------+--------+--------------+ +| employee_id | department | salary | next_salary | ++-------------+-------------+--------+--------------+ +| 1 | Sales | 30000 | 50000 | +| 2 | Sales | 50000 | 70000 | +| 3 | Sales | 70000 | 0 | +| 4 | Engineering | 40000 | 60000 | +| 5 | Engineering | 60000 | 0 | ++-------------+-------------+--------+--------------+ +```"#) .build() }); @@ -201,7 +241,7 @@ impl WindowUDFImpl for WindowShift { /// /// For more details see: fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { - parse_expr(expr_args.input_exprs(), expr_args.input_types()) + parse_expr(expr_args.input_exprs(), expr_args.input_fields()) .into_iter() .collect::>() } @@ -224,7 +264,7 @@ impl WindowUDFImpl for WindowShift { })?; let default_value = parse_default_value( partition_evaluator_args.input_exprs(), - partition_evaluator_args.input_types(), + partition_evaluator_args.input_fields(), )?; Ok(Box::new(WindowShiftEvaluator { @@ -235,10 +275,14 @@ impl WindowUDFImpl for WindowShift { })) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - let return_type = parse_expr_type(field_args.input_types())?; + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let return_field = parse_expr_field(field_args.input_fields())?; - Ok(Field::new(field_args.name(), return_type, true)) + Ok(return_field + .as_ref() + .clone() + .with_name(field_args.name()) + .into()) } fn reverse_expr(&self) -> ReversedUDWF { @@ -270,16 +314,16 @@ impl WindowUDFImpl for WindowShift { /// For more details see: fn parse_expr( input_exprs: &[Arc], - input_types: &[DataType], + input_fields: &[FieldRef], ) -> Result> { assert!(!input_exprs.is_empty()); - assert!(!input_types.is_empty()); + assert!(!input_fields.is_empty()); let expr = Arc::clone(input_exprs.first().unwrap()); - let expr_type = input_types.first().unwrap(); + let expr_field = input_fields.first().unwrap(); // Handles the most common case where NULL is unexpected - if !expr_type.is_null() { + if !expr_field.data_type().is_null() { return Ok(expr); } @@ -292,36 +336,43 @@ fn parse_expr( }) } -/// Returns the data type of the default value(if provided) when the +static NULL_FIELD: LazyLock = + LazyLock::new(|| Field::new("value", DataType::Null, true).into()); + +/// Returns the field of the default value(if provided) when the /// expression is `NULL`. /// -/// Otherwise, returns the expression type unchanged. -fn parse_expr_type(input_types: &[DataType]) -> Result { - assert!(!input_types.is_empty()); - let expr_type = input_types.first().unwrap_or(&DataType::Null); +/// Otherwise, returns the expression field unchanged. +fn parse_expr_field(input_fields: &[FieldRef]) -> Result { + assert!(!input_fields.is_empty()); + let expr_field = input_fields.first().unwrap_or(&NULL_FIELD); // Handles the most common case where NULL is unexpected - if !expr_type.is_null() { - return Ok(expr_type.clone()); + if !expr_field.data_type().is_null() { + return Ok(expr_field.as_ref().clone().with_nullable(true).into()); } - let default_value_type = input_types.get(2).unwrap_or(&DataType::Null); - Ok(default_value_type.clone()) + let default_value_field = input_fields.get(2).unwrap_or(&NULL_FIELD); + Ok(default_value_field + .as_ref() + .clone() + .with_nullable(true) + .into()) } /// Handles type coercion and null value refinement for default value /// argument depending on the data type of the input expression. fn parse_default_value( input_exprs: &[Arc], - input_types: &[DataType], + input_types: &[FieldRef], ) -> Result { - let expr_type = parse_expr_type(input_types)?; + let expr_field = parse_expr_field(input_types)?; let unparsed = get_scalar_value_from_args(input_exprs, 2)?; unparsed .filter(|v| !v.data_type().is_null()) - .map(|v| v.cast_to(&expr_type)) - .unwrap_or(ScalarValue::try_from(expr_type)) + .map(|v| v.cast_to(expr_field.data_type())) + .unwrap_or_else(|| ScalarValue::try_from(expr_field.data_type())) } #[derive(Debug)] @@ -666,7 +717,12 @@ mod tests { test_i32_result( WindowShift::lead(), - PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + PartitionEvaluatorArgs::new( + &[expr], + &[Field::new("f", DataType::Int32, true).into()], + false, + false, + ), [ Some(-2), Some(3), @@ -688,7 +744,12 @@ mod tests { test_i32_result( WindowShift::lag(), - PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + PartitionEvaluatorArgs::new( + &[expr], + &[Field::new("f", DataType::Int32, true).into()], + false, + false, + ), [ None, Some(1), @@ -713,12 +774,15 @@ mod tests { as Arc; let input_exprs = &[expr, shift_offset, default_value]; - let input_types: &[DataType] = - &[DataType::Int32, DataType::Int32, DataType::Int32]; + let input_fields = [DataType::Int32, DataType::Int32, DataType::Int32] + .into_iter() + .map(|d| Field::new("f", d, true)) + .map(Arc::new) + .collect::>(); test_i32_result( WindowShift::lag(), - PartitionEvaluatorArgs::new(input_exprs, input_types, false, false), + PartitionEvaluatorArgs::new(input_exprs, &input_fields, false, false), [ Some(100), Some(1), diff --git a/datafusion/functions-window/src/macros.rs b/datafusion/functions-window/src/macros.rs index 2ef1eacba953..23414a7a7172 100644 --- a/datafusion/functions-window/src/macros.rs +++ b/datafusion/functions-window/src/macros.rs @@ -40,6 +40,7 @@ /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # use datafusion_common::arrow::datatypes::{DataType, Field}; /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # @@ -85,8 +86,8 @@ /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { -/// # Ok(Field::new(field_args.name(), DataType::Int64, false)) +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::Int64, false).into()) /// # } /// # } /// # @@ -138,6 +139,7 @@ macro_rules! get_or_init_udwf { /// 1. With Zero Parameters /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # use datafusion_common::arrow::datatypes::{DataType, Field}; /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # use datafusion_functions_window::{create_udwf_expr, get_or_init_udwf}; @@ -196,8 +198,8 @@ macro_rules! get_or_init_udwf { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { -/// # Ok(Field::new(field_args.name(), DataType::UInt64, false)) +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::UInt64, false).into()) /// # } /// # } /// ``` @@ -205,6 +207,7 @@ macro_rules! get_or_init_udwf { /// 2. With Multiple Parameters /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # /// # use datafusion_expr::{ /// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, @@ -283,12 +286,12 @@ macro_rules! get_or_init_udwf { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { /// # Ok(Field::new( /// # field_args.name(), -/// # field_args.get_input_type(0).unwrap(), +/// # field_args.get_input_field(0).unwrap().data_type().clone(), /// # false, -/// # )) +/// # ).into()) /// # } /// # } /// ``` @@ -352,6 +355,7 @@ macro_rules! create_udwf_expr { /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # use datafusion_common::arrow::datatypes::{DataType, Field}; /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # @@ -404,8 +408,8 @@ macro_rules! create_udwf_expr { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { -/// # Ok(Field::new(field_args.name(), DataType::Int64, false)) +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::Int64, false).into()) /// # } /// # } /// # @@ -415,6 +419,7 @@ macro_rules! create_udwf_expr { /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # use datafusion_common::arrow::datatypes::{DataType, Field}; /// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; /// # use datafusion_functions_window::{create_udwf_expr, define_udwf_and_expr, get_or_init_udwf}; @@ -468,8 +473,8 @@ macro_rules! create_udwf_expr { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { -/// # Ok(Field::new(field_args.name(), DataType::UInt64, false)) +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::UInt64, false).into()) /// # } /// # } /// ``` @@ -479,6 +484,7 @@ macro_rules! create_udwf_expr { /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # /// # use datafusion_expr::{ /// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, @@ -554,12 +560,12 @@ macro_rules! create_udwf_expr { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { /// # Ok(Field::new( /// # field_args.name(), -/// # field_args.get_input_type(0).unwrap(), +/// # field_args.get_input_field(0).unwrap().data_type().clone(), /// # false, -/// # )) +/// # ).into()) /// # } /// # } /// ``` @@ -567,6 +573,7 @@ macro_rules! create_udwf_expr { /// /// ``` /// # use std::any::Any; +/// use arrow::datatypes::FieldRef; /// # /// # use datafusion_expr::{ /// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, @@ -643,12 +650,12 @@ macro_rules! create_udwf_expr { /// # ) -> datafusion_common::Result> { /// # unimplemented!() /// # } -/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { /// # Ok(Field::new( /// # field_args.name(), -/// # field_args.get_input_type(0).unwrap(), +/// # field_args.get_input_field(0).unwrap().data_type().clone(), /// # false, -/// # )) +/// # ).into()) /// # } /// # } /// ``` diff --git a/datafusion/functions-window/src/nth_value.rs b/datafusion/functions-window/src/nth_value.rs index 36e6b83d61ce..0b83e1ff9f08 100644 --- a/datafusion/functions-window/src/nth_value.rs +++ b/datafusion/functions-window/src/nth_value.rs @@ -19,12 +19,7 @@ use crate::utils::{get_scalar_value_from_args, get_signed_integer}; -use std::any::Any; -use std::cmp::Ordering; -use std::fmt::Debug; -use std::ops::Range; -use std::sync::LazyLock; - +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; @@ -37,6 +32,11 @@ use datafusion_expr::{ use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use field::WindowUDFFieldArgs; +use std::any::Any; +use std::cmp::Ordering; +use std::fmt::Debug; +use std::ops::Range; +use std::sync::LazyLock; get_or_init_udwf!( First, @@ -135,6 +135,26 @@ static FIRST_VALUE_DOCUMENTATION: LazyLock = LazyLock::new(|| { "first_value(expression)", ) .with_argument("expression", "Expression to operate on") + .with_sql_example(r#"```sql + --Example usage of the first_value window function: + SELECT department, + employee_id, + salary, + first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary + FROM employees; +``` + +```sql ++-------------+-------------+--------+------------+ +| department | employee_id | salary | top_salary | ++-------------+-------------+--------+------------+ +| Sales | 1 | 70000 | 70000 | +| Sales | 2 | 50000 | 70000 | +| Sales | 3 | 30000 | 70000 | +| Engineering | 4 | 90000 | 90000 | +| Engineering | 5 | 80000 | 90000 | ++-------------+-------------+--------+------------+ +```"#) .build() }); @@ -150,6 +170,26 @@ static LAST_VALUE_DOCUMENTATION: LazyLock = LazyLock::new(|| { "last_value(expression)", ) .with_argument("expression", "Expression to operate on") + .with_sql_example(r#"```sql +-- SQL example of last_value: +SELECT department, + employee_id, + salary, + last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary +FROM employees; +``` + +```sql ++-------------+-------------+--------+---------------------+ +| department | employee_id | salary | running_last_salary | ++-------------+-------------+--------+---------------------+ +| Sales | 1 | 30000 | 30000 | +| Sales | 2 | 50000 | 50000 | +| Sales | 3 | 70000 | 70000 | +| Engineering | 4 | 40000 | 40000 | +| Engineering | 5 | 60000 | 60000 | ++-------------+-------------+--------+---------------------+ +```"#) .build() }); @@ -269,11 +309,15 @@ impl WindowUDFImpl for NthValue { })) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - let nullable = true; - let return_type = field_args.input_types().first().unwrap_or(&DataType::Null); + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let return_type = field_args + .input_fields() + .first() + .map(|f| f.data_type()) + .cloned() + .unwrap_or(DataType::Null); - Ok(Field::new(field_args.name(), return_type.clone(), nullable)) + Ok(Field::new(field_args.name(), return_type, true).into()) } fn reverse_expr(&self) -> ReversedUDWF { @@ -511,7 +555,12 @@ mod tests { let expr = Arc::new(Column::new("c3", 0)) as Arc; test_i32_result( NthValue::first(), - PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + PartitionEvaluatorArgs::new( + &[expr], + &[Field::new("f", DataType::Int32, true).into()], + false, + false, + ), Int32Array::from(vec![1; 8]).iter().collect::(), ) } @@ -521,7 +570,12 @@ mod tests { let expr = Arc::new(Column::new("c3", 0)) as Arc; test_i32_result( NthValue::last(), - PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + PartitionEvaluatorArgs::new( + &[expr], + &[Field::new("f", DataType::Int32, true).into()], + false, + false, + ), Int32Array::from(vec![ Some(1), Some(-2), @@ -545,7 +599,7 @@ mod tests { NthValue::nth(), PartitionEvaluatorArgs::new( &[expr, n_value], - &[DataType::Int32], + &[Field::new("f", DataType::Int32, true).into()], false, false, ), @@ -564,7 +618,7 @@ mod tests { NthValue::nth(), PartitionEvaluatorArgs::new( &[expr, n_value], - &[DataType::Int32], + &[Field::new("f", DataType::Int32, true).into()], false, false, ), diff --git a/datafusion/functions-window/src/ntile.rs b/datafusion/functions-window/src/ntile.rs index 180f7ab02c03..6b4c0960e695 100644 --- a/datafusion/functions-window/src/ntile.rs +++ b/datafusion/functions-window/src/ntile.rs @@ -17,13 +17,10 @@ //! `ntile` window function implementation -use std::any::Any; -use std::fmt::Debug; -use std::sync::Arc; - use crate::utils::{ get_scalar_value_from_args, get_signed_integer, get_unsigned_integer, }; +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::{ArrayRef, UInt64Array}; use datafusion_common::arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_err, DataFusionError, Result}; @@ -34,6 +31,9 @@ use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_macros::user_doc; use field::WindowUDFFieldArgs; +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; get_or_init_udwf!( Ntile, @@ -52,7 +52,29 @@ pub fn ntile(arg: Expr) -> Expr { argument( name = "expression", description = "An integer describing the number groups the partition should be split into" - ) + ), + sql_example = r#"```sql + --Example usage of the ntile window function: + SELECT employee_id, + salary, + ntile(4) OVER (ORDER BY salary DESC) AS quartile + FROM employees; +``` + +```sql ++-------------+--------+----------+ +| employee_id | salary | quartile | ++-------------+--------+----------+ +| 1 | 90000 | 1 | +| 2 | 85000 | 1 | +| 3 | 80000 | 2 | +| 4 | 70000 | 2 | +| 5 | 60000 | 3 | +| 6 | 50000 | 3 | +| 7 | 40000 | 4 | +| 8 | 30000 | 4 | ++-------------+--------+----------+ +```"# )] #[derive(Debug)] pub struct Ntile { @@ -127,10 +149,10 @@ impl WindowUDFImpl for Ntile { Ok(Box::new(NtileEvaluator { n: n as u64 })) } } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { let nullable = false; - Ok(Field::new(field_args.name(), DataType::UInt64, nullable)) + Ok(Field::new(field_args.name(), DataType::UInt64, nullable).into()) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-window/src/planner.rs b/datafusion/functions-window/src/planner.rs index 1ddd8b27c420..091737bb9c15 100644 --- a/datafusion/functions-window/src/planner.rs +++ b/datafusion/functions-window/src/planner.rs @@ -43,7 +43,7 @@ impl ExprPlanner for WindowFunctionPlanner { null_treatment, } = raw_expr; - let origin_expr = Expr::WindowFunction(WindowFunction { + let origin_expr = Expr::from(WindowFunction { fun: func_def, params: WindowFunctionParams { args, @@ -56,7 +56,10 @@ impl ExprPlanner for WindowFunctionPlanner { let saved_name = NamePreserver::new_for_projection().save(&origin_expr); - let Expr::WindowFunction(WindowFunction { + let Expr::WindowFunction(window_fun) = origin_expr else { + unreachable!("") + }; + let WindowFunction { fun, params: WindowFunctionParams { @@ -66,10 +69,7 @@ impl ExprPlanner for WindowFunctionPlanner { window_frame, null_treatment, }, - }) = origin_expr - else { - unreachable!("") - }; + } = *window_fun; let raw_expr = RawWindowExpr { func_def: fun, args, @@ -95,9 +95,9 @@ impl ExprPlanner for WindowFunctionPlanner { null_treatment, } = raw_expr; - let new_expr = Expr::WindowFunction(WindowFunction::new( + let new_expr = Expr::from(WindowFunction::new( func_def, - vec![Expr::Literal(COUNT_STAR_EXPANSION)], + vec![Expr::Literal(COUNT_STAR_EXPANSION, None)], )) .partition_by(partition_by) .order_by(order_by) diff --git a/datafusion/functions-window/src/rank.rs b/datafusion/functions-window/src/rank.rs index 2ff2c31d8c2a..969a957cddd9 100644 --- a/datafusion/functions-window/src/rank.rs +++ b/datafusion/functions-window/src/rank.rs @@ -18,13 +18,8 @@ //! Implementation of `rank`, `dense_rank`, and `percent_rank` window functions, //! which can be evaluated at runtime during query execution. -use std::any::Any; -use std::fmt::Debug; -use std::iter; -use std::ops::Range; -use std::sync::{Arc, LazyLock}; - use crate::define_udwf_and_expr; +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::array::{Float64Array, UInt64Array}; use datafusion_common::arrow::compute::SortOptions; @@ -39,6 +34,11 @@ use datafusion_expr::{ use datafusion_functions_window_common::field; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use field::WindowUDFFieldArgs; +use std::any::Any; +use std::fmt::Debug; +use std::iter; +use std::ops::Range; +use std::sync::{Arc, LazyLock}; define_udwf_and_expr!( Rank, @@ -110,6 +110,26 @@ static RANK_DOCUMENTATION: LazyLock = LazyLock::new(|| { skips ranks for identical values.", "rank()") + .with_sql_example(r#"```sql + --Example usage of the rank window function: + SELECT department, + salary, + rank() OVER (PARTITION BY department ORDER BY salary DESC) AS rank + FROM employees; +``` + +```sql ++-------------+--------+------+ +| department | salary | rank | ++-------------+--------+------+ +| Sales | 70000 | 1 | +| Sales | 50000 | 2 | +| Sales | 50000 | 2 | +| Sales | 30000 | 4 | +| Engineering | 90000 | 1 | +| Engineering | 80000 | 2 | ++-------------+--------+------+ +```"#) .build() }); @@ -121,6 +141,26 @@ static DENSE_RANK_DOCUMENTATION: LazyLock = LazyLock::new(|| { Documentation::builder(DOC_SECTION_RANKING, "Returns the rank of the current row without gaps. This function ranks \ rows in a dense manner, meaning consecutive ranks are assigned even for identical \ values.", "dense_rank()") + .with_sql_example(r#"```sql + --Example usage of the dense_rank window function: + SELECT department, + salary, + dense_rank() OVER (PARTITION BY department ORDER BY salary DESC) AS dense_rank + FROM employees; +``` + +```sql ++-------------+--------+------------+ +| department | salary | dense_rank | ++-------------+--------+------------+ +| Sales | 70000 | 1 | +| Sales | 50000 | 2 | +| Sales | 50000 | 2 | +| Sales | 30000 | 3 | +| Engineering | 90000 | 1 | +| Engineering | 80000 | 2 | ++-------------+--------+------------+ +```"#) .build() }); @@ -131,6 +171,23 @@ fn get_dense_rank_doc() -> &'static Documentation { static PERCENT_RANK_DOCUMENTATION: LazyLock = LazyLock::new(|| { Documentation::builder(DOC_SECTION_RANKING, "Returns the percentage rank of the current row within its partition. \ The value ranges from 0 to 1 and is computed as `(rank - 1) / (total_rows - 1)`.", "percent_rank()") + .with_sql_example(r#"```sql + --Example usage of the percent_rank window function: + SELECT employee_id, + salary, + percent_rank() OVER (ORDER BY salary) AS percent_rank + FROM employees; +``` + +```sql ++-------------+--------+---------------+ +| employee_id | salary | percent_rank | ++-------------+--------+---------------+ +| 1 | 30000 | 0.00 | +| 2 | 50000 | 0.50 | +| 3 | 70000 | 1.00 | ++-------------+--------+---------------+ +```"#) .build() }); @@ -161,14 +218,14 @@ impl WindowUDFImpl for Rank { })) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { let return_type = match self.rank_type { RankType::Basic | RankType::Dense => DataType::UInt64, RankType::Percent => DataType::Float64, }; let nullable = false; - Ok(Field::new(field_args.name(), return_type, nullable)) + Ok(Field::new(field_args.name(), return_type, nullable).into()) } fn sort_options(&self) -> Option { diff --git a/datafusion/functions-window/src/row_number.rs b/datafusion/functions-window/src/row_number.rs index 8f462528dbed..ba8627dd86d7 100644 --- a/datafusion/functions-window/src/row_number.rs +++ b/datafusion/functions-window/src/row_number.rs @@ -17,6 +17,7 @@ //! `row_number` window function implementation +use arrow::datatypes::FieldRef; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::arrow::array::UInt64Array; use datafusion_common::arrow::compute::SortOptions; @@ -44,7 +45,27 @@ define_udwf_and_expr!( #[user_doc( doc_section(label = "Ranking Functions"), description = "Number of the current row within its partition, counting from 1.", - syntax_example = "row_number()" + syntax_example = "row_number()", + sql_example = r"```sql + --Example usage of the row_number window function: + SELECT department, + salary, + row_number() OVER (PARTITION BY department ORDER BY salary DESC) AS row_num + FROM employees; +``` + +```sql ++-------------+--------+---------+ +| department | salary | row_num | ++-------------+--------+---------+ +| Sales | 70000 | 1 | +| Sales | 50000 | 2 | +| Sales | 50000 | 3 | +| Sales | 30000 | 4 | +| Engineering | 90000 | 1 | +| Engineering | 80000 | 2 | ++-------------+--------+---------+ +```#" )] #[derive(Debug)] pub struct RowNumber { @@ -86,8 +107,8 @@ impl WindowUDFImpl for RowNumber { Ok(Box::::default()) } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - Ok(Field::new(field_args.name(), DataType::UInt64, false)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::UInt64, false).into()) } fn sort_options(&self) -> Option { diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 729770b8a65c..0c4280babc70 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -80,9 +80,9 @@ log = { workspace = true } md-5 = { version = "^0.10.0", optional = true } rand = { workspace = true } regex = { workspace = true, optional = true } -sha2 = { version = "^0.10.1", optional = true } +sha2 = { version = "^0.10.9", optional = true } unicode-segmentation = { version = "^1.7.1", optional = true } -uuid = { version = "1.16", features = ["v4"], optional = true } +uuid = { version = "1.17", features = ["v4"], optional = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } @@ -90,6 +90,11 @@ criterion = { workspace = true } rand = { workspace = true } tokio = { workspace = true, features = ["macros", "rt", "sync"] } +[[bench]] +harness = false +name = "ascii" +required-features = ["string_expressions"] + [[bench]] harness = false name = "concat" diff --git a/datafusion/functions/README.md b/datafusion/functions/README.md index a610d135c0f6..27dc4afc76bb 100644 --- a/datafusion/functions/README.md +++ b/datafusion/functions/README.md @@ -24,4 +24,9 @@ This crate contains packages of function that can be used to customize the functionality of DataFusion. +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/functions/benches/ascii.rs b/datafusion/functions/benches/ascii.rs new file mode 100644 index 000000000000..1c7023f4497e --- /dev/null +++ b/datafusion/functions/benches/ascii.rs @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; +mod helper; + +use arrow::datatypes::{DataType, Field}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ScalarFunctionArgs; +use helper::gen_string_array; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let ascii = datafusion_functions::string::ascii(); + + // All benches are single batch run with 8192 rows + const N_ROWS: usize = 8192; + const STR_LEN: usize = 16; + const UTF8_DENSITY_OF_ALL_ASCII: f32 = 0.0; + const NORMAL_UTF8_DENSITY: f32 = 0.8; + + for null_density in [0.0, 0.5] { + // StringArray ASCII only + let args_string_ascii = gen_string_array( + N_ROWS, + STR_LEN, + null_density, + UTF8_DENSITY_OF_ALL_ASCII, + false, + ); + + let arg_fields = + vec![Field::new("a", args_string_ascii[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Utf8, true).into(); + + c.bench_function( + format!("ascii/string_ascii_only (null_density={null_density})").as_str(), + |b| { + b.iter(|| { + black_box(ascii.invoke_with_args(ScalarFunctionArgs { + args: args_string_ascii.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + })) + }) + }, + ); + + // StringArray UTF8 + let args_string_utf8 = + gen_string_array(N_ROWS, STR_LEN, null_density, NORMAL_UTF8_DENSITY, false); + let arg_fields = + vec![Field::new("a", args_string_utf8[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Utf8, true).into(); + c.bench_function( + format!("ascii/string_utf8 (null_density={null_density})").as_str(), + |b| { + b.iter(|| { + black_box(ascii.invoke_with_args(ScalarFunctionArgs { + args: args_string_utf8.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + })) + }) + }, + ); + + // StringViewArray ASCII only + let args_string_view_ascii = gen_string_array( + N_ROWS, + STR_LEN, + null_density, + UTF8_DENSITY_OF_ALL_ASCII, + true, + ); + let arg_fields = + vec![Field::new("a", args_string_view_ascii[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Utf8, true).into(); + c.bench_function( + format!("ascii/string_view_ascii_only (null_density={null_density})") + .as_str(), + |b| { + b.iter(|| { + black_box(ascii.invoke_with_args(ScalarFunctionArgs { + args: args_string_view_ascii.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + })) + }) + }, + ); + + // StringViewArray UTF8 + let args_string_view_utf8 = + gen_string_array(N_ROWS, STR_LEN, null_density, NORMAL_UTF8_DENSITY, true); + let arg_fields = + vec![Field::new("a", args_string_view_utf8[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Utf8, true).into(); + c.bench_function( + format!("ascii/string_view_utf8 (null_density={null_density})").as_str(), + |b| { + b.iter(|| { + black_box(ascii.invoke_with_args(ScalarFunctionArgs { + args: args_string_view_utf8.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + })) + }) + }, + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs index bbcfed021064..b4a9e917f416 100644 --- a/datafusion/functions/benches/character_length.rs +++ b/datafusion/functions/benches/character_length.rs @@ -17,10 +17,11 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ScalarFunctionArgs; use helper::gen_string_array; +use std::sync::Arc; mod helper; @@ -28,20 +29,28 @@ fn criterion_benchmark(c: &mut Criterion) { // All benches are single batch run with 8192 rows let character_length = datafusion_functions::unicode::character_length(); - let return_type = DataType::Utf8; + let return_field = Arc::new(Field::new("f", DataType::Utf8, true)); let n_rows = 8192; for str_len in [8, 32, 128, 4096] { // StringArray ASCII only let args_string_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, false); + let arg_fields = args_string_ascii + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); c.bench_function( - &format!("character_length_StringArray_ascii_str_len_{}", str_len), + &format!("character_length_StringArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &return_type, + return_field: Arc::clone(&return_field), })) }) }, @@ -49,14 +58,22 @@ fn criterion_benchmark(c: &mut Criterion) { // StringArray UTF8 let args_string_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, false); + let arg_fields = args_string_utf8 + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); c.bench_function( - &format!("character_length_StringArray_utf8_str_len_{}", str_len), + &format!("character_length_StringArray_utf8_str_len_{str_len}"), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &return_type, + return_field: Arc::clone(&return_field), })) }) }, @@ -64,14 +81,22 @@ fn criterion_benchmark(c: &mut Criterion) { // StringViewArray ASCII only let args_string_view_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, true); + let arg_fields = args_string_view_ascii + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); c.bench_function( - &format!("character_length_StringViewArray_ascii_str_len_{}", str_len), + &format!("character_length_StringViewArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &return_type, + return_field: Arc::clone(&return_field), })) }) }, @@ -79,14 +104,22 @@ fn criterion_benchmark(c: &mut Criterion) { // StringViewArray UTF8 let args_string_view_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, true); + let arg_fields = args_string_view_utf8 + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); c.bench_function( - &format!("character_length_StringViewArray_utf8_str_len_{}", str_len), + &format!("character_length_StringViewArray_utf8_str_len_{str_len}"), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &return_type, + return_field: Arc::clone(&return_field), })) }) }, diff --git a/datafusion/functions/benches/chr.rs b/datafusion/functions/benches/chr.rs index 8575809c21c8..6a956bb78812 100644 --- a/datafusion/functions/benches/chr.rs +++ b/datafusion/functions/benches/chr.rs @@ -23,7 +23,7 @@ use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string::chr; use rand::{Rng, SeedableRng}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use rand::rngs::StdRng; use std::sync::Arc; @@ -37,27 +37,34 @@ fn criterion_benchmark(c: &mut Criterion) { let size = 1024; let input: PrimitiveArray = { let null_density = 0.2; - let mut rng = seedable_rng(); + let mut rng = StdRng::seed_from_u64(42); (0..size) .map(|_| { - if rng.gen::() < null_density { + if rng.random::() < null_density { None } else { - Some(rng.gen_range::(1i64..10_000)) + Some(rng.random_range::(1i64..10_000)) } }) .collect() }; let input = Arc::new(input); let args = vec![ColumnarValue::Array(input)]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + c.bench_function("chr", |b| { b.iter(|| { black_box( cot_fn .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(), ) diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index 45ca076e754f..d350c03c497b 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -16,7 +16,7 @@ // under the License. use arrow::array::ArrayRef; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion_common::ScalarValue; @@ -37,6 +37,14 @@ fn create_args(size: usize, str_len: usize) -> Vec { fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let args = create_args(size, 32); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let mut group = c.benchmark_group("concat function"); group.bench_function(BenchmarkId::new("concat", size), |b| { b.iter(|| { @@ -45,8 +53,9 @@ fn criterion_benchmark(c: &mut Criterion) { concat() .invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(), ) diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs index b2a9ca0b9f47..a32e0d834672 100644 --- a/datafusion/functions/benches/cot.rs +++ b/datafusion/functions/benches/cot.rs @@ -25,7 +25,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::cot; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { @@ -33,14 +33,23 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; - c.bench_function(&format!("cot f32 array: {}", size), |b| { + let arg_fields = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + c.bench_function(&format!("cot f32 array: {size}"), |b| { b.iter(|| { black_box( cot_fn .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Float32, + return_field: Field::new("f", DataType::Float32, true).into(), }) .unwrap(), ) @@ -48,14 +57,24 @@ fn criterion_benchmark(c: &mut Criterion) { }); let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; - c.bench_function(&format!("cot f64 array: {}", size), |b| { + let arg_fields = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Float64, true)); + + c.bench_function(&format!("cot f64 array: {size}"), |b| { b.iter(|| { black_box( cot_fn .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Float64, + return_field: Arc::clone(&return_field), }) .unwrap(), ) diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index 7ea5fdcb2be2..ac766a002576 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -20,6 +20,7 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{Array, ArrayRef, TimestampSecondArray}; +use arrow::datatypes::Field; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; use rand::rngs::ThreadRng; @@ -31,7 +32,7 @@ use datafusion_functions::datetime::date_bin; fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray { let mut seconds = vec![]; for _ in 0..1000 { - seconds.push(rng.gen_range(0..1_000_000)); + seconds.push(rng.random_range(0..1_000_000)); } TimestampSecondArray::from(seconds) @@ -39,7 +40,7 @@ fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("date_bin_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let timestamps_array = Arc::new(timestamps(&mut rng)) as ArrayRef; let batch_len = timestamps_array.len(); let interval = ColumnarValue::Scalar(ScalarValue::new_interval_dt(0, 1_000_000)); @@ -48,13 +49,19 @@ fn criterion_benchmark(c: &mut Criterion) { let return_type = udf .return_type(&[interval.data_type(), timestamps.data_type()]) .unwrap(); + let return_field = Arc::new(Field::new("f", return_type, true)); + let arg_fields = vec![ + Field::new("a", interval.data_type(), true).into(), + Field::new("b", timestamps.data_type(), true).into(), + ]; b.iter(|| { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![interval.clone(), timestamps.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &return_type, + return_field: Arc::clone(&return_field), }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/date_trunc.rs b/datafusion/functions/benches/date_trunc.rs index e7e96fb7a9fa..ad4d0d0fbb79 100644 --- a/datafusion/functions/benches/date_trunc.rs +++ b/datafusion/functions/benches/date_trunc.rs @@ -20,6 +20,7 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{Array, ArrayRef, TimestampSecondArray}; +use arrow::datatypes::Field; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; use rand::rngs::ThreadRng; @@ -31,7 +32,7 @@ use datafusion_functions::datetime::date_trunc; fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray { let mut seconds = vec![]; for _ in 0..1000 { - seconds.push(rng.gen_range(0..1_000_000)); + seconds.push(rng.random_range(0..1_000_000)); } TimestampSecondArray::from(seconds) @@ -39,7 +40,7 @@ fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("date_trunc_minute_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let timestamps_array = Arc::new(timestamps(&mut rng)) as ArrayRef; let batch_len = timestamps_array.len(); let precision = @@ -47,15 +48,25 @@ fn criterion_benchmark(c: &mut Criterion) { let timestamps = ColumnarValue::Array(timestamps_array); let udf = date_trunc(); let args = vec![precision, timestamps]; - let return_type = &udf + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + let return_type = udf .return_type(&args.iter().map(|arg| arg.data_type()).collect::>()) .unwrap(); + let return_field = Arc::new(Field::new("f", return_type, true)); b.iter(|| { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), }) .expect("date_trunc should work on valid values"), ) diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs index cf8f8d2fd62c..830e0324766f 100644 --- a/datafusion/functions/benches/encoding.rs +++ b/datafusion/functions/benches/encoding.rs @@ -17,7 +17,8 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::array::Array; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -33,19 +34,29 @@ fn criterion_benchmark(c: &mut Criterion) { let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], + arg_fields: vec![ + Field::new("a", str_array.data_type().to_owned(), true).into(), + Field::new("b", method.data_type().to_owned(), true).into(), + ], number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(); + let arg_fields = vec![ + Field::new("a", encoded.data_type().to_owned(), true).into(), + Field::new("b", method.data_type().to_owned(), true).into(), + ]; let args = vec![encoded, method]; + b.iter(|| { black_box( decode .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(), ) @@ -54,22 +65,34 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function(&format!("hex_decode/{size}"), |b| { let method = ColumnarValue::Scalar("hex".into()); + let arg_fields = vec![ + Field::new("a", str_array.data_type().to_owned(), true).into(), + Field::new("b", method.data_type().to_owned(), true).into(), + ]; let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], + arg_fields, number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(); + let arg_fields = vec![ + Field::new("a", encoded.data_type().to_owned(), true).into(), + Field::new("b", method.data_type().to_owned(), true).into(), + ]; + let return_field = Field::new("f", DataType::Utf8, true).into(); let args = vec![encoded, method]; + b.iter(|| { black_box( decode .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Arc::clone(&return_field), }) .unwrap(), ) diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs index 9307525482c2..bad540f049e2 100644 --- a/datafusion/functions/benches/find_in_set.rs +++ b/datafusion/functions/benches/find_in_set.rs @@ -18,14 +18,14 @@ extern crate criterion; use arrow::array::{StringArray, StringViewArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; use std::sync::Arc; @@ -51,7 +51,7 @@ fn gen_args_array( let mut output_set_vec: Vec> = Vec::with_capacity(n_rows); let mut output_element_vec: Vec> = Vec::with_capacity(n_rows); for _ in 0..n_rows { - let rand_num = rng_ref.gen::(); // [0.0, 1.0) + let rand_num = rng_ref.random::(); // [0.0, 1.0) if rand_num < null_density { output_element_vec.push(None); output_set_vec.push(None); @@ -60,7 +60,7 @@ fn gen_args_array( let mut generated_string = String::with_capacity(str_len_chars); for i in 0..num_elements { for _ in 0..str_len_chars { - let idx = rng_ref.gen_range(0..corpus_char_count); + let idx = rng_ref.random_range(0..corpus_char_count); let char = utf8.chars().nth(idx).unwrap(); generated_string.push(char); } @@ -112,7 +112,7 @@ fn random_element_in_set(string: &str) -> String { } let mut rng = StdRng::seed_from_u64(44); - let random_index = rng.gen_range(0..elements.len()); + let random_index = rng.random_range(0..elements.len()); elements[random_index].to_string() } @@ -153,23 +153,35 @@ fn criterion_benchmark(c: &mut Criterion) { group.measurement_time(Duration::from_secs(10)); let args = gen_args_array(n_rows, str_len, 0.1, 0.5, false); - group.bench_function(format!("string_len_{}", str_len), |b| { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); + let return_field = Field::new("f", DataType::Int32, true).into(); + group.bench_function(format!("string_len_{str_len}"), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), })) }) }); let args = gen_args_array(n_rows, str_len, 0.1, 0.5, true); - group.bench_function(format!("string_view_len_{}", str_len), |b| { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Int32, true)); + group.bench_function(format!("string_view_len_{str_len}"), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), })) }) }); @@ -179,23 +191,35 @@ fn criterion_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("find_in_set_scalar"); let args = gen_args_scalar(n_rows, str_len, 0.1, false); - group.bench_function(format!("string_len_{}", str_len), |b| { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Int32, true)); + group.bench_function(format!("string_len_{str_len}"), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), })) }) }); let args = gen_args_scalar(n_rows, str_len, 0.1, true); - group.bench_function(format!("string_view_len_{}", str_len), |b| { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Int32, true)); + group.bench_function(format!("string_view_len_{str_len}"), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), })) }) }); diff --git a/datafusion/functions/benches/gcd.rs b/datafusion/functions/benches/gcd.rs index f8c855c82ad4..f700d31123a9 100644 --- a/datafusion/functions/benches/gcd.rs +++ b/datafusion/functions/benches/gcd.rs @@ -17,6 +17,7 @@ extern crate criterion; +use arrow::datatypes::Field; use arrow::{ array::{ArrayRef, Int64Array}, datatypes::DataType, @@ -29,9 +30,9 @@ use rand::Rng; use std::sync::Arc; fn generate_i64_array(n_rows: usize) -> ArrayRef { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let values = (0..n_rows) - .map(|_| rng.gen_range(0..1000)) + .map(|_| rng.random_range(0..1000)) .collect::>(); Arc::new(Int64Array::from(values)) as ArrayRef } @@ -47,8 +48,12 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![array_a.clone(), array_b.clone()], + arg_fields: vec![ + Field::new("a", array_a.data_type(), true).into(), + Field::new("b", array_b.data_type(), true).into(), + ], number_rows: 0, - return_type: &DataType::Int64, + return_field: Field::new("f", DataType::Int64, true).into(), }) .expect("date_bin should work on valid values"), ) @@ -63,8 +68,12 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![array_a.clone(), scalar_b.clone()], + arg_fields: vec![ + Field::new("a", array_a.data_type(), true).into(), + Field::new("b", scalar_b.data_type(), true).into(), + ], number_rows: 0, - return_type: &DataType::Int64, + return_field: Field::new("f", DataType::Int64, true).into(), }) .expect("date_bin should work on valid values"), ) @@ -79,8 +88,12 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![scalar_a.clone(), scalar_b.clone()], + arg_fields: vec![ + Field::new("a", scalar_a.data_type(), true).into(), + Field::new("b", scalar_b.data_type(), true).into(), + ], number_rows: 0, - return_type: &DataType::Int64, + return_field: Field::new("f", DataType::Int64, true).into(), }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/helper.rs b/datafusion/functions/benches/helper.rs index 0dbb4b0027d4..a2b110ae4d63 100644 --- a/datafusion/functions/benches/helper.rs +++ b/datafusion/functions/benches/helper.rs @@ -17,7 +17,7 @@ use arrow::array::{StringArray, StringViewArray}; use datafusion_expr::ColumnarValue; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use rand::{rngs::StdRng, Rng, SeedableRng}; use std::sync::Arc; @@ -39,14 +39,14 @@ pub fn gen_string_array( let mut output_string_vec: Vec> = Vec::with_capacity(n_rows); for _ in 0..n_rows { - let rand_num = rng_ref.gen::(); // [0.0, 1.0) + let rand_num = rng_ref.random::(); // [0.0, 1.0) if rand_num < null_density { output_string_vec.push(None); } else if rand_num < null_density + utf8_density { // Generate random UTF8 string let mut generated_string = String::with_capacity(str_len_chars); for _ in 0..str_len_chars { - let char = corpus[rng_ref.gen_range(0..corpus.len())]; + let char = corpus[rng_ref.random_range(0..corpus.len())]; generated_string.push(char); } output_string_vec.push(Some(generated_string)); diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index 97c76831b33c..f89b11dff8fb 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::OffsetSizeTrait; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; @@ -49,14 +49,23 @@ fn criterion_benchmark(c: &mut Criterion) { let initcap = unicode::initcap(); for size in [1024, 4096] { let args = create_args::(size, 8, true); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + c.bench_function( - format!("initcap string view shorter than 12 [size={}]", size).as_str(), + format!("initcap string view shorter than 12 [size={size}]").as_str(), |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8View, + return_field: Field::new("f", DataType::Utf8View, true).into(), })) }) }, @@ -64,25 +73,27 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args::(size, 16, true); c.bench_function( - format!("initcap string view longer than 12 [size={}]", size).as_str(), + format!("initcap string view longer than 12 [size={size}]").as_str(), |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8View, + return_field: Field::new("f", DataType::Utf8View, true).into(), })) }) }, ); let args = create_args::(size, 16, false); - c.bench_function(format!("initcap string [size={}]", size).as_str(), |b| { + c.bench_function(format!("initcap string [size={size}]").as_str(), |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }); diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs index 42004cc24f69..49d0a9e326dd 100644 --- a/datafusion/functions/benches/isnan.rs +++ b/datafusion/functions/benches/isnan.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::{ datatypes::{Float32Type, Float64Type}, util::bench_util::create_primitive_array, @@ -32,14 +32,23 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; - c.bench_function(&format!("isnan f32 array: {}", size), |b| { + let arg_fields = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + c.bench_function(&format!("isnan f32 array: {size}"), |b| { b.iter(|| { black_box( isnan .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Boolean, + return_field: Field::new("f", DataType::Boolean, true).into(), }) .unwrap(), ) @@ -47,14 +56,22 @@ fn criterion_benchmark(c: &mut Criterion) { }); let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; - c.bench_function(&format!("isnan f64 array: {}", size), |b| { + let arg_fields = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + c.bench_function(&format!("isnan f64 array: {size}"), |b| { b.iter(|| { black_box( isnan .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Boolean, + return_field: Field::new("f", DataType::Boolean, true).into(), }) .unwrap(), ) diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs index 9e5f6a84804b..6d1d34c7a832 100644 --- a/datafusion/functions/benches/iszero.rs +++ b/datafusion/functions/benches/iszero.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::{ datatypes::{Float32Type, Float64Type}, util::bench_util::create_primitive_array, @@ -33,14 +33,24 @@ fn criterion_benchmark(c: &mut Criterion) { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f32_array.len(); let f32_args = vec![ColumnarValue::Array(f32_array)]; - c.bench_function(&format!("iszero f32 array: {}", size), |b| { + let arg_fields = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Boolean, true)); + + c.bench_function(&format!("iszero f32 array: {size}"), |b| { b.iter(|| { black_box( iszero .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Boolean, + return_field: Arc::clone(&return_field), }) .unwrap(), ) @@ -49,14 +59,24 @@ fn criterion_benchmark(c: &mut Criterion) { let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f64_array.len(); let f64_args = vec![ColumnarValue::Array(f64_array)]; - c.bench_function(&format!("iszero f64 array: {}", size), |b| { + let arg_fields = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Boolean, true)); + + c.bench_function(&format!("iszero f64 array: {size}"), |b| { b.iter(|| { black_box( iszero .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Boolean, + return_field: Arc::clone(&return_field), }) .unwrap(), ) diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index 534e5739225d..cdf1529c108c 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::{ArrayRef, StringArray, StringViewBuilder}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; @@ -44,7 +44,7 @@ fn create_args2(size: usize) -> Vec { let mut items = Vec::with_capacity(size); items.push("农历新年".to_string()); for i in 1..size { - items.push(format!("DATAFUSION {}", i)); + items.push(format!("DATAFUSION {i}")); } let array = Arc::new(StringArray::from(items)) as ArrayRef; vec![ColumnarValue::Array(array)] @@ -58,11 +58,11 @@ fn create_args3(size: usize) -> Vec { let mut items = Vec::with_capacity(size); let half = size / 2; for i in 0..half { - items.push(format!("DATAFUSION {}", i)); + items.push(format!("DATAFUSION {i}")); } items.push("Ⱦ".to_string()); for i in half + 1..size { - items.push(format!("DATAFUSION {}", i)); + items.push(format!("DATAFUSION {i}")); } let array = Arc::new(StringArray::from(items)) as ArrayRef; vec![ColumnarValue::Array(array)] @@ -124,42 +124,66 @@ fn criterion_benchmark(c: &mut Criterion) { let lower = string::lower(); for size in [1024, 4096, 8192] { let args = create_args1(size, 32); - c.bench_function(&format!("lower_all_values_are_ascii: {}", size), |b| { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + c.bench_function(&format!("lower_all_values_are_ascii: {size}"), |b| { b.iter(|| { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }); let args = create_args2(size); - c.bench_function( - &format!("lower_the_first_value_is_nonascii: {}", size), - |b| { - b.iter(|| { - let args_cloned = args.clone(); - black_box(lower.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: size, - return_type: &DataType::Utf8, - })) - }) - }, - ); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + c.bench_function(&format!("lower_the_first_value_is_nonascii: {size}"), |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(lower.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + })) + }) + }); let args = create_args3(size); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + c.bench_function( - &format!("lower_the_middle_value_is_nonascii: {}", size), + &format!("lower_the_middle_value_is_nonascii: {size}"), |b| { b.iter(|| { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }, @@ -176,29 +200,37 @@ fn criterion_benchmark(c: &mut Criterion) { for &str_len in &str_lens { for &size in &sizes { let args = create_args4(size, str_len, *null_density, mixed); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + c.bench_function( - &format!("lower_all_values_are_ascii_string_views: size: {}, str_len: {}, null_density: {}, mixed: {}", - size, str_len, null_density, mixed), + &format!("lower_all_values_are_ascii_string_views: size: {size}, str_len: {str_len}, null_density: {null_density}, mixed: {mixed}"), |b| b.iter(|| { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }), ); let args = create_args4(size, str_len, *null_density, mixed); c.bench_function( - &format!("lower_all_values_are_ascii_string_views: size: {}, str_len: {}, null_density: {}, mixed: {}", - size, str_len, null_density, mixed), + &format!("lower_all_values_are_ascii_string_views: size: {size}, str_len: {str_len}, null_density: {null_density}, mixed: {mixed}"), |b| b.iter(|| { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }), ); @@ -211,8 +243,9 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }), ); diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index 457fb499f5a1..7a44f40a689a 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{ black_box, criterion_group, criterion_main, measurement::Measurement, BenchmarkGroup, Criterion, SamplingMode, @@ -26,13 +26,9 @@ use criterion::{ use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDF}; use datafusion_functions::string; -use rand::{distributions::Alphanumeric, rngs::StdRng, Rng, SeedableRng}; +use rand::{distr::Alphanumeric, rngs::StdRng, Rng, SeedableRng}; use std::{fmt, sync::Arc}; -pub fn seedable_rng() -> StdRng { - StdRng::seed_from_u64(42) -} - #[derive(Clone, Copy)] pub enum StringArrayType { Utf8View, @@ -58,14 +54,14 @@ pub fn create_string_array_and_characters( remaining_len: usize, string_array_type: StringArrayType, ) -> (ArrayRef, ScalarValue) { - let rng = &mut seedable_rng(); + let rng = &mut StdRng::seed_from_u64(42); // Create `size` rows: // - 10% rows will be `None` // - Other 90% will be strings with same `remaining_len` lengths // We will build the string array on it later. let string_iter = (0..size).map(|_| { - if rng.gen::() < 0.1 { + if rng.random::() < 0.1 { None } else { let mut value = trimmed.as_bytes().to_vec(); @@ -136,6 +132,11 @@ fn run_with_string_type( string_type: StringArrayType, ) { let args = create_args(size, characters, trimmed, remaining_len, string_type); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); group.bench_function( format!( "{string_type} [size={size}, len_before={len}, len_after={remaining_len}]", @@ -145,8 +146,9 @@ fn run_with_string_type( let args_cloned = args.clone(); black_box(ltrim.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }, diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index 8dd7a7a59773..e1f609fbb35c 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -20,7 +20,7 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{Array, ArrayRef, Int32Array}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use rand::rngs::ThreadRng; use rand::Rng; @@ -32,7 +32,7 @@ use datafusion_functions::datetime::make_date; fn years(rng: &mut ThreadRng) -> Int32Array { let mut years = vec![]; for _ in 0..1000 { - years.push(rng.gen_range(1900..2050)); + years.push(rng.random_range(1900..2050)); } Int32Array::from(years) @@ -41,7 +41,7 @@ fn years(rng: &mut ThreadRng) -> Int32Array { fn months(rng: &mut ThreadRng) -> Int32Array { let mut months = vec![]; for _ in 0..1000 { - months.push(rng.gen_range(1..13)); + months.push(rng.random_range(1..13)); } Int32Array::from(months) @@ -50,27 +50,34 @@ fn months(rng: &mut ThreadRng) -> Int32Array { fn days(rng: &mut ThreadRng) -> Int32Array { let mut days = vec![]; for _ in 0..1000 { - days.push(rng.gen_range(1..29)); + days.push(rng.random_range(1..29)); } Int32Array::from(days) } fn criterion_benchmark(c: &mut Criterion) { c.bench_function("make_date_col_col_col_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let years_array = Arc::new(years(&mut rng)) as ArrayRef; let batch_len = years_array.len(); let years = ColumnarValue::Array(years_array); let months = ColumnarValue::Array(Arc::new(months(&mut rng)) as ArrayRef); let days = ColumnarValue::Array(Arc::new(days(&mut rng)) as ArrayRef); + let arg_fields = vec![ + Field::new("a", years.data_type(), true).into(), + Field::new("a", months.data_type(), true).into(), + Field::new("a", days.data_type(), true).into(), + ]; + let return_field = Field::new("f", DataType::Date32, true).into(); b.iter(|| { black_box( make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![years.clone(), months.clone(), days.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Date32, + return_field: Arc::clone(&return_field), }) .expect("make_date should work on valid values"), ) @@ -78,20 +85,26 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("make_date_scalar_col_col_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let year = ColumnarValue::Scalar(ScalarValue::Int32(Some(2025))); let months_arr = Arc::new(months(&mut rng)) as ArrayRef; let batch_len = months_arr.len(); let months = ColumnarValue::Array(months_arr); let days = ColumnarValue::Array(Arc::new(days(&mut rng)) as ArrayRef); - + let arg_fields = vec![ + Field::new("a", year.data_type(), true).into(), + Field::new("a", months.data_type(), true).into(), + Field::new("a", days.data_type(), true).into(), + ]; + let return_field = Field::new("f", DataType::Date32, true).into(); b.iter(|| { black_box( make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), months.clone(), days.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Date32, + return_field: Arc::clone(&return_field), }) .expect("make_date should work on valid values"), ) @@ -99,20 +112,26 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("make_date_scalar_scalar_col_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let year = ColumnarValue::Scalar(ScalarValue::Int32(Some(2025))); let month = ColumnarValue::Scalar(ScalarValue::Int32(Some(11))); let day_arr = Arc::new(days(&mut rng)); let batch_len = day_arr.len(); let days = ColumnarValue::Array(day_arr); - + let arg_fields = vec![ + Field::new("a", year.data_type(), true).into(), + Field::new("a", month.data_type(), true).into(), + Field::new("a", days.data_type(), true).into(), + ]; + let return_field = Field::new("f", DataType::Date32, true).into(); b.iter(|| { black_box( make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), month.clone(), days.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Date32, + return_field: Arc::clone(&return_field), }) .expect("make_date should work on valid values"), ) @@ -123,14 +142,21 @@ fn criterion_benchmark(c: &mut Criterion) { let year = ColumnarValue::Scalar(ScalarValue::Int32(Some(2025))); let month = ColumnarValue::Scalar(ScalarValue::Int32(Some(11))); let day = ColumnarValue::Scalar(ScalarValue::Int32(Some(26))); + let arg_fields = vec![ + Field::new("a", year.data_type(), true).into(), + Field::new("a", month.data_type(), true).into(), + Field::new("a", day.data_type(), true).into(), + ]; + let return_field = Field::new("f", DataType::Date32, true).into(); b.iter(|| { black_box( make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), month.clone(), day.clone()], + arg_fields: arg_fields.clone(), number_rows: 1, - return_type: &DataType::Date32, + return_field: Arc::clone(&return_field), }) .expect("make_date should work on valid values"), ) diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index 9096c976bf31..4ac977af9d42 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; @@ -33,14 +33,23 @@ fn criterion_benchmark(c: &mut Criterion) { ColumnarValue::Scalar(ScalarValue::Utf8(Some("abcd".to_string()))), ColumnarValue::Array(array), ]; - c.bench_function(&format!("nullif scalar array: {}", size), |b| { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + c.bench_function(&format!("nullif scalar array: {size}"), |b| { b.iter(|| { black_box( nullif .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(), ) diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index f78a53fbee19..d954ff452ed5 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -16,14 +16,15 @@ // under the License. use arrow::array::{ArrayRef, ArrowPrimitiveType, OffsetSizeTrait, PrimitiveArray}; -use arrow::datatypes::{DataType, Int64Type}; +use arrow::datatypes::{DataType, Field, Int64Type}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode::{lpad, rpad}; -use rand::distributions::{Distribution, Uniform}; +use rand::distr::{Distribution, Uniform}; use rand::Rng; use std::sync::Arc; @@ -52,13 +53,13 @@ where dist: Uniform::new_inclusive::(0, len as i64), }; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); (0..size) .map(|_| { - if rng.gen::() < null_density { + if rng.random::() < null_density { None } else { - Some(rng.sample(&dist)) + Some(rng.sample(dist.dist.unwrap())) } }) .collect() @@ -95,21 +96,41 @@ fn create_args( } } +fn invoke_pad_with_args( + args: Vec, + number_rows: usize, + left_pad: bool, +) -> Result { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + + let scalar_args = ScalarFunctionArgs { + args: args.clone(), + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Utf8, true).into(), + }; + + if left_pad { + lpad().invoke_with_args(scalar_args) + } else { + rpad().invoke_with_args(scalar_args) + } +} + fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 2048] { let mut group = c.benchmark_group("lpad function"); let args = create_args::(size, 32, false); + group.bench_function(BenchmarkId::new("utf8 type", size), |b| { b.iter(|| { criterion::black_box( - lpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, true).unwrap(), ) }) }); @@ -118,13 +139,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { b.iter(|| { criterion::black_box( - lpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::LargeUtf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, true).unwrap(), ) }) }); @@ -133,13 +148,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("stringview type", size), |b| { b.iter(|| { criterion::black_box( - lpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, true).unwrap(), ) }) }); @@ -152,13 +161,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("utf8 type", size), |b| { b.iter(|| { criterion::black_box( - rpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, false).unwrap(), ) }) }); @@ -167,13 +170,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { b.iter(|| { criterion::black_box( - rpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::LargeUtf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, false).unwrap(), ) }) }); @@ -183,13 +180,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("stringview type", size), |b| { b.iter(|| { criterion::black_box( - rpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, false).unwrap(), ) }) }); diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index 78ebf23e02e0..dc1e280b93b1 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -17,14 +17,16 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl}; use datafusion_functions::math::random::RandomFunc; +use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let random_func = RandomFunc::new(); + let return_field = Field::new("f", DataType::Float64, true).into(); // Benchmark to evaluate 1M rows in batch size 8192 let iterations = 1_000_000 / 8192; // Calculate how many iterations are needed to reach approximately 1M rows c.bench_function("random_1M_rows_batch_8192", |b| { @@ -34,8 +36,9 @@ fn criterion_benchmark(c: &mut Criterion) { random_func .invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_fields: vec![], number_rows: 8192, - return_type: &DataType::Float64, + return_field: Arc::clone(&return_field), }) .unwrap(), ); @@ -43,6 +46,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + let return_field = Field::new("f", DataType::Float64, true).into(); // Benchmark to evaluate 1M rows in batch size 128 let iterations_128 = 1_000_000 / 128; // Calculate how many iterations are needed to reach approximately 1M rows with batch size 128 c.bench_function("random_1M_rows_batch_128", |b| { @@ -52,8 +56,9 @@ fn criterion_benchmark(c: &mut Criterion) { random_func .invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_fields: vec![], number_rows: 128, - return_type: &DataType::Float64, + return_field: Arc::clone(&return_field), }) .unwrap(), ); diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index 3a1a6a71173e..c0b50ad62f64 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -26,9 +26,9 @@ use datafusion_functions::regex::regexpcount::regexp_count_func; use datafusion_functions::regex::regexplike::regexp_like; use datafusion_functions::regex::regexpmatch::regexp_match; use datafusion_functions::regex::regexpreplace::regexp_replace; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; +use rand::prelude::IndexedRandom; use rand::rngs::ThreadRng; -use rand::seq::SliceRandom; use rand::Rng; use std::iter; use std::sync::Arc; @@ -65,7 +65,7 @@ fn regex(rng: &mut ThreadRng) -> StringArray { fn start(rng: &mut ThreadRng) -> Int64Array { let mut data: Vec = vec![]; for _ in 0..1000 { - data.push(rng.gen_range(1..5)); + data.push(rng.random_range(1..5)); } Int64Array::from(data) @@ -88,7 +88,7 @@ fn flags(rng: &mut ThreadRng) -> StringArray { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("regexp_count_1000 string", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; let regex = Arc::new(regex(&mut rng)) as ArrayRef; let start = Arc::new(start(&mut rng)) as ArrayRef; @@ -108,7 +108,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_count_1000 utf8view", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); let start = Arc::new(start(&mut rng)) as ArrayRef; @@ -128,7 +128,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_like_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; let regex = Arc::new(regex(&mut rng)) as ArrayRef; let flags = Arc::new(flags(&mut rng)) as ArrayRef; @@ -142,7 +142,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_like_1000 utf8view", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); @@ -156,7 +156,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_match_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; let regex = Arc::new(regex(&mut rng)) as ArrayRef; let flags = Arc::new(flags(&mut rng)) as ArrayRef; @@ -174,7 +174,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_match_1000 utf8view", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); @@ -192,7 +192,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_replace_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; let regex = Arc::new(regex(&mut rng)) as ArrayRef; let flags = Arc::new(flags(&mut rng)) as ArrayRef; @@ -214,7 +214,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("regexp_replace_1000 utf8view", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); // flags are not allowed to be utf8view according to the function diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 5cc6a177d9d9..175933f5f745 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -18,11 +18,12 @@ extern crate criterion; use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; +use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; use std::sync::Arc; @@ -56,66 +57,62 @@ fn create_args( } } +fn invoke_repeat_with_args( + args: Vec, + repeat_times: i64, +) -> Result { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + + string::repeat().invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows: repeat_times as usize, + return_field: Field::new("f", DataType::Utf8, true).into(), + }) +} + fn criterion_benchmark(c: &mut Criterion) { - let repeat = string::repeat(); for size in [1024, 4096] { // REPEAT 3 TIMES let repeat_times = 3; - let mut group = c.benchmark_group(format!("repeat {} times", repeat_times)); + let mut group = c.benchmark_group(format!("repeat {repeat_times} times")); group.sampling_mode(SamplingMode::Flat); group.sample_size(10); group.measurement_time(Duration::from_secs(10)); let args = create_args::(size, 32, repeat_times, true); group.bench_function( - format!( - "repeat_string_view [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_string_view [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); let args = create_args::(size, 32, repeat_times, false); group.bench_function( - format!( - "repeat_string [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_string [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); let args = create_args::(size, 32, repeat_times, false); group.bench_function( - format!( - "repeat_large_string [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_large_string [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); @@ -124,61 +121,40 @@ fn criterion_benchmark(c: &mut Criterion) { // REPEAT 30 TIMES let repeat_times = 30; - let mut group = c.benchmark_group(format!("repeat {} times", repeat_times)); + let mut group = c.benchmark_group(format!("repeat {repeat_times} times")); group.sampling_mode(SamplingMode::Flat); group.sample_size(10); group.measurement_time(Duration::from_secs(10)); let args = create_args::(size, 32, repeat_times, true); group.bench_function( - format!( - "repeat_string_view [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_string_view [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); let args = create_args::(size, 32, repeat_times, false); group.bench_function( - format!( - "repeat_string [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_string [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); let args = create_args::(size, 32, repeat_times, false); group.bench_function( - format!( - "repeat_large_string [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_large_string [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); @@ -187,25 +163,18 @@ fn criterion_benchmark(c: &mut Criterion) { // REPEAT overflow let repeat_times = 1073741824; - let mut group = c.benchmark_group(format!("repeat {} times", repeat_times)); + let mut group = c.benchmark_group(format!("repeat {repeat_times} times")); group.sampling_mode(SamplingMode::Flat); group.sample_size(10); group.measurement_time(Duration::from_secs(10)); let args = create_args::(size, 2, repeat_times, false); group.bench_function( - format!( - "repeat_string overflow [size={}, repeat_times={}]", - size, repeat_times - ), + format!("repeat_string overflow [size={size}, repeat_times={repeat_times}]"), |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs index d61f8fb80517..640366011305 100644 --- a/datafusion/functions/benches/reverse.rs +++ b/datafusion/functions/benches/reverse.rs @@ -18,7 +18,7 @@ extern crate criterion; mod helper; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ScalarFunctionArgs; use helper::gen_string_array; @@ -41,13 +41,18 @@ fn criterion_benchmark(c: &mut Criterion) { false, ); c.bench_function( - &format!("reverse_StringArray_ascii_str_len_{}", str_len), + &format!("reverse_StringArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), + arg_fields: vec![Field::new( + "a", + args_string_ascii[0].data_type(), + true, + ).into()], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }, @@ -58,15 +63,17 @@ fn criterion_benchmark(c: &mut Criterion) { gen_string_array(N_ROWS, str_len, NULL_DENSITY, NORMAL_UTF8_DENSITY, false); c.bench_function( &format!( - "reverse_StringArray_utf8_density_{}_str_len_{}", - NORMAL_UTF8_DENSITY, str_len + "reverse_StringArray_utf8_density_{NORMAL_UTF8_DENSITY}_str_len_{str_len}" ), |b| { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), + arg_fields: vec![ + Field::new("a", args_string_utf8[0].data_type(), true).into(), + ], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }, @@ -81,13 +88,18 @@ fn criterion_benchmark(c: &mut Criterion) { true, ); c.bench_function( - &format!("reverse_StringViewArray_ascii_str_len_{}", str_len), + &format!("reverse_StringViewArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), + arg_fields: vec![Field::new( + "a", + args_string_view_ascii[0].data_type(), + true, + ).into()], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }, @@ -98,15 +110,19 @@ fn criterion_benchmark(c: &mut Criterion) { gen_string_array(N_ROWS, str_len, NULL_DENSITY, NORMAL_UTF8_DENSITY, true); c.bench_function( &format!( - "reverse_StringViewArray_utf8_density_{}_str_len_{}", - NORMAL_UTF8_DENSITY, str_len + "reverse_StringViewArray_utf8_density_{NORMAL_UTF8_DENSITY}_str_len_{str_len}" ), |b| { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), + arg_fields: vec![Field::new( + "a", + args_string_view_utf8[0].data_type(), + true, + ).into()], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }, diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs index 01939fad5f34..10079bcc81c7 100644 --- a/datafusion/functions/benches/signum.rs +++ b/datafusion/functions/benches/signum.rs @@ -19,7 +19,7 @@ extern crate criterion; use arrow::datatypes::DataType; use arrow::{ - datatypes::{Float32Type, Float64Type}, + datatypes::{Field, Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; use criterion::{black_box, criterion_group, criterion_main, Criterion}; @@ -33,14 +33,24 @@ fn criterion_benchmark(c: &mut Criterion) { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f32_array.len(); let f32_args = vec![ColumnarValue::Array(f32_array)]; - c.bench_function(&format!("signum f32 array: {}", size), |b| { + let arg_fields = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field = Field::new("f", DataType::Float32, true).into(); + + c.bench_function(&format!("signum f32 array: {size}"), |b| { b.iter(|| { black_box( signum .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Float32, + return_field: Arc::clone(&return_field), }) .unwrap(), ) @@ -50,14 +60,24 @@ fn criterion_benchmark(c: &mut Criterion) { let batch_len = f64_array.len(); let f64_args = vec![ColumnarValue::Array(f64_array)]; - c.bench_function(&format!("signum f64 array: {}", size), |b| { + let arg_fields = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field = Field::new("f", DataType::Float64, true).into(); + + c.bench_function(&format!("signum f64 array: {size}"), |b| { b.iter(|| { black_box( signum .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Float64, + return_field: Arc::clone(&return_field), }) .unwrap(), ) diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs index df57c229e0ad..df32db1182f1 100644 --- a/datafusion/functions/benches/strpos.rs +++ b/datafusion/functions/benches/strpos.rs @@ -18,10 +18,10 @@ extern crate criterion; use arrow::array::{StringArray, StringViewArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; use std::str::Chars; @@ -46,7 +46,7 @@ fn gen_string_array( let mut output_string_vec: Vec> = Vec::with_capacity(n_rows); let mut output_sub_string_vec: Vec> = Vec::with_capacity(n_rows); for _ in 0..n_rows { - let rand_num = rng_ref.gen::(); // [0.0, 1.0) + let rand_num = rng_ref.random::(); // [0.0, 1.0) if rand_num < null_density { output_sub_string_vec.push(None); output_string_vec.push(None); @@ -54,7 +54,7 @@ fn gen_string_array( // Generate random UTF8 string let mut generated_string = String::with_capacity(str_len_chars); for _ in 0..str_len_chars { - let idx = rng_ref.gen_range(0..corpus_char_count); + let idx = rng_ref.random_range(0..corpus_char_count); let char = utf8.chars().nth(idx).unwrap(); generated_string.push(char); } @@ -94,8 +94,8 @@ fn random_substring(chars: Chars) -> String { // get the substring of a random length from the input string by byte unit let mut rng = StdRng::seed_from_u64(44); let count = chars.clone().count(); - let start = rng.gen_range(0..count - 1); - let end = rng.gen_range(start + 1..count); + let start = rng.random_range(0..count - 1); + let end = rng.random_range(start + 1..count); chars .enumerate() .filter(|(i, _)| *i >= start && *i < end) @@ -111,14 +111,18 @@ fn criterion_benchmark(c: &mut Criterion) { for str_len in [8, 32, 128, 4096] { // StringArray ASCII only let args_string_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, false); + let arg_fields = + vec![Field::new("a", args_string_ascii[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Int32, true).into(); c.bench_function( - &format!("strpos_StringArray_ascii_str_len_{}", str_len), + &format!("strpos_StringArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), })) }) }, @@ -126,29 +130,34 @@ fn criterion_benchmark(c: &mut Criterion) { // StringArray UTF8 let args_string_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, false); - c.bench_function( - &format!("strpos_StringArray_utf8_str_len_{}", str_len), - |b| { - b.iter(|| { - black_box(strpos.invoke_with_args(ScalarFunctionArgs { - args: args_string_utf8.clone(), - number_rows: n_rows, - return_type: &DataType::Int32, - })) - }) - }, - ); + let arg_fields = + vec![Field::new("a", args_string_utf8[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Int32, true).into(); + c.bench_function(&format!("strpos_StringArray_utf8_str_len_{str_len}"), |b| { + b.iter(|| { + black_box(strpos.invoke_with_args(ScalarFunctionArgs { + args: args_string_utf8.clone(), + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + })) + }) + }); // StringViewArray ASCII only let args_string_view_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, true); + let arg_fields = + vec![Field::new("a", args_string_view_ascii[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Int32, true).into(); c.bench_function( - &format!("strpos_StringViewArray_ascii_str_len_{}", str_len), + &format!("strpos_StringViewArray_ascii_str_len_{str_len}"), |b| { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), })) }) }, @@ -156,14 +165,18 @@ fn criterion_benchmark(c: &mut Criterion) { // StringViewArray UTF8 let args_string_view_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, true); + let arg_fields = + vec![Field::new("a", args_string_view_utf8[0].data_type(), true).into()]; + let return_field = Field::new("f", DataType::Int32, true).into(); c.bench_function( - &format!("strpos_StringViewArray_utf8_str_len_{}", str_len), + &format!("strpos_StringViewArray_utf8_str_len_{str_len}"), |b| { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: Arc::clone(&return_field), })) }) }, diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs index 80ab70ef71b0..342e18b0d9a2 100644 --- a/datafusion/functions/benches/substr.rs +++ b/datafusion/functions/benches/substr.rs @@ -18,11 +18,12 @@ extern crate criterion; use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; +use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode; use std::sync::Arc; @@ -96,8 +97,25 @@ fn create_args_with_count( } } +fn invoke_substr_with_args( + args: Vec, + number_rows: usize, +) -> Result { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + + unicode::substr().invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Utf8View, true).into(), + }) +} + fn criterion_benchmark(c: &mut Criterion) { - let substr = unicode::substr(); for size in [1024, 4096] { // string_len = 12, substring_len=6 (see `create_args_without_count`) let len = 12; @@ -107,44 +125,19 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args_without_count::(size, len, true, true); group.bench_function( - format!("substr_string_view [size={}, strlen={}]", size, len), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_string_view [size={size}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_without_count::(size, len, false, false); - group.bench_function( - format!("substr_string [size={}, strlen={}]", size, len), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, - ); + group.bench_function(format!("substr_string [size={size}, strlen={len}]"), |b| { + b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))) + }); let args = create_args_without_count::(size, len, true, false); group.bench_function( - format!("substr_large_string [size={}, strlen={}]", size, len), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_large_string [size={size}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); group.finish(); @@ -158,53 +151,20 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args_with_count::(size, len, count, true); group.bench_function( - format!( - "substr_string_view [size={}, count={}, strlen={}]", - size, count, len, - ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_string_view [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); group.bench_function( - format!( - "substr_string [size={}, count={}, strlen={}]", - size, count, len, - ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); group.bench_function( - format!( - "substr_large_string [size={}, count={}, strlen={}]", - size, count, len, - ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_large_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); group.finish(); @@ -218,53 +178,20 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args_with_count::(size, len, count, true); group.bench_function( - format!( - "substr_string_view [size={}, count={}, strlen={}]", - size, count, len, - ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_string_view [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); group.bench_function( - format!( - "substr_string [size={}, count={}, strlen={}]", - size, count, len, - ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); group.bench_function( - format!( - "substr_large_string [size={}, count={}, strlen={}]", - size, count, len, - ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + format!("substr_large_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); group.finish(); diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index b1c1c3c34a95..e772fb38fc40 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -20,9 +20,9 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{ArrayRef, Int64Array, StringArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use rand::distributions::{Alphanumeric, Uniform}; +use rand::distr::{Alphanumeric, Uniform}; use rand::prelude::Distribution; use rand::Rng; @@ -54,21 +54,21 @@ fn data() -> (StringArray, StringArray, Int64Array) { dist: Uniform::new(-4, 5), test: |x: &i64| x != &0, }; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let mut strings: Vec = vec![]; let mut delimiters: Vec = vec![]; let mut counts: Vec = vec![]; for _ in 0..1000 { - let length = rng.gen_range(20..50); + let length = rng.random_range(20..50); let text: String = (&mut rng) .sample_iter(&Alphanumeric) .take(length) .map(char::from) .collect(); - let char = rng.gen_range(0..text.len()); + let char = rng.random_range(0..text.len()); let delimiter = &text.chars().nth(char).unwrap(); - let count = rng.sample(&dist); + let count = rng.sample(dist.dist.unwrap()); strings.push(text); delimiters.push(delimiter.to_string()); @@ -91,13 +91,22 @@ fn criterion_benchmark(c: &mut Criterion) { let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef); let args = vec![strings, delimiters, counts]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + b.iter(|| { black_box( substr_index() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .expect("substr_index should work on valid values"), ) diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index 6f20a20dc219..d19714ce6166 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -20,12 +20,12 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{ArrayRef, Date32Array, StringArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use chrono::prelude::*; use chrono::TimeDelta; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::prelude::IndexedRandom; use rand::rngs::ThreadRng; -use rand::seq::SliceRandom; use rand::Rng; use datafusion_common::ScalarValue; @@ -39,7 +39,7 @@ fn random_date_in_range( end_date: NaiveDate, ) -> NaiveDate { let days_in_range = (end_date - start_date).num_days(); - let random_days: i64 = rng.gen_range(0..days_in_range); + let random_days: i64 = rng.random_range(0..days_in_range); start_date + TimeDelta::try_days(random_days).unwrap() } @@ -82,7 +82,7 @@ fn patterns(rng: &mut ThreadRng) -> StringArray { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("to_char_array_array_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data_arr = data(&mut rng); let batch_len = data_arr.len(); let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); @@ -93,8 +93,12 @@ fn criterion_benchmark(c: &mut Criterion) { to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), patterns.clone()], + arg_fields: vec![ + Field::new("a", data.data_type(), true).into(), + Field::new("b", patterns.data_type(), true).into(), + ], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .expect("to_char should work on valid values"), ) @@ -102,7 +106,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("to_char_array_scalar_1000", |b| { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let data_arr = data(&mut rng); let batch_len = data_arr.len(); let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); @@ -114,8 +118,12 @@ fn criterion_benchmark(c: &mut Criterion) { to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), patterns.clone()], + arg_fields: vec![ + Field::new("a", data.data_type(), true).into(), + Field::new("b", patterns.data_type(), true).into(), + ], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .expect("to_char should work on valid values"), ) @@ -141,8 +149,12 @@ fn criterion_benchmark(c: &mut Criterion) { to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), pattern.clone()], + arg_fields: vec![ + Field::new("a", data.data_type(), true).into(), + Field::new("b", pattern.data_type(), true).into(), + ], number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .expect("to_char should work on valid values"), ) diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs index a45d936c0a52..4a02b74ca42d 100644 --- a/datafusion/functions/benches/to_hex.rs +++ b/datafusion/functions/benches/to_hex.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::{DataType, Int32Type, Int64Type}; +use arrow::datatypes::{DataType, Field, Int32Type, Int64Type}; use arrow::util::bench_util::create_primitive_array; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -30,14 +30,15 @@ fn criterion_benchmark(c: &mut Criterion) { let i32_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = i32_array.len(); let i32_args = vec![ColumnarValue::Array(i32_array)]; - c.bench_function(&format!("to_hex i32 array: {}", size), |b| { + c.bench_function(&format!("to_hex i32 array: {size}"), |b| { b.iter(|| { let args_cloned = i32_args.clone(); black_box( hex.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: vec![Field::new("a", DataType::Int32, false).into()], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(), ) @@ -46,14 +47,15 @@ fn criterion_benchmark(c: &mut Criterion) { let i64_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = i64_array.len(); let i64_args = vec![ColumnarValue::Array(i64_array)]; - c.bench_function(&format!("to_hex i64 array: {}", size), |b| { + c.bench_function(&format!("to_hex i64 array: {size}"), |b| { b.iter(|| { let args_cloned = i64_args.clone(); black_box( hex.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: vec![Field::new("a", DataType::Int64, false).into()], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(), ) diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index aec56697691f..d89811348489 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::array::builder::StringBuilder; use arrow::array::{Array, ArrayRef, StringArray}; use arrow::compute::cast; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::{DataType, Field, TimeUnit}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -109,7 +109,10 @@ fn data_with_formats() -> (StringArray, StringArray, StringArray, StringArray) { ) } fn criterion_benchmark(c: &mut Criterion) { - let return_type = &DataType::Timestamp(TimeUnit::Nanosecond, None); + let return_field = + Field::new("f", DataType::Timestamp(TimeUnit::Nanosecond, None), true).into(); + let arg_field = Field::new("a", DataType::Utf8, false).into(); + let arg_fields = vec![arg_field]; c.bench_function("to_timestamp_no_formats_utf8", |b| { let arr_data = data(); let batch_len = arr_data.len(); @@ -120,8 +123,9 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), }) .expect("to_timestamp should work on valid values"), ) @@ -138,8 +142,9 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), }) .expect("to_timestamp should work on valid values"), ) @@ -156,8 +161,9 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), }) .expect("to_timestamp should work on valid values"), ) @@ -174,13 +180,22 @@ fn criterion_benchmark(c: &mut Criterion) { ColumnarValue::Array(Arc::new(format2) as ArrayRef), ColumnarValue::Array(Arc::new(format3) as ArrayRef), ]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + b.iter(|| { black_box( to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), }) .expect("to_timestamp should work on valid values"), ) @@ -205,13 +220,22 @@ fn criterion_benchmark(c: &mut Criterion) { Arc::new(cast(&format3, &DataType::LargeUtf8).unwrap()) as ArrayRef ), ]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + b.iter(|| { black_box( to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), }) .expect("to_timestamp should work on valid values"), ) @@ -237,13 +261,22 @@ fn criterion_benchmark(c: &mut Criterion) { Arc::new(cast(&format3, &DataType::Utf8View).unwrap()) as ArrayRef ), ]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + b.iter(|| { black_box( to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: Arc::clone(&return_field), }) .expect("to_timestamp should work on valid values"), ) diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs index 7fc93921d2e7..897e21c1e1d9 100644 --- a/datafusion/functions/benches/trunc.rs +++ b/datafusion/functions/benches/trunc.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::{ - datatypes::{Float32Type, Float64Type}, + datatypes::{Field, Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; use criterion::{black_box, criterion_group, criterion_main, Criterion}; @@ -33,14 +33,17 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; - c.bench_function(&format!("trunc f32 array: {}", size), |b| { + let arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let return_field = Field::new("f", DataType::Float32, true).into(); + c.bench_function(&format!("trunc f32 array: {size}"), |b| { b.iter(|| { black_box( trunc .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Float32, + return_field: Arc::clone(&return_field), }) .unwrap(), ) @@ -48,14 +51,17 @@ fn criterion_benchmark(c: &mut Criterion) { }); let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; - c.bench_function(&format!("trunc f64 array: {}", size), |b| { + let arg_fields = vec![Field::new("a", DataType::Float64, true).into()]; + let return_field = Field::new("f", DataType::Float64, true).into(); + c.bench_function(&format!("trunc f64 array: {size}"), |b| { b.iter(|| { black_box( trunc .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Float64, + return_field: Arc::clone(&return_field), }) .unwrap(), ) diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index f0bee89c7d37..bf2c4161001e 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -42,8 +42,9 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(upper.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: vec![Field::new("a", DataType::Utf8, true).into()], number_rows: size, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }); diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs index 7b8d156fec21..942af122562a 100644 --- a/datafusion/functions/benches/uuid.rs +++ b/datafusion/functions/benches/uuid.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ScalarFunctionArgs; use datafusion_functions::string; @@ -28,8 +28,9 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(uuid.invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_fields: vec![], number_rows: 1024, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })) }) }); diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 2686dbf8be3c..e9dee09e74bf 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -17,7 +17,7 @@ //! [`ArrowCastFunc`]: Implementation of the `arrow_cast` -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::error::ArrowError; use datafusion_common::{ arrow_datafusion_err, exec_err, internal_err, Result, ScalarValue, @@ -29,7 +29,7 @@ use std::any::Any; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; @@ -113,11 +113,11 @@ impl ScalarUDFImpl for ArrowCastFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - let nullable = args.nullables.iter().any(|&nullable| nullable); + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); let [_, type_arg] = take_function_args(self.name(), args.scalar_arguments)?; @@ -131,7 +131,7 @@ impl ScalarUDFImpl for ArrowCastFunc { ) }, |casted_type| match casted_type.parse::() { - Ok(data_type) => Ok(ReturnInfo::new(data_type, nullable)), + Ok(data_type) => Ok(Field::new(self.name(), data_type, nullable).into()), Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")), Err(e) => Err(arrow_datafusion_err!(e)), }, @@ -177,7 +177,7 @@ impl ScalarUDFImpl for ArrowCastFunc { fn data_type_from_args(args: &[Expr]) -> Result { let [_, type_arg] = take_function_args("arrow_cast", args)?; - let Expr::Literal(ScalarValue::Utf8(Some(val))) = type_arg else { + let Expr::Literal(ScalarValue::Utf8(Some(val)), _) = type_arg else { return exec_err!( "arrow_cast requires its second argument to be a constant string, got {:?}", type_arg diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index ba20c23828eb..12a4bef24739 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -18,11 +18,11 @@ use arrow::array::{new_null_array, BooleanArray}; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, is_not_null, is_null}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::binary::try_type_union_resolution; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -79,19 +79,20 @@ impl ScalarUDFImpl for CoalesceFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // If any the arguments in coalesce is non-null, the result is non-null - let nullable = args.nullables.iter().all(|&nullable| nullable); + let nullable = args.arg_fields.iter().all(|f| f.is_nullable()); let return_type = args - .arg_types + .arg_fields .iter() + .map(|f| f.data_type()) .find_or_first(|d| !d.is_null()) .unwrap() .clone(); - Ok(ReturnInfo::new(return_type, nullable)) + Ok(Field::new(self.name(), return_type, nullable).into()) } /// coalesce evaluates to the first value which is not NULL diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 3ac26b98359b..2f39132871bb 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -20,7 +20,7 @@ use arrow::array::{ Scalar, }; use arrow::compute::SortOptions; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; use arrow_buffer::NullBuffer; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ @@ -28,7 +28,7 @@ use datafusion_common::{ ScalarValue, }; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -108,7 +108,7 @@ impl ScalarUDFImpl for GetFieldFunc { let [base, field_name] = take_function_args(self.name(), args)?; let name = match field_name { - Expr::Literal(name) => name, + Expr::Literal(name, _) => name, other => &ScalarValue::Utf8(Some(other.schema_name().to_string())), }; @@ -118,7 +118,7 @@ impl ScalarUDFImpl for GetFieldFunc { fn schema_name(&self, args: &[Expr]) -> Result { let [base, field_name] = take_function_args(self.name(), args)?; let name = match field_name { - Expr::Literal(name) => name, + Expr::Literal(name, _) => name, other => &ScalarValue::Utf8(Some(other.schema_name().to_string())), }; @@ -130,14 +130,14 @@ impl ScalarUDFImpl for GetFieldFunc { } fn return_type(&self, _: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // Length check handled in the signature debug_assert_eq!(args.scalar_arguments.len(), 2); - match (&args.arg_types[0], args.scalar_arguments[1].as_ref()) { + match (&args.arg_fields[0].data_type(), args.scalar_arguments[1].as_ref()) { (DataType::Map(fields, _), _) => { match fields.data_type() { DataType::Struct(fields) if fields.len() == 2 => { @@ -146,7 +146,8 @@ impl ScalarUDFImpl for GetFieldFunc { // instead, we assume that the second column is the "value" column both here and in // execution. let value_field = fields.get(1).expect("fields should have exactly two members"); - Ok(ReturnInfo::new_nullable(value_field.data_type().clone())) + + Ok(value_field.as_ref().clone().with_nullable(true).into()) }, _ => exec_err!("Map fields must contain a Struct with exactly 2 fields"), } @@ -158,10 +159,20 @@ impl ScalarUDFImpl for GetFieldFunc { |field_name| { fields.iter().find(|f| f.name() == field_name) .ok_or(plan_datafusion_err!("Field {field_name} not found in struct")) - .map(|f| ReturnInfo::new_nullable(f.data_type().to_owned())) + .map(|f| { + let mut child_field = f.as_ref().clone(); + + // If the parent is nullable, then getting the child must be nullable, + // so potentially override the return value + + if args.arg_fields[0].is_nullable() { + child_field = child_field.with_nullable(true); + } + Arc::new(child_field) + }) }) }, - (DataType::Null, _) => Ok(ReturnInfo::new_nullable(DataType::Null)), + (DataType::Null, _) => Ok(Field::new(self.name(), DataType::Null, true).into()), (other, _) => exec_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"), } } diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index c6329b1ee0af..db080cd62847 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -36,6 +36,7 @@ pub mod overlay; pub mod planner; pub mod r#struct; pub mod union_extract; +pub mod union_tag; pub mod version; // create UDFs @@ -52,6 +53,7 @@ make_udf_function!(coalesce::CoalesceFunc, coalesce); make_udf_function!(greatest::GreatestFunc, greatest); make_udf_function!(least::LeastFunc, least); make_udf_function!(union_extract::UnionExtractFun, union_extract); +make_udf_function!(union_tag::UnionTagFunc, union_tag); make_udf_function!(version::VersionFunc, version); pub mod expr_fn { @@ -101,6 +103,10 @@ pub mod expr_fn { least, "Returns `least(args...)`, which evaluates to the smallest value in the list of expressions or NULL if all the expressions are NULL", args, + ),( + union_tag, + "Returns the name of the currently selected field in the union", + arg1 )); #[doc = "Returns the value of the field with the given name from the struct"] @@ -136,6 +142,7 @@ pub fn functions() -> Vec> { greatest(), least(), union_extract(), + union_tag(), version(), r#struct(), ] diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index bba884d96483..115f4a8aba22 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -16,10 +16,10 @@ // under the License. use arrow::array::StructArray; -use arrow::datatypes::{DataType, Field, Fields}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields}; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -91,10 +91,12 @@ impl ScalarUDFImpl for NamedStructFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("named_struct: return_type called instead of return_type_from_args") + internal_err!( + "named_struct: return_type called instead of return_field_from_args" + ) } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // do not accept 0 arguments. if args.scalar_arguments.is_empty() { return exec_err!( @@ -126,7 +128,13 @@ impl ScalarUDFImpl for NamedStructFunc { ) ) .collect::>>()?; - let types = args.arg_types.iter().skip(1).step_by(2).collect::>(); + let types = args + .arg_fields + .iter() + .skip(1) + .step_by(2) + .map(|f| f.data_type()) + .collect::>(); let return_fields = names .into_iter() @@ -134,13 +142,16 @@ impl ScalarUDFImpl for NamedStructFunc { .map(|(name, data_type)| Ok(Field::new(name, data_type.to_owned(), true))) .collect::>>()?; - Ok(ReturnInfo::new_nullable(DataType::Struct(Fields::from( - return_fields, - )))) + Ok(Field::new( + self.name(), + DataType::Struct(Fields::from(return_fields)), + true, + ) + .into()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let DataType::Struct(fields) = args.return_type else { + let DataType::Struct(fields) = args.return_type() else { return internal_err!("incorrect named_struct return type"); }; diff --git a/datafusion/functions/src/core/struct.rs b/datafusion/functions/src/core/struct.rs index 8792bf1bd1b9..f068fc18a8b0 100644 --- a/datafusion/functions/src/core/struct.rs +++ b/datafusion/functions/src/core/struct.rs @@ -117,7 +117,7 @@ impl ScalarUDFImpl for StructFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let DataType::Struct(fields) = args.return_type else { + let DataType::Struct(fields) = args.return_type() else { return internal_err!("incorrect struct return type"); }; diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index 420eeed42cc3..be49f8226712 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -16,14 +16,14 @@ // under the License. use arrow::array::Array; -use arrow::datatypes::{DataType, FieldRef, UnionFields}; +use arrow::datatypes::{DataType, Field, FieldRef, UnionFields}; use datafusion_common::cast::as_union_array; use datafusion_common::utils::take_function_args; use datafusion_common::{ exec_datafusion_err, exec_err, internal_err, Result, ScalarValue, }; use datafusion_doc::Documentation; -use datafusion_expr::{ColumnarValue, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs}; +use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -82,35 +82,35 @@ impl ScalarUDFImpl for UnionExtractFun { } fn return_type(&self, _: &[DataType]) -> Result { - // should be using return_type_from_args and not calling the default implementation + // should be using return_field_from_args and not calling the default implementation internal_err!("union_extract should return type from args") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - if args.arg_types.len() != 2 { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + if args.arg_fields.len() != 2 { return exec_err!( "union_extract expects 2 arguments, got {} instead", - args.arg_types.len() + args.arg_fields.len() ); } - let DataType::Union(fields, _) = &args.arg_types[0] else { + let DataType::Union(fields, _) = &args.arg_fields[0].data_type() else { return exec_err!( "union_extract first argument must be a union, got {} instead", - args.arg_types[0] + args.arg_fields[0].data_type() ); }; let Some(ScalarValue::Utf8(Some(field_name))) = &args.scalar_arguments[1] else { return exec_err!( "union_extract second argument must be a non-null string literal, got {} instead", - args.arg_types[1] + args.arg_fields[1].data_type() ); }; let field = find_field(fields, field_name)?.1; - Ok(ReturnInfo::new_nullable(field.data_type().clone())) + Ok(Field::new(self.name(), field.data_type().clone(), true).into()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -189,47 +189,67 @@ mod tests { ], ); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Union( + None, + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); + let result = fun.invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Union( - None, - fields.clone(), - UnionMode::Dense, - )), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ], + args, + arg_fields, number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })?; assert_scalar(result, ScalarValue::Utf8(None)); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Union( + Some((3, Box::new(ScalarValue::Int32(Some(42))))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); + let result = fun.invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Union( - Some((3, Box::new(ScalarValue::Int32(Some(42))))), - fields.clone(), - UnionMode::Dense, - )), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ], + args, + arg_fields, number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })?; assert_scalar(result, ScalarValue::Utf8(None)); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Union( + Some((1, Box::new(ScalarValue::new_utf8("42")))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) + .collect::>(); let result = fun.invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Union( - Some((1, Box::new(ScalarValue::new_utf8("42")))), - fields.clone(), - UnionMode::Dense, - )), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ], + args, + arg_fields, number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), })?; assert_scalar(result, ScalarValue::new_utf8("42")); diff --git a/datafusion/functions/src/core/union_tag.rs b/datafusion/functions/src/core/union_tag.rs new file mode 100644 index 000000000000..3a4d96de2bc0 --- /dev/null +++ b/datafusion/functions/src/core/union_tag.rs @@ -0,0 +1,225 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, AsArray, DictionaryArray, Int8Array, StringArray}; +use arrow::datatypes::DataType; +use datafusion_common::utils::take_function_args; +use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; +use datafusion_doc::Documentation; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +#[user_doc( + doc_section(label = "Union Functions"), + description = "Returns the name of the currently selected field in the union", + syntax_example = "union_tag(union_expression)", + sql_example = r#"```sql +❯ select union_column, union_tag(union_column) from table_with_union; ++--------------+-------------------------+ +| union_column | union_tag(union_column) | ++--------------+-------------------------+ +| {a=1} | a | +| {b=3.0} | b | +| {a=4} | a | +| {b=} | b | +| {a=} | a | ++--------------+-------------------------+ +```"#, + standard_argument(name = "union", prefix = "Union") +)] +#[derive(Debug)] +pub struct UnionTagFunc { + signature: Signature, +} + +impl Default for UnionTagFunc { + fn default() -> Self { + Self::new() + } +} + +impl UnionTagFunc { + pub fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for UnionTagFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "union_tag" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Utf8), + )) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [union_] = take_function_args("union_tag", args.args)?; + + match union_ { + ColumnarValue::Array(array) + if matches!(array.data_type(), DataType::Union(_, _)) => + { + let union_array = array.as_union(); + + let keys = Int8Array::try_new(union_array.type_ids().clone(), None)?; + + let fields = match union_array.data_type() { + DataType::Union(fields, _) => fields, + _ => unreachable!(), + }; + + // Union fields type IDs only constraints are being unique and in the 0..128 range: + // They may not start at 0, be sequential, or even contiguous. + // Therefore, we allocate a values vector with a length equal to the highest type ID plus one, + // ensuring that each field's name can be placed at the index corresponding to its type ID. + let values_len = fields + .iter() + .map(|(type_id, _)| type_id + 1) + .max() + .unwrap_or_default() as usize; + + let mut values = vec![""; values_len]; + + for (type_id, field) in fields.iter() { + values[type_id as usize] = field.name().as_str() + } + + let values = Arc::new(StringArray::from(values)); + + // SAFETY: union type_ids are validated to not be smaller than zero. + // values len is the union biggest type id plus one. + // keys is built from the union type_ids, which contains only valid type ids + // therefore, `keys[i] >= values.len() || keys[i] < 0` never occurs + let dict = unsafe { DictionaryArray::new_unchecked(keys, values) }; + + Ok(ColumnarValue::Array(Arc::new(dict))) + } + ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => match value { + Some((value_type_id, _)) => fields + .iter() + .find(|(type_id, _)| value_type_id == *type_id) + .map(|(_, field)| { + ColumnarValue::Scalar(ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(field.name().as_str().into()), + )) + }) + .ok_or_else(|| { + exec_datafusion_err!( + "union_tag: union scalar with unknow type_id {value_type_id}" + ) + }), + None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( + args.return_field.data_type(), + )?)), + }, + v => exec_err!("union_tag only support unions, got {:?}", v.data_type()), + } + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +#[cfg(test)] +mod tests { + use super::UnionTagFunc; + use arrow::datatypes::{DataType, Field, UnionFields, UnionMode}; + use datafusion_common::ScalarValue; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use std::sync::Arc; + + // when it becomes possible to construct union scalars in SQL, this should go to sqllogictests + #[test] + fn union_scalar() { + let fields = [(0, Arc::new(Field::new("a", DataType::UInt32, false)))] + .into_iter() + .collect(); + + let scalar = ScalarValue::Union( + Some((0, Box::new(ScalarValue::UInt32(Some(0))))), + fields, + UnionMode::Dense, + ); + + let return_type = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); + + let result = UnionTagFunc::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(scalar)], + number_rows: 1, + return_field: Field::new("res", return_type, true).into(), + arg_fields: vec![], + }) + .unwrap(); + + assert_scalar( + result, + ScalarValue::Dictionary(Box::new(DataType::Int8), Box::new("a".into())), + ); + } + + #[test] + fn union_scalar_empty() { + let scalar = ScalarValue::Union(None, UnionFields::empty(), UnionMode::Dense); + + let return_type = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); + + let result = UnionTagFunc::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(scalar)], + number_rows: 1, + return_field: Field::new("res", return_type, true).into(), + arg_fields: vec![], + }) + .unwrap(); + + assert_scalar( + result, + ScalarValue::Dictionary( + Box::new(DataType::Int8), + Box::new(ScalarValue::Utf8(None)), + ), + ); + } + + fn assert_scalar(value: ColumnarValue, expected: ScalarValue) { + match value { + ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"), + ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected), + } + } +} diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index 34038022f2dc..b3abe246b4b3 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -97,6 +97,7 @@ impl ScalarUDFImpl for VersionFunc { #[cfg(test)] mod test { use super::*; + use arrow::datatypes::Field; use datafusion_expr::ScalarUDF; #[tokio::test] @@ -105,8 +106,9 @@ mod test { let version = version_udf .invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_fields: vec![], number_rows: 0, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }) .unwrap(); diff --git a/datafusion/functions/src/crypto/basic.rs b/datafusion/functions/src/crypto/basic.rs index eaa688c1c335..5bf83943a92d 100644 --- a/datafusion/functions/src/crypto/basic.rs +++ b/datafusion/functions/src/crypto/basic.rs @@ -21,7 +21,7 @@ use arrow::array::{ Array, ArrayRef, BinaryArray, BinaryArrayType, BinaryViewArray, GenericBinaryArray, OffsetSizeTrait, }; -use arrow::array::{AsArray, GenericStringArray, StringArray, StringViewArray}; +use arrow::array::{AsArray, GenericStringArray, StringViewArray}; use arrow::datatypes::DataType; use blake2::{Blake2b512, Blake2s256, Digest}; use blake3::Hasher as Blake3; @@ -169,18 +169,18 @@ pub fn md5(args: &[ColumnarValue]) -> Result { let [data] = take_function_args("md5", args)?; let value = digest_process(data, DigestAlgorithm::Md5)?; - // md5 requires special handling because of its unique utf8 return type + // md5 requires special handling because of its unique utf8view return type Ok(match value { ColumnarValue::Array(array) => { let binary_array = as_binary_array(&array)?; - let string_array: StringArray = binary_array + let string_array: StringViewArray = binary_array .iter() .map(|opt| opt.map(hex_encode::<_>)) .collect(); ColumnarValue::Array(Arc::new(string_array)) } ColumnarValue::Scalar(ScalarValue::Binary(opt)) => { - ColumnarValue::Scalar(ScalarValue::Utf8(opt.map(hex_encode::<_>))) + ColumnarValue::Scalar(ScalarValue::Utf8View(opt.map(hex_encode::<_>))) } _ => return exec_err!("Impossibly got invalid results from digest"), }) diff --git a/datafusion/functions/src/crypto/md5.rs b/datafusion/functions/src/crypto/md5.rs index c1540450029c..e209ed06e28b 100644 --- a/datafusion/functions/src/crypto/md5.rs +++ b/datafusion/functions/src/crypto/md5.rs @@ -92,12 +92,12 @@ impl ScalarUDFImpl for Md5Func { fn return_type(&self, arg_types: &[DataType]) -> Result { use DataType::*; Ok(match &arg_types[0] { - LargeUtf8 | LargeBinary => Utf8, - Utf8View | Utf8 | Binary | BinaryView => Utf8, + LargeUtf8 | LargeBinary => Utf8View, + Utf8View | Utf8 | Binary | BinaryView => Utf8View, Null => Null, Dictionary(_, t) => match **t { - LargeUtf8 | LargeBinary => Utf8, - Utf8 | Binary | BinaryView => Utf8, + LargeUtf8 | LargeBinary => Utf8View, + Utf8 | Binary | BinaryView => Utf8View, Null => Null, _ => { return plan_err!( diff --git a/datafusion/functions/src/datetime/current_date.rs b/datafusion/functions/src/datetime/current_date.rs index 9998e7d3758e..2bda1f262abe 100644 --- a/datafusion/functions/src/datetime/current_date.rs +++ b/datafusion/functions/src/datetime/current_date.rs @@ -108,6 +108,7 @@ impl ScalarUDFImpl for CurrentDateFunc { ); Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::Date32(days), + None, ))) } diff --git a/datafusion/functions/src/datetime/current_time.rs b/datafusion/functions/src/datetime/current_time.rs index c416d0240b13..9b9d3997e9d7 100644 --- a/datafusion/functions/src/datetime/current_time.rs +++ b/datafusion/functions/src/datetime/current_time.rs @@ -96,6 +96,7 @@ impl ScalarUDFImpl for CurrentTimeFunc { let nano = now_ts.timestamp_nanos_opt().map(|ts| ts % 86400000000000); Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::Time64Nanosecond(nano), + None, ))) } diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 5ffae46dde48..1c801dfead72 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -505,85 +505,88 @@ mod tests { use arrow::array::types::TimestampNanosecondType; use arrow::array::{Array, IntervalDayTimeArray, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; - use arrow::datatypes::{DataType, TimeUnit}; + use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; - use datafusion_common::ScalarValue; + use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use chrono::TimeDelta; + fn invoke_date_bin_with_args( + args: Vec, + number_rows: usize, + return_field: &FieldRef, + ) -> Result { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true).into()) + .collect::>(); + + let args = datafusion_expr::ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field: Arc::clone(return_field), + }; + DateBinFunc::new().invoke_with_args(args) + } + #[test] fn test_date_bin() { - let mut args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + let return_field = &Arc::new(Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + )); + + let mut args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); let timestamps = Arc::new((1..6).map(Some).collect::()); let batch_len = timestamps.len(); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Array(timestamps), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Array(timestamps), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, batch_len, return_field); assert!(res.is_ok()); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); // stride supports month-day-nano - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNano { - months: 0, - days: 0, - nanoseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano { + months: 0, + days: 0, + nanoseconds: 1, + }, + ))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); // @@ -591,33 +594,25 @@ mod tests { // // invalid number of arguments - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - )))], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( + IntervalDayTime { + days: 0, + milliseconds: 1, + }, + )))]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expected two or three arguments" ); // stride: invalid type - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expects stride argument to be an INTERVAL but got Interval(YearMonth)" @@ -625,113 +620,83 @@ mod tests { // stride: invalid value - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 0, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 0, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; - let res = DateBinFunc::new().invoke_with_args(args); + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride must be non-zero" ); // stride: overflow of day-time interval - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime::MAX, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( + IntervalDayTime::MAX, + ))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride argument is too large" ); // stride: overflow of month-day-nano interval - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::new_interval_mdn(0, i32::MAX, 1)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_mdn(0, i32::MAX, 1)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride argument is too large" ); // stride: month intervals - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::new_interval_mdn(1, 1, 1)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_mdn(1, 1, 1)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN stride does not support combination of month, day and nanosecond intervals" ); // origin: invalid type - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expects origin argument to be a TIMESTAMP with nanosecond precision but got Timestamp(Microsecond, None)" ); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); // unsupported array type for stride @@ -745,16 +710,12 @@ mod tests { }) .collect::(), ); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(intervals), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Array(intervals), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN only supports literal values for the stride argument, not arrays" @@ -763,21 +724,15 @@ mod tests { // unsupported array type for origin let timestamps = Arc::new((1..6).map(Some).collect::()); let batch_len = timestamps.len(); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Array(timestamps), - ], - number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Array(timestamps), + ]; + let res = invoke_date_bin_with_args(args, batch_len, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN only supports literal values for the origin argument, not arrays" @@ -893,22 +848,22 @@ mod tests { .collect::() .with_timezone_opt(tz_opt.clone()); let batch_len = input.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::new_interval_dt(1, 0)), - ColumnarValue::Array(Arc::new(input)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - Some(string_to_timestamp_nanos(origin).unwrap()), - tz_opt.clone(), - )), - ], - number_rows: batch_len, - return_type: &DataType::Timestamp( - TimeUnit::Nanosecond, + let args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_dt(1, 0)), + ColumnarValue::Array(Arc::new(input)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(string_to_timestamp_nanos(origin).unwrap()), tz_opt.clone(), - ), - }; - let result = DateBinFunc::new().invoke_with_args(args).unwrap(); + )), + ]; + let return_field = &Arc::new(Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + true, + )); + let result = + invoke_date_bin_with_args(args, batch_len, return_field).unwrap(); + if let ColumnarValue::Array(result) = result { assert_eq!( result.data_type(), diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index bfd06b39d206..743cdeb1d3f0 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -26,7 +26,7 @@ use arrow::datatypes::DataType::{ Date32, Date64, Duration, Interval, Time32, Time64, Timestamp, }; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; use datafusion_common::types::{logical_date, NativeType}; use datafusion_common::{ @@ -42,7 +42,7 @@ use datafusion_common::{ Result, ScalarValue, }; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, Signature, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; @@ -142,10 +142,10 @@ impl ScalarUDFImpl for DatePartFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { let [field, _] = take_function_args(self.name(), args.scalar_arguments)?; field @@ -155,12 +155,13 @@ impl ScalarUDFImpl for DatePartFunc { .filter(|s| !s.is_empty()) .map(|part| { if is_epoch(part) { - ReturnInfo::new_nullable(DataType::Float64) + Field::new(self.name(), DataType::Float64, true) } else { - ReturnInfo::new_nullable(DataType::Int32) + Field::new(self.name(), DataType::Int32, true) } }) }) + .map(Arc::new) .map_or_else( || exec_err!("{} requires non-empty constant string", self.name()), Ok, @@ -231,6 +232,7 @@ impl ScalarUDFImpl for DatePartFunc { fn aliases(&self) -> &[String] { &self.aliases } + fn documentation(&self) -> Option<&Documentation> { self.doc() } diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index ed3eb228bf03..8963ef77a53b 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -471,7 +471,7 @@ fn parse_tz(tz: &Option>) -> Result> { tz.as_ref() .map(|tz| { Tz::from_str(tz).map_err(|op| { - DataFusionError::Execution(format!("failed on timezone {tz}: {:?}", op)) + DataFusionError::Execution(format!("failed on timezone {tz}: {op:?}")) }) }) .transpose() @@ -487,7 +487,7 @@ mod tests { use arrow::array::types::TimestampNanosecondType; use arrow::array::{Array, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; - use arrow::datatypes::{DataType, TimeUnit}; + use arrow::datatypes::{DataType, Field, TimeUnit}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -726,13 +726,23 @@ mod tests { .collect::() .with_timezone_opt(tz_opt.clone()); let batch_len = input.len(); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, false).into(), + Field::new("b", input.data_type().clone(), false).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::from("day")), ColumnarValue::Array(Arc::new(input)), ], + arg_fields, number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + return_field: Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + true, + ) + .into(), }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { @@ -888,13 +898,23 @@ mod tests { .collect::() .with_timezone_opt(tz_opt.clone()); let batch_len = input.len(); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, false).into(), + Field::new("b", input.data_type().clone(), false).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::from("hour")), ColumnarValue::Array(Arc::new(input)), ], + arg_fields, number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + return_field: Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + true, + ) + .into(), }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index ed8181452dbd..c1497040261c 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -18,20 +18,19 @@ use std::any::Any; use std::sync::Arc; -use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Int64, Timestamp, Utf8}; use arrow::datatypes::TimeUnit::Second; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, Signature, - Volatility, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; #[user_doc( doc_section(label = "Time and Date Functions"), - description = "Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp.", + description = "Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp.", syntax_example = "from_unixtime(expression[, timezone])", sql_example = r#"```sql > select from_unixtime(1599572549, 'America/New_York'); @@ -82,12 +81,12 @@ impl ScalarUDFImpl for FromUnixtimeFunc { &self.signature } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // Length check handled in the signature debug_assert!(matches!(args.scalar_arguments.len(), 1 | 2)); if args.scalar_arguments.len() == 1 { - Ok(ReturnInfo::new_nullable(Timestamp(Second, None))) + Ok(Field::new(self.name(), Timestamp(Second, None), true).into()) } else { args.scalar_arguments[1] .and_then(|sv| { @@ -95,12 +94,14 @@ impl ScalarUDFImpl for FromUnixtimeFunc { .flatten() .filter(|s| !s.is_empty()) .map(|tz| { - ReturnInfo::new_nullable(Timestamp( - Second, - Some(Arc::from(tz.to_string())), - )) + Field::new( + self.name(), + Timestamp(Second, Some(Arc::from(tz.to_string()))), + true, + ) }) }) + .map(Arc::new) .map_or_else( || { exec_err!( @@ -114,7 +115,7 @@ impl ScalarUDFImpl for FromUnixtimeFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("call return_type_from_args instead") + internal_err!("call return_field_from_args instead") } fn invoke_with_args( @@ -161,8 +162,8 @@ impl ScalarUDFImpl for FromUnixtimeFunc { #[cfg(test)] mod test { use crate::datetime::from_unixtime::FromUnixtimeFunc; - use arrow::datatypes::DataType; use arrow::datatypes::TimeUnit::Second; + use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; use datafusion_common::ScalarValue::Int64; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -170,10 +171,12 @@ mod test { #[test] fn test_without_timezone() { + let arg_field = Arc::new(Field::new("a", DataType::Int64, true)); let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(Int64(Some(1729900800)))], + arg_fields: vec![arg_field], number_rows: 1, - return_type: &DataType::Timestamp(Second, None), + return_field: Field::new("f", DataType::Timestamp(Second, None), true).into(), }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); @@ -187,6 +190,10 @@ mod test { #[test] fn test_with_timezone() { + let arg_fields = vec![ + Field::new("a", DataType::Int64, true).into(), + Field::new("a", DataType::Utf8, true).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(Int64(Some(1729900800))), @@ -194,11 +201,14 @@ mod test { "America/New_York".to_string(), ))), ], + arg_fields, number_rows: 2, - return_type: &DataType::Timestamp( - Second, - Some(Arc::from("America/New_York")), - ), + return_field: Field::new( + "f", + DataType::Timestamp(Second, Some(Arc::from("America/New_York"))), + true, + ) + .into(), }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index 929fa601f107..daa9bd83971f 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -223,25 +223,39 @@ fn make_date_inner( mod tests { use crate::datetime::make_date::MakeDateFunc; use arrow::array::{Array, Date32Array, Int32Array, Int64Array, UInt32Array}; - use arrow::datatypes::DataType; - use datafusion_common::ScalarValue; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use std::sync::Arc; + fn invoke_make_date_with_args( + args: Vec, + number_rows: usize, + ) -> Result { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true).into()) + .collect::>(); + let args = datafusion_expr::ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Date32, true).into(), + }; + MakeDateFunc::new().invoke_with_args(args) + } + #[test] fn test_make_date() { - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(2024))), ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + 1, + ) + .expect("that make_date parsed values without error"); if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { assert_eq!(19736, date.unwrap()); @@ -249,18 +263,15 @@ mod tests { panic!("Expected a scalar value") } - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int64(Some(2024))), ColumnarValue::Scalar(ScalarValue::UInt64(Some(1))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + 1, + ) + .expect("that make_date parsed values without error"); if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { assert_eq!(19736, date.unwrap()); @@ -268,18 +279,15 @@ mod tests { panic!("Expected a scalar value") } - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some("2024".to_string()))), ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("1".to_string()))), ColumnarValue::Scalar(ScalarValue::Utf8(Some("14".to_string()))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + 1, + ) + .expect("that make_date parsed values without error"); if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { assert_eq!(19736, date.unwrap()); @@ -291,18 +299,15 @@ mod tests { let months = Arc::new((1..5).map(Some).collect::()); let days = Arc::new((11..15).map(Some).collect::()); let batch_len = years.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Array(years), ColumnarValue::Array(months), ColumnarValue::Array(days), ], - number_rows: batch_len, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + batch_len, + ) + .unwrap(); if let ColumnarValue::Array(array) = res { assert_eq!(array.len(), 4); @@ -321,60 +326,52 @@ mod tests { // // invalid number of arguments - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new().invoke_with_args(args); + let res = invoke_make_date_with_args( + vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: make_date function requires 3 arguments, got 1" ); // invalid type - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new().invoke_with_args(args); + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Arrow error: Cast error: Casting from Interval(YearMonth) to Int32 not supported" ); // overflow of month - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), ColumnarValue::Scalar(ScalarValue::UInt64(Some(u64::MAX))), ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new().invoke_with_args(args); + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Arrow error: Cast error: Can't cast value 18446744073709551615 to type Int32" ); // overflow of day - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(u32::MAX))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new().invoke_with_args(args); + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Arrow error: Cast error: Can't cast value 4294967295 to type Int32" diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index b26dc52cee4d..ffb3aed5a960 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::DataType; use arrow::datatypes::DataType::Timestamp; use arrow::datatypes::TimeUnit::Nanosecond; +use arrow::datatypes::{DataType, Field, FieldRef}; use std::any::Any; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, - Signature, Volatility, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -77,15 +77,17 @@ impl ScalarUDFImpl for NowFunc { &self.signature } - fn return_type_from_args(&self, _args: ReturnTypeArgs) -> Result { - Ok(ReturnInfo::new_non_nullable(Timestamp( - Nanosecond, - Some("+00:00".into()), - ))) + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + Ok(Field::new( + self.name(), + Timestamp(Nanosecond, Some("+00:00".into())), + false, + ) + .into()) } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } fn invoke_with_args( @@ -106,6 +108,7 @@ impl ScalarUDFImpl for NowFunc { .timestamp_nanos_opt(); Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::TimestampNanosecond(now_ts, Some("+00:00".into())), + None, ))) } diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 8b2e5ad87471..219a9b576423 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -165,6 +165,7 @@ impl ScalarUDFImpl for ToCharFunc { fn aliases(&self) -> &[String] { &self.aliases } + fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -303,7 +304,7 @@ mod tests { TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field, TimeUnit}; use chrono::{NaiveDateTime, Timelike}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -385,10 +386,15 @@ mod tests { ]; for (value, format, expected) in scalar_data { + let arg_fields = vec![ + Field::new("a", value.data_type(), false).into(), + Field::new("a", format.data_type(), false).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(value), ColumnarValue::Scalar(format)], + arg_fields, number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -465,13 +471,18 @@ mod tests { for (value, format, expected) in scalar_array_data { let batch_len = format.len(); + let arg_fields = vec![ + Field::new("a", value.data_type(), false).into(), + Field::new("a", format.data_type().to_owned(), false).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), ], + arg_fields, number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -596,13 +607,18 @@ mod tests { for (value, format, expected) in array_scalar_data { let batch_len = value.len(); + let arg_fields = vec![ + Field::new("a", value.data_type().clone(), false).into(), + Field::new("a", format.data_type(), false).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Array(value as ArrayRef), ColumnarValue::Scalar(format), ], + arg_fields, number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -618,13 +634,18 @@ mod tests { for (value, format, expected) in array_array_data { let batch_len = value.len(); + let arg_fields = vec![ + Field::new("a", value.data_type().clone(), false).into(), + Field::new("a", format.data_type().clone(), false).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Array(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), ], + arg_fields, number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -643,10 +664,12 @@ mod tests { // // invalid number of arguments + let arg_field = Field::new("a", DataType::Int32, true).into(); let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], + arg_fields: vec![arg_field], number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( @@ -655,13 +678,18 @@ mod tests { ); // invalid type + let arg_fields = vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("a", DataType::Timestamp(TimeUnit::Nanosecond, None), true).into(), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], + arg_fields, number_rows: 1, - return_type: &DataType::Utf8, + return_field: Field::new("f", DataType::Utf8, true).into(), }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index 91740b2c31c1..c9fd17dbef11 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -163,14 +163,32 @@ impl ScalarUDFImpl for ToDateFunc { #[cfg(test)] mod tests { use arrow::array::{Array, Date32Array, GenericStringArray, StringViewArray}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use arrow::{compute::kernels::cast_utils::Parser, datatypes::Date32Type}; - use datafusion_common::ScalarValue; + use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use std::sync::Arc; use super::ToDateFunc; + fn invoke_to_date_with_args( + args: Vec, + number_rows: usize, + ) -> Result { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true).into()) + .collect::>(); + + let args = datafusion_expr::ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Date32, true).into(), + }; + ToDateFunc::new().invoke_with_args(args) + } + #[test] fn test_to_date_without_format() { struct TestCase { @@ -208,12 +226,8 @@ mod tests { } fn test_scalar(sv: ScalarValue, tc: &TestCase) { - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(sv)], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = + invoke_to_date_with_args(vec![ColumnarValue::Scalar(sv)], 1); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -234,12 +248,10 @@ mod tests { { let date_array = A::from(vec![tc.date_str]); let batch_len = date_array.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Array(Arc::new(date_array))], - number_rows: batch_len, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = invoke_to_date_with_args( + vec![ColumnarValue::Array(Arc::new(date_array))], + batch_len, + ); match to_date_result { Ok(ColumnarValue::Array(a)) => { @@ -328,15 +340,13 @@ mod tests { fn test_scalar(sv: ScalarValue, tc: &TestCase) { let format_scalar = ScalarValue::Utf8(Some(tc.format_str.to_string())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let to_date_result = invoke_to_date_with_args( + vec![ ColumnarValue::Scalar(sv), ColumnarValue::Scalar(format_scalar), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + 1, + ); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -358,15 +368,13 @@ mod tests { let format_array = A::from(vec![tc.format_str]); let batch_len = date_array.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let to_date_result = invoke_to_date_with_args( + vec![ ColumnarValue::Array(Arc::new(date_array)), ColumnarValue::Array(Arc::new(format_array)), ], - number_rows: batch_len, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + batch_len, + ); match to_date_result { Ok(ColumnarValue::Array(a)) => { @@ -398,16 +406,14 @@ mod tests { let format1_scalar = ScalarValue::Utf8(Some("%Y-%m-%d".into())); let format2_scalar = ScalarValue::Utf8(Some("%Y/%m/%d".into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let to_date_result = invoke_to_date_with_args( + vec![ ColumnarValue::Scalar(formatted_date_scalar), ColumnarValue::Scalar(format1_scalar), ColumnarValue::Scalar(format2_scalar), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + 1, + ); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -431,19 +437,17 @@ mod tests { for date_str in test_cases { let formatted_date_scalar = ScalarValue::Utf8(Some(date_str.into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(formatted_date_scalar)], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = invoke_to_date_with_args( + vec![ColumnarValue::Scalar(formatted_date_scalar)], + 1, + ); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { let expected = Date32Type::parse_formatted("2020-09-08", "%Y-%m-%d"); assert_eq!(date_val, expected, "to_date created wrong value"); } - _ => panic!("Conversion of {} failed", date_str), + _ => panic!("Conversion of {date_str} failed"), } } } @@ -453,23 +457,18 @@ mod tests { let date_str = "20241231"; let date_scalar = ScalarValue::Utf8(Some(date_str.into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(date_scalar)], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = + invoke_to_date_with_args(vec![ColumnarValue::Scalar(date_scalar)], 1); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { let expected = Date32Type::parse_formatted("2024-12-31", "%Y-%m-%d"); assert_eq!( date_val, expected, - "to_date created wrong value for {}", - date_str + "to_date created wrong value for {date_str}" ); } - _ => panic!("Conversion of {} failed", date_str), + _ => panic!("Conversion of {date_str} failed"), } } @@ -478,18 +477,11 @@ mod tests { let date_str = "202412311"; let date_scalar = ScalarValue::Utf8(Some(date_str.into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(date_scalar)], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = + invoke_to_date_with_args(vec![ColumnarValue::Scalar(date_scalar)], 1); if let Ok(ColumnarValue::Scalar(ScalarValue::Date32(_))) = to_date_result { - panic!( - "Conversion of {} succeeded, but should have failed, ", - date_str - ); + panic!("Conversion of {date_str} succeeded, but should have failed. "); } } } diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 8dbef90cdc3f..b9ebe537d459 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -407,9 +407,9 @@ impl ScalarUDFImpl for ToLocalTimeFunc { mod tests { use std::sync::Arc; - use arrow::array::{types::TimestampNanosecondType, TimestampNanosecondArray}; + use arrow::array::{types::TimestampNanosecondType, Array, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; - use arrow::datatypes::{DataType, TimeUnit}; + use arrow::datatypes::{DataType, Field, TimeUnit}; use chrono::NaiveDateTime; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; @@ -538,11 +538,13 @@ mod tests { } fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { + let arg_field = Field::new("a", input.data_type(), true).into(); let res = ToLocalTimeFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(input)], + arg_fields: vec![arg_field], number_rows: 1, - return_type: &expected.data_type(), + return_field: Field::new("f", expected.data_type(), true).into(), }) .unwrap(); match res { @@ -602,10 +604,17 @@ mod tests { .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::(); let batch_size = input.len(); + let arg_field = Field::new("a", input.data_type().clone(), true).into(); let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::new(input))], + arg_fields: vec![arg_field], number_rows: batch_size, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ) + .into(), }; let result = ToLocalTimeFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index 52c86733f332..8b26a1c25950 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -639,7 +639,7 @@ mod tests { TimestampNanosecondArray, TimestampSecondArray, }; use arrow::array::{ArrayRef, Int64Array, StringBuilder}; - use arrow::datatypes::TimeUnit; + use arrow::datatypes::{Field, TimeUnit}; use chrono::Utc; use datafusion_common::{assert_contains, DataFusionError, ScalarValue}; use datafusion_expr::ScalarFunctionImplementation; @@ -1012,11 +1012,13 @@ mod tests { for udf in &udfs { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); + let arg_field = Field::new("arg", array.data_type().clone(), true).into(); assert!(matches!(rt, Timestamp(_, Some(_)))); let args = datafusion_expr::ScalarFunctionArgs { args: vec![array.clone()], + arg_fields: vec![arg_field], number_rows: 4, - return_type: &rt, + return_field: Field::new("f", rt, true).into(), }; let res = udf .invoke_with_args(args) @@ -1060,10 +1062,12 @@ mod tests { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); assert!(matches!(rt, Timestamp(_, None))); + let arg_field = Field::new("arg", array.data_type().clone(), true).into(); let args = datafusion_expr::ScalarFunctionArgs { args: vec![array.clone()], + arg_fields: vec![arg_field], number_rows: 5, - return_type: &rt, + return_field: Field::new("f", rt, true).into(), }; let res = udf .invoke_with_args(args) diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index 51e8c6968866..9a7b49105743 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -310,7 +310,7 @@ fn hex_decode(input: &[u8], buf: &mut [u8]) -> Result { let out_len = input.len() / 2; let buf = &mut buf[..out_len]; hex::decode_to_slice(input, buf).map_err(|e| { - DataFusionError::Internal(format!("Failed to decode from hex: {}", e)) + DataFusionError::Internal(format!("Failed to decode from hex: {e}")) })?; Ok(out_len) } @@ -319,7 +319,7 @@ fn base64_decode(input: &[u8], buf: &mut [u8]) -> Result { general_purpose::STANDARD_NO_PAD .decode_slice(input, buf) .map_err(|e| { - DataFusionError::Internal(format!("Failed to decode from base64: {}", e)) + DataFusionError::Internal(format!("Failed to decode from base64: {e}")) }) } @@ -419,15 +419,13 @@ impl Encoding { .decode(value) .map_err(|e| { DataFusionError::Internal(format!( - "Failed to decode value using base64: {}", - e + "Failed to decode value using base64: {e}" )) })? } Self::Hex => hex::decode(value).map_err(|e| { DataFusionError::Internal(format!( - "Failed to decode value using hex: {}", - e + "Failed to decode value using hex: {e}" )) })?, }; @@ -447,15 +445,13 @@ impl Encoding { .decode(value) .map_err(|e| { DataFusionError::Internal(format!( - "Failed to decode value using base64: {}", - e + "Failed to decode value using base64: {e}" )) })? } Self::Hex => hex::decode(value).map_err(|e| { DataFusionError::Internal(format!( - "Failed to decode value using hex: {}", - e + "Failed to decode value using hex: {e}" )) })?, }; diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index b65c4c543242..51cd5df8060d 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -209,8 +209,7 @@ mod tests { for alias in func.aliases() { assert!( names.insert(alias.to_string().to_lowercase()), - "duplicate function name: {}", - alias + "duplicate function name: {alias}" ); } } diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index d2849c3abba0..30ebf8654ea0 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -40,6 +40,7 @@ /// Exported functions accept: /// - `Vec` argument (single argument followed by a comma) /// - Variable number of `Expr` arguments (zero or more arguments, must be without commas) +#[macro_export] macro_rules! export_functions { ($(($FUNC:ident, $DOC:expr, $($arg:tt)*)),*) => { $( @@ -69,6 +70,7 @@ macro_rules! export_functions { /// named `$NAME` which returns that singleton. /// /// This is used to ensure creating the list of `ScalarUDF` only happens once. +#[macro_export] macro_rules! make_udf_function { ($UDF:ty, $NAME:ident) => { #[doc = concat!("Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) implementation of ", stringify!($NAME))] diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index fd135f4c5ec0..23e267a323b9 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -210,7 +210,9 @@ impl ScalarUDFImpl for LogFunc { }; match number { - Expr::Literal(value) if value == ScalarValue::new_one(&number_datatype)? => { + Expr::Literal(value, _) + if value == ScalarValue::new_one(&number_datatype)? => + { Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_zero( &info.get_data_type(&base)?, )?))) @@ -256,6 +258,7 @@ mod tests { use arrow::array::{Float32Array, Float64Array, Int64Array}; use arrow::compute::SortOptions; + use arrow::datatypes::Field; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_common::DFSchema; use datafusion_expr::execution_props::ExecutionProps; @@ -264,6 +267,10 @@ mod tests { #[test] #[should_panic] fn test_log_invalid_base_type() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, false).into(), + Field::new("a", DataType::Int64, false).into(), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ @@ -271,20 +278,23 @@ mod tests { ]))), // num ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), ], + arg_fields, number_rows: 4, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), }; let _ = LogFunc::new().invoke_with_args(args); } #[test] fn test_log_invalid_value() { + let arg_field = Field::new("a", DataType::Int64, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num ], + arg_fields: vec![arg_field], number_rows: 1, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = LogFunc::new().invoke_with_args(args); @@ -293,12 +303,14 @@ mod tests { #[test] fn test_log_scalar_f32_unary() { + let arg_field = Field::new("a", DataType::Float32, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num ], + arg_fields: vec![arg_field], number_rows: 1, - return_type: &DataType::Float32, + return_field: Field::new("f", DataType::Float32, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -320,12 +332,14 @@ mod tests { #[test] fn test_log_scalar_f64_unary() { + let arg_field = Field::new("a", DataType::Float64, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num ], + arg_fields: vec![arg_field], number_rows: 1, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -347,13 +361,18 @@ mod tests { #[test] fn test_log_scalar_f32() { + let arg_fields = vec![ + Field::new("a", DataType::Float32, false).into(), + Field::new("a", DataType::Float32, false).into(), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num ], + arg_fields, number_rows: 1, - return_type: &DataType::Float32, + return_field: Field::new("f", DataType::Float32, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -375,13 +394,18 @@ mod tests { #[test] fn test_log_scalar_f64() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, false).into(), + Field::new("a", DataType::Float64, false).into(), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num ], + arg_fields, number_rows: 1, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -403,14 +427,16 @@ mod tests { #[test] fn test_log_f64_unary() { + let arg_field = Field::new("a", DataType::Float64, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ 10.0, 100.0, 1000.0, 10000.0, ]))), // num ], + arg_fields: vec![arg_field], number_rows: 4, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -435,14 +461,16 @@ mod tests { #[test] fn test_log_f32_unary() { + let arg_field = Field::new("a", DataType::Float32, false).into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float32Array::from(vec![ 10.0, 100.0, 1000.0, 10000.0, ]))), // num ], + arg_fields: vec![arg_field], number_rows: 4, - return_type: &DataType::Float32, + return_field: Field::new("f", DataType::Float32, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -467,6 +495,10 @@ mod tests { #[test] fn test_log_f64() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, false).into(), + Field::new("a", DataType::Float64, false).into(), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ @@ -476,8 +508,9 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ], + arg_fields, number_rows: 4, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) @@ -502,6 +535,10 @@ mod tests { #[test] fn test_log_f32() { + let arg_fields = vec![ + Field::new("a", DataType::Float32, false).into(), + Field::new("a", DataType::Float32, false).into(), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float32Array::from(vec![ @@ -511,8 +548,9 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ], + arg_fields, number_rows: 4, - return_type: &DataType::Float32, + return_field: Field::new("f", DataType::Float32, true).into(), }; let result = LogFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 028ec2fef793..465844704f59 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -156,12 +156,15 @@ impl ScalarUDFImpl for PowerFunc { let exponent_type = info.get_data_type(&exponent)?; match exponent { - Expr::Literal(value) if value == ScalarValue::new_zero(&exponent_type)? => { + Expr::Literal(value, _) + if value == ScalarValue::new_zero(&exponent_type)? => + { Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::new_one(&info.get_data_type(&base)?)?, + None, ))) } - Expr::Literal(value) if value == ScalarValue::new_one(&exponent_type)? => { + Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => { Ok(ExprSimplifyResult::Simplified(base)) } Expr::ScalarFunction(ScalarFunction { func, mut args }) @@ -187,12 +190,17 @@ fn is_log(func: &ScalarUDF) -> bool { #[cfg(test)] mod tests { use arrow::array::Float64Array; + use arrow::datatypes::Field; use datafusion_common::cast::{as_float64_array, as_int64_array}; use super::*; #[test] fn test_power_f64() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, true).into(), + Field::new("a", DataType::Float64, true).into(), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ @@ -202,8 +210,9 @@ mod tests { 3.0, 2.0, 4.0, 4.0, ]))), // exponent ], + arg_fields, number_rows: 4, - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = PowerFunc::new() .invoke_with_args(args) @@ -227,13 +236,18 @@ mod tests { #[test] fn test_power_i64() { + let arg_fields = vec![ + Field::new("a", DataType::Int64, true).into(), + Field::new("a", DataType::Int64, true).into(), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent ], + arg_fields, number_rows: 4, - return_type: &DataType::Int64, + return_field: Field::new("f", DataType::Int64, true).into(), }; let result = PowerFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/math/random.rs b/datafusion/functions/src/math/random.rs index 607f9fb09f2a..92b6ed1895ed 100644 --- a/datafusion/functions/src/math/random.rs +++ b/datafusion/functions/src/math/random.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use arrow::array::Float64Array; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Float64; -use rand::{thread_rng, Rng}; +use rand::{rng, Rng}; use datafusion_common::{internal_err, Result}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -74,9 +74,9 @@ impl ScalarUDFImpl for RandomFunc { if !args.args.is_empty() { return internal_err!("{} function does not accept arguments", self.name()); } - let mut rng = thread_rng(); + let mut rng = rng(); let mut values = vec![0.0; args.number_rows]; - // Equivalent to set each element with rng.gen_range(0.0..1.0), but more efficient + // Equivalent to set each element with rng.random_range(0.0..1.0), but more efficient rng.fill(&mut values[..]); let array = Float64Array::from(values); diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index ba5422afa768..ec6ef5a78c6a 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -138,7 +138,7 @@ mod test { use std::sync::Arc; use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; @@ -157,10 +157,12 @@ mod test { f32::INFINITY, f32::NEG_INFINITY, ])); + let arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + arg_fields, number_rows: array.len(), - return_type: &DataType::Float32, + return_field: Field::new("f", DataType::Float32, true).into(), }; let result = SignumFunc::new() .invoke_with_args(args) @@ -201,10 +203,12 @@ mod test { f64::INFINITY, f64::NEG_INFINITY, ])); + let arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + arg_fields, number_rows: array.len(), - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), }; let result = SignumFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 8cb1a4ff3d60..52ab3d489ee3 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -577,15 +577,12 @@ fn compile_regex(regex: &str, flags: Option<&str>) -> Result "regexp_count() does not support global flag".to_string(), )); } - format!("(?{}){}", flags, regex) + format!("(?{flags}){regex}") } }; Regex::new(&pattern).map_err(|_| { - ArrowError::ComputeError(format!( - "Regular expression did not compile: {}", - pattern - )) + ArrowError::ComputeError(format!("Regular expression did not compile: {pattern}")) }) } @@ -619,6 +616,7 @@ fn count_matches( mod tests { use super::*; use arrow::array::{GenericStringArray, StringViewArray}; + use arrow::datatypes::Field; use datafusion_expr::ScalarFunctionArgs; #[test] @@ -647,6 +645,26 @@ mod tests { test_case_regexp_count_cache_check::>(); } + fn regexp_count_with_scalar_values(args: &[ScalarValue]) -> Result { + let args_values = args + .iter() + .map(|sv| ColumnarValue::Scalar(sv.clone())) + .collect(); + + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, a)| Field::new(format!("arg_{idx}"), a.data_type(), true).into()) + .collect::>(); + + RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { + args: args_values, + arg_fields, + number_rows: args.len(), + return_field: Field::new("f", Int64, true).into(), + }) + } + fn test_case_sensitive_regexp_count_scalar() { let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; let regex = "abc"; @@ -657,11 +675,7 @@ mod tests { let v_sv = ScalarValue::Utf8(Some(v.to_string())); let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - number_rows: 2, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -672,11 +686,7 @@ mod tests { // largeutf8 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - number_rows: 2, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -687,11 +697,7 @@ mod tests { // utf8view let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - number_rows: 2, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -713,15 +719,7 @@ mod tests { let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); let start_sv = ScalarValue::Int64(Some(start)); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ], - number_rows: 3, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -732,15 +730,7 @@ mod tests { // largeutf8 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ], - number_rows: 3, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -751,15 +741,7 @@ mod tests { // utf8view let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ], - number_rows: 3, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -783,16 +765,13 @@ mod tests { let start_sv = ScalarValue::Int64(Some(start)); let flags_sv = ScalarValue::Utf8(Some(flags.to_string())); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -804,16 +783,13 @@ mod tests { let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -825,16 +801,13 @@ mod tests { let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); let flags_sv = ScalarValue::Utf8View(Some(flags.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -907,16 +880,12 @@ mod tests { let start_sv = ScalarValue::Int64(Some(start)); let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| f.to_string())); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -928,16 +897,12 @@ mod tests { let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(regex.get(pos).map(|s| s.to_string())); let flags_sv = ScalarValue::LargeUtf8(flags.get(pos).map(|f| f.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -949,16 +914,12 @@ mod tests { let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(regex.get(pos).map(|s| s.to_string())); let flags_sv = ScalarValue::Utf8View(flags.get(pos).map(|f| f.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 006492a0e07a..63c987906b0f 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -16,7 +16,7 @@ // under the License. use crate::utils::make_scalar_function; -use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array}; +use arrow::array::{ArrayRef, AsArray, Int32Array, StringArrayType}; use arrow::datatypes::DataType; use arrow::error::ArrowError; use datafusion_common::types::logical_string; @@ -103,19 +103,22 @@ impl ScalarUDFImpl for AsciiFunc { fn calculate_ascii<'a, V>(array: V) -> Result where - V: ArrayAccessor, + V: StringArrayType<'a, Item = &'a str>, { - let iter = ArrayIter::new(array); - let result = iter - .map(|string| { - string.map(|s| { - let mut chars = s.chars(); - chars.next().map_or(0, |v| v as i32) - }) + let values: Vec<_> = (0..array.len()) + .map(|i| { + if array.is_null(i) { + 0 + } else { + let s = array.value(i); + s.chars().next().map_or(0, |c| c as i32) + } }) - .collect::(); + .collect(); - Ok(Arc::new(result) as ArrayRef) + let array = Int32Array::new(values.into(), array.nulls().cloned()); + + Ok(Arc::new(array)) } /// Returns the numeric code of the first character of the argument. @@ -182,6 +185,7 @@ mod tests { test_ascii!(Some(String::from("x")), Ok(Some(120))); test_ascii!(Some(String::from("a")), Ok(Some(97))); test_ascii!(Some(String::from("")), Ok(Some(0))); + test_ascii!(Some(String::from("🚀")), Ok(Some(128640))); test_ascii!(None, Ok(None)); Ok(()) } diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index c47d08d579e4..64a527eac198 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -295,7 +295,7 @@ pub fn simplify_concat(args: Vec) -> Result { let data_types: Vec<_> = args .iter() .filter_map(|expr| match expr { - Expr::Literal(l) => Some(l.data_type()), + Expr::Literal(l, _) => Some(l.data_type()), _ => None, }) .collect(); @@ -304,25 +304,25 @@ pub fn simplify_concat(args: Vec) -> Result { for arg in args.clone() { match arg { - Expr::Literal(ScalarValue::Utf8(None)) => {} - Expr::Literal(ScalarValue::LargeUtf8(None)) => { + Expr::Literal(ScalarValue::Utf8(None), _) => {} + Expr::Literal(ScalarValue::LargeUtf8(None), _) => { } - Expr::Literal(ScalarValue::Utf8View(None)) => { } + Expr::Literal(ScalarValue::Utf8View(None), _) => { } // filter out `null` args // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. // Concatenate it with the `contiguous_scalar`. - Expr::Literal(ScalarValue::Utf8(Some(v))) => { + Expr::Literal(ScalarValue::Utf8(Some(v)), _) => { contiguous_scalar += &v; } - Expr::Literal(ScalarValue::LargeUtf8(Some(v))) => { + Expr::Literal(ScalarValue::LargeUtf8(Some(v)), _) => { contiguous_scalar += &v; } - Expr::Literal(ScalarValue::Utf8View(Some(v))) => { + Expr::Literal(ScalarValue::Utf8View(Some(v)), _) => { contiguous_scalar += &v; } - Expr::Literal(x) => { + Expr::Literal(x, _) => { return internal_err!( "The scalar {x} should be casted to string type during the type coercion." ) @@ -376,6 +376,7 @@ mod tests { use crate::utils::test::test_function; use arrow::array::{Array, LargeStringArray, StringViewArray}; use arrow::array::{ArrayRef, StringArray}; + use arrow::datatypes::Field; use DataType::*; #[test] @@ -468,11 +469,22 @@ mod tests { None, Some("b"), ]))); + let arg_fields = vec![ + Field::new("a", Utf8, true), + Field::new("a", Utf8, true), + Field::new("a", Utf8, true), + Field::new("a", Utf8View, true), + Field::new("a", Utf8View, true), + ] + .into_iter() + .map(Arc::new) + .collect::>(); let args = ScalarFunctionArgs { args: vec![c0, c1, c2, c3, c4], + arg_fields, number_rows: 3, - return_type: &Utf8, + return_field: Field::new("f", Utf8, true).into(), }; let result = ConcatFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index c2bad206db15..1f45f8501e1f 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -312,6 +312,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { match delimiter { // when the delimiter is an empty string, @@ -336,8 +337,8 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result {} - Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v))) => { + Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None), _) => {} + Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)), _) => { match contiguous_scalar { None => contiguous_scalar = Some(v.to_string()), Some(mut pre) => { @@ -347,7 +348,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result return internal_err!("The scalar {s} should be casted to string type during the type coercion."), + Expr::Literal(s, _) => return internal_err!("The scalar {s} should be casted to string type during the type coercion."), // If the arg is not a literal, we should first push the current `contiguous_scalar` // to the `new_args` and reset it to None. // Then pushing this arg to the `new_args`. @@ -374,10 +375,11 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::Utf8(None), + None, ))), } } - Expr::Literal(d) => internal_err!( + Expr::Literal(d, _) => internal_err!( "The scalar {d} should be casted to string type during the type coercion." ), _ => { @@ -394,7 +396,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result bool { match expr { - Expr::Literal(v) => v.is_null(), + Expr::Literal(v, _) => v.is_null(), _ => false, } } @@ -403,10 +405,10 @@ fn is_null(expr: &Expr) -> bool { mod tests { use std::sync::Arc; + use crate::string::concat_ws::ConcatWsFunc; use arrow::array::{Array, ArrayRef, StringArray}; use arrow::datatypes::DataType::Utf8; - - use crate::string::concat_ws::ConcatWsFunc; + use arrow::datatypes::Field; use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; @@ -481,10 +483,16 @@ mod tests { Some("z"), ]))); + let arg_fields = vec![ + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + ]; let args = ScalarFunctionArgs { args: vec![c0, c1, c2], + arg_fields, number_rows: 3, - return_type: &Utf8, + return_field: Field::new("f", Utf8, true).into(), }; let result = ConcatWsFunc::new().invoke_with_args(args)?; @@ -511,10 +519,16 @@ mod tests { Some("z"), ]))); + let arg_fields = vec![ + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + ]; let args = ScalarFunctionArgs { args: vec![c0, c1, c2], + arg_fields, number_rows: 3, - return_type: &Utf8, + return_field: Field::new("f", Utf8, true).into(), }; let result = ConcatWsFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 05a3edf61c5a..215f8f7a25b9 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -150,10 +150,11 @@ fn contains(args: &[ArrayRef]) -> Result { #[cfg(test)] mod test { use super::ContainsFunc; + use crate::expr_fn::contains; use arrow::array::{BooleanArray, StringArray}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; - use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; #[test] @@ -164,11 +165,16 @@ mod test { Some("yyy?()"), ]))); let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string()))); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("a", DataType::Utf8, true).into(), + ]; let args = ScalarFunctionArgs { args: vec![array, scalar], + arg_fields, number_rows: 2, - return_type: &DataType::Boolean, + return_field: Field::new("f", DataType::Boolean, true).into(), }; let actual = udf.invoke_with_args(args).unwrap(); @@ -181,4 +187,19 @@ mod test { *expect.into_array(2).unwrap() ); } + + #[test] + fn test_contains_api() { + let expr = contains( + Expr::Literal( + ScalarValue::Utf8(Some("the quick brown fox".to_string())), + None, + ), + Expr::Literal(ScalarValue::Utf8(Some("row".to_string())), None), + ); + assert_eq!( + expr.to_string(), + "contains(Utf8(\"the quick brown fox\"), Utf8(\"row\"))" + ); + } } diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index 226275b13999..536c29a7cb25 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -98,15 +98,19 @@ impl ScalarUDFImpl for LowerFunc { mod tests { use super::*; use arrow::array::{Array, ArrayRef, StringArray}; + use arrow::datatypes::DataType::Utf8; + use arrow::datatypes::Field; use std::sync::Arc; fn to_lower(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = LowerFunc::new(); + let arg_fields = vec![Field::new("a", input.data_type().clone(), true).into()]; let args = ScalarFunctionArgs { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], - return_type: &DataType::Utf8, + arg_fields, + return_field: Field::new("f", Utf8, true).into(), }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 4c59e2644456..b4a026db9f89 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -140,7 +140,8 @@ pub mod expr_fn { "returns uuid v4 as a string value", ), ( contains, - "Return true if search_string is found within string.", + "Return true if `search_string` is found within `string`.", + string search_string )); #[doc = "Removes all characters, spaces by default, from both sides of a string"] diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 71df83352f96..ecab1af132e0 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -130,7 +130,7 @@ impl ScalarUDFImpl for StartsWithFunc { args: Vec, info: &dyn SimplifyInfo, ) -> Result { - if let Expr::Literal(scalar_value) = &args[1] { + if let Expr::Literal(scalar_value, _) = &args[1] { // Convert starts_with(col, 'prefix') to col LIKE 'prefix%' with proper escaping // Example: starts_with(col, 'ja%') -> col LIKE 'ja\%%' // 1. 'ja%' (input pattern) @@ -141,8 +141,8 @@ impl ScalarUDFImpl for StartsWithFunc { | ScalarValue::LargeUtf8(Some(pattern)) | ScalarValue::Utf8View(Some(pattern)) => { let escaped_pattern = pattern.replace("%", "\\%"); - let like_pattern = format!("{}%", escaped_pattern); - Expr::Literal(ScalarValue::Utf8(Some(like_pattern))) + let like_pattern = format!("{escaped_pattern}%"); + Expr::Literal(ScalarValue::Utf8(Some(like_pattern)), None) } _ => return Ok(ExprSimplifyResult::Original(args)), }; diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index a3a1acfcf1f0..a739a4dfb20e 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -19,25 +19,29 @@ use std::any::Any; use std::fmt::Write; use std::sync::Arc; -use arrow::array::{ArrayRef, GenericStringBuilder, OffsetSizeTrait}; +use crate::utils::make_scalar_function; +use arrow::array::{ArrayRef, GenericStringBuilder}; +use arrow::datatypes::DataType::{ + Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, Utf8, +}; use arrow::datatypes::{ - ArrowNativeType, ArrowPrimitiveType, DataType, Int32Type, Int64Type, + ArrowNativeType, ArrowPrimitiveType, DataType, Int16Type, Int32Type, Int64Type, + Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; - -use crate::utils::make_scalar_function; use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr_common::signature::TypeSignature::Exact; use datafusion_macros::user_doc; /// Converts the number to its equivalent hexadecimal representation. /// to_hex(2147483647) = '7fffffff' pub fn to_hex(args: &[ArrayRef]) -> Result where - T::Native: OffsetSizeTrait, + T::Native: std::fmt::LowerHex, { let integer_array = as_primitive_array::(&args[0])?; @@ -96,9 +100,20 @@ impl Default for ToHexFunc { impl ToHexFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + signature: Signature::one_of( + vec![ + Exact(vec![Int8]), + Exact(vec![Int16]), + Exact(vec![Int32]), + Exact(vec![Int64]), + Exact(vec![UInt8]), + Exact(vec![UInt16]), + Exact(vec![UInt32]), + Exact(vec![UInt64]), + ], + Volatility::Immutable, + ), } } } @@ -117,10 +132,8 @@ impl ScalarUDFImpl for ToHexFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - Int8 | Int16 | Int32 | Int64 => Utf8, + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => Utf8, _ => { return plan_err!("The to_hex function can only accept integers."); } @@ -129,12 +142,14 @@ impl ScalarUDFImpl for ToHexFunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { match args.args[0].data_type() { - DataType::Int32 => { - make_scalar_function(to_hex::, vec![])(&args.args) - } - DataType::Int64 => { - make_scalar_function(to_hex::, vec![])(&args.args) - } + Int64 => make_scalar_function(to_hex::, vec![])(&args.args), + UInt64 => make_scalar_function(to_hex::, vec![])(&args.args), + Int32 => make_scalar_function(to_hex::, vec![])(&args.args), + UInt32 => make_scalar_function(to_hex::, vec![])(&args.args), + Int16 => make_scalar_function(to_hex::, vec![])(&args.args), + UInt16 => make_scalar_function(to_hex::, vec![])(&args.args), + Int8 => make_scalar_function(to_hex::, vec![])(&args.args), + UInt8 => make_scalar_function(to_hex::, vec![])(&args.args), other => exec_err!("Unsupported data type {other:?} for function to_hex"), } } @@ -146,48 +161,92 @@ impl ScalarUDFImpl for ToHexFunc { #[cfg(test)] mod tests { - use arrow::array::{Int32Array, StringArray}; - + use arrow::array::{ + Int16Array, Int32Array, Int64Array, Int8Array, StringArray, UInt16Array, + UInt32Array, UInt64Array, UInt8Array, + }; use datafusion_common::cast::as_string_array; use super::*; - #[test] - // Test to_hex function for zero - fn to_hex_zero() -> Result<()> { - let array = vec![0].into_iter().collect::(); - let array_ref = Arc::new(array); - let hex_value_arc = to_hex::(&[array_ref])?; - let hex_value = as_string_array(&hex_value_arc)?; - let expected = StringArray::from(vec![Some("0")]); - assert_eq!(&expected, hex_value); - - Ok(()) + macro_rules! test_to_hex_type { + // Default test with standard input/output + ($name:ident, $arrow_type:ty, $array_type:ty) => { + test_to_hex_type!( + $name, + $arrow_type, + $array_type, + vec![Some(100), Some(0), None], + vec![Some("64"), Some("0"), None] + ); + }; + + // Custom test with custom input/output (eg: positive number) + ($name:ident, $arrow_type:ty, $array_type:ty, $input:expr, $expected:expr) => { + #[test] + fn $name() -> Result<()> { + let input = $input; + let expected = $expected; + + let array = <$array_type>::from(input); + let array_ref = Arc::new(array); + let hex_result = to_hex::<$arrow_type>(&[array_ref])?; + let hex_array = as_string_array(&hex_result)?; + let expected_array = StringArray::from(expected); + + assert_eq!(&expected_array, hex_array); + Ok(()) + } + }; } - #[test] - // Test to_hex function for positive number - fn to_hex_positive_number() -> Result<()> { - let array = vec![100].into_iter().collect::(); - let array_ref = Arc::new(array); - let hex_value_arc = to_hex::(&[array_ref])?; - let hex_value = as_string_array(&hex_value_arc)?; - let expected = StringArray::from(vec![Some("64")]); - assert_eq!(&expected, hex_value); - - Ok(()) - } + test_to_hex_type!( + to_hex_int8, + Int8Type, + Int8Array, + vec![Some(100), Some(0), None, Some(-1)], + vec![Some("64"), Some("0"), None, Some("ffffffffffffffff")] + ); + test_to_hex_type!( + to_hex_int16, + Int16Type, + Int16Array, + vec![Some(100), Some(0), None, Some(-1)], + vec![Some("64"), Some("0"), None, Some("ffffffffffffffff")] + ); + test_to_hex_type!( + to_hex_int32, + Int32Type, + Int32Array, + vec![Some(100), Some(0), None, Some(-1)], + vec![Some("64"), Some("0"), None, Some("ffffffffffffffff")] + ); + test_to_hex_type!( + to_hex_int64, + Int64Type, + Int64Array, + vec![Some(100), Some(0), None, Some(-1)], + vec![Some("64"), Some("0"), None, Some("ffffffffffffffff")] + ); - #[test] - // Test to_hex function for negative number - fn to_hex_negative_number() -> Result<()> { - let array = vec![-1].into_iter().collect::(); - let array_ref = Arc::new(array); - let hex_value_arc = to_hex::(&[array_ref])?; - let hex_value = as_string_array(&hex_value_arc)?; - let expected = StringArray::from(vec![Some("ffffffffffffffff")]); - assert_eq!(&expected, hex_value); - - Ok(()) - } + test_to_hex_type!(to_hex_uint8, UInt8Type, UInt8Array); + test_to_hex_type!(to_hex_uint16, UInt16Type, UInt16Array); + test_to_hex_type!(to_hex_uint32, UInt32Type, UInt32Array); + test_to_hex_type!(to_hex_uint64, UInt64Type, UInt64Array); + + test_to_hex_type!( + to_hex_large_signed, + Int64Type, + Int64Array, + vec![Some(i64::MAX), Some(i64::MIN)], + vec![Some("7fffffffffffffff"), Some("8000000000000000")] + ); + + test_to_hex_type!( + to_hex_large_unsigned, + UInt64Type, + UInt64Array, + vec![Some(u64::MAX), Some(u64::MIN)], + vec![Some("ffffffffffffffff"), Some("0")] + ); } diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 2fec7305d183..882fb45eda4a 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -97,15 +97,19 @@ impl ScalarUDFImpl for UpperFunc { mod tests { use super::*; use arrow::array::{Array, ArrayRef, StringArray}; + use arrow::datatypes::DataType::Utf8; + use arrow::datatypes::Field; use std::sync::Arc; fn to_upper(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = UpperFunc::new(); + let arg_field = Field::new("a", input.data_type().clone(), true).into(); let args = ScalarFunctionArgs { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], - return_type: &DataType::Utf8, + arg_fields: vec![arg_field], + return_field: Field::new("f", Utf8, true).into(), }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/string/uuid.rs b/datafusion/functions/src/string/uuid.rs index d1f43d548066..29415a9b2080 100644 --- a/datafusion/functions/src/string/uuid.rs +++ b/datafusion/functions/src/string/uuid.rs @@ -86,7 +86,7 @@ impl ScalarUDFImpl for UuidFunc { } // Generate random u128 values - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let mut randoms = vec![0u128; args.number_rows]; rng.fill(&mut randoms[..]); diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index c2db253dc741..4ee5995f0a6b 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -17,7 +17,7 @@ use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::array::{ - Array, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveBuilder, + Array, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveArray, StringArrayType, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; @@ -131,46 +131,64 @@ where T::Native: OffsetSizeTrait, V: StringArrayType<'a>, { - let mut builder = PrimitiveBuilder::::with_capacity(array.len()); - // String characters are variable length encoded in UTF-8, counting the // number of chars requires expensive decoding, however checking if the // string is ASCII only is relatively cheap. // If strings are ASCII only, count bytes instead. let is_array_ascii_only = array.is_ascii(); - if array.null_count() == 0 { + let array = if array.null_count() == 0 { if is_array_ascii_only { - for i in 0..array.len() { - let value = array.value(i); - builder.append_value(T::Native::usize_as(value.len())); - } + let values: Vec<_> = (0..array.len()) + .map(|i| { + let value = array.value(i); + T::Native::usize_as(value.len()) + }) + .collect(); + PrimitiveArray::::new(values.into(), None) } else { - for i in 0..array.len() { - let value = array.value(i); - builder.append_value(T::Native::usize_as(value.chars().count())); - } + let values: Vec<_> = (0..array.len()) + .map(|i| { + let value = array.value(i); + if value.is_ascii() { + T::Native::usize_as(value.len()) + } else { + T::Native::usize_as(value.chars().count()) + } + }) + .collect(); + PrimitiveArray::::new(values.into(), None) } } else if is_array_ascii_only { - for i in 0..array.len() { - if array.is_null(i) { - builder.append_null(); - } else { - let value = array.value(i); - builder.append_value(T::Native::usize_as(value.len())); - } - } + let values: Vec<_> = (0..array.len()) + .map(|i| { + if array.is_null(i) { + T::default_value() + } else { + let value = array.value(i); + T::Native::usize_as(value.len()) + } + }) + .collect(); + PrimitiveArray::::new(values.into(), array.nulls().cloned()) } else { - for i in 0..array.len() { - if array.is_null(i) { - builder.append_null(); - } else { - let value = array.value(i); - builder.append_value(T::Native::usize_as(value.chars().count())); - } - } - } + let values: Vec<_> = (0..array.len()) + .map(|i| { + if array.is_null(i) { + T::default_value() + } else { + let value = array.value(i); + if value.is_ascii() { + T::Native::usize_as(value.len()) + } else { + T::Native::usize_as(value.chars().count()) + } + } + }) + .collect(); + PrimitiveArray::::new(values.into(), array.nulls().cloned()) + }; - Ok(Arc::new(builder.finish()) as ArrayRef) + Ok(Arc::new(array)) } #[cfg(test)] diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index c4a9f067e9f4..8b00c7be1ccf 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -348,7 +348,7 @@ mod tests { use crate::unicode::find_in_set::FindInSetFunc; use crate::utils::test::test_function; use arrow::array::{Array, Int32Array, StringArray}; - use arrow::datatypes::DataType::Int32; + use arrow::datatypes::{DataType::Int32, Field}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; @@ -471,10 +471,18 @@ mod tests { }) .unwrap_or(1); let return_type = fis.return_type(&type_array)?; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, a)| { + Field::new(format!("arg_{idx}"), a.data_type(), true).into() + }) + .collect::>(); let result = fis.invoke_with_args(ScalarFunctionArgs { args, + arg_fields, number_rows: cardinality, - return_type: &return_type, + return_field: Field::new("f", return_type, true).into(), }); assert!(result.is_ok()); diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index b3bc73a29585..1c81b46ec78e 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -22,7 +22,9 @@ use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::array::{ ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, StringArrayType, }; -use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Field, FieldRef, Int32Type, Int64Type, +}; use datafusion_common::types::logical_string; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::{ @@ -88,16 +90,23 @@ impl ScalarUDFImpl for StrposFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be used instead") + internal_err!("return_field_from_args should be used instead") } - fn return_type_from_args( + fn return_field_from_args( &self, - args: datafusion_expr::ReturnTypeArgs, - ) -> Result { - utf8_to_int_type(&args.arg_types[0], "strpos/instr/position").map(|data_type| { - datafusion_expr::ReturnInfo::new(data_type, args.nullables.iter().any(|x| *x)) - }) + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + utf8_to_int_type(args.arg_fields[0].data_type(), "strpos/instr/position").map( + |data_type| { + Field::new( + self.name(), + data_type, + args.arg_fields.iter().any(|x| x.is_nullable()), + ) + .into() + }, + ) } fn invoke_with_args( @@ -228,7 +237,7 @@ mod tests { use arrow::array::{Array, Int32Array, Int64Array}; use arrow::datatypes::DataType::{Int32, Int64}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -321,15 +330,15 @@ mod tests { fn nullable_return_type() { fn get_nullable(string_array_nullable: bool, substring_nullable: bool) -> bool { let strpos = StrposFunc::new(); - let args = datafusion_expr::ReturnTypeArgs { - arg_types: &[DataType::Utf8, DataType::Utf8], - nullables: &[string_array_nullable, substring_nullable], + let args = datafusion_expr::ReturnFieldArgs { + arg_fields: &[ + Field::new("f1", DataType::Utf8, string_array_nullable).into(), + Field::new("f2", DataType::Utf8, substring_nullable).into(), + ], scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>], }; - let (_, nullable) = strpos.return_type_from_args(args).unwrap().into_parts(); - - nullable + strpos.return_field_from_args(args).unwrap().is_nullable() } assert!(!get_nullable(false, false)); diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 47f3121ba2ce..583ff48bff39 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -75,7 +75,7 @@ get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); /// Creates a scalar function implementation for the given function. /// * `inner` - the function to be executed /// * `hints` - hints to be used when expanding scalars to arrays -pub(super) fn make_scalar_function( +pub fn make_scalar_function( inner: F, hints: Vec, ) -> impl Fn(&[ColumnarValue]) -> Result @@ -133,7 +133,7 @@ pub mod test { let expected: Result> = $EXPECTED; let func = $FUNC; - let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); + let data_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); let cardinality = $ARGS .iter() .fold(Option::::None, |acc, arg| match arg { @@ -153,19 +153,28 @@ pub mod test { ColumnarValue::Array(a) => a.null_count() > 0, }).collect::>(); - let return_info = func.return_type_from_args(datafusion_expr::ReturnTypeArgs { - arg_types: &type_array, + let field_array = data_array.into_iter().zip(nullables).enumerate() + .map(|(idx, (data_type, nullable))| arrow::datatypes::Field::new(format!("field_{idx}"), data_type, nullable)) + .map(std::sync::Arc::new) + .collect::>(); + + let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { + arg_fields: &field_array, scalar_arguments: &scalar_arguments_refs, - nullables: &nullables }); + let arg_fields = $ARGS.iter() + .enumerate() + .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) + .collect::>(); match expected { Ok(expected) => { - assert_eq!(return_info.is_ok(), true); - let (return_type, _nullable) = return_info.unwrap().into_parts(); - assert_eq!(return_type, $EXPECTED_DATA_TYPE); + assert_eq!(return_field.is_ok(), true); + let return_field = return_field.unwrap(); + let return_type = return_field.data_type(); + assert_eq!(return_type, &$EXPECTED_DATA_TYPE); - let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type}); + let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_field}); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array"); @@ -179,17 +188,17 @@ pub mod test { }; } Err(expected_error) => { - if return_info.is_err() { - match return_info { + if return_field.is_err() { + match return_field { Ok(_) => assert!(false, "expected error"), Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); } } } else { - let (return_type, _nullable) = return_info.unwrap().into_parts(); + let return_field = return_field.unwrap(); // invoke is expected error - cannot use .expect_err() due to Debug not being implemented - match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type}) { + match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_field}) { Ok(_) => assert!(false, "expected error"), Err(error) => { assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); diff --git a/datafusion/macros/Cargo.toml b/datafusion/macros/Cargo.toml index c6532aa04681..a18988967856 100644 --- a/datafusion/macros/Cargo.toml +++ b/datafusion/macros/Cargo.toml @@ -42,4 +42,4 @@ proc-macro = true [dependencies] datafusion-expr = { workspace = true } quote = "1.0.40" -syn = { version = "2.0.100", features = ["full"] } +syn = { version = "2.0.104", features = ["full"] } diff --git a/datafusion/macros/README.md b/datafusion/macros/README.md new file mode 100644 index 000000000000..c78c02f1ca3a --- /dev/null +++ b/datafusion/macros/README.md @@ -0,0 +1,31 @@ + + +# DataFusion Window Function Common Library + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate contains common macros used in DataFusion + +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/macros/src/user_doc.rs b/datafusion/macros/src/user_doc.rs index c6510c156423..31cf9bb1b750 100644 --- a/datafusion/macros/src/user_doc.rs +++ b/datafusion/macros/src/user_doc.rs @@ -206,7 +206,7 @@ pub fn user_doc(args: TokenStream, input: TokenStream) -> TokenStream { }; let doc_section_description = doc_section_desc .map(|desc| quote! { Some(#desc)}) - .unwrap_or(quote! { None }); + .unwrap_or_else(|| quote! { None }); let sql_example = sql_example.map(|ex| { quote! { diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 61d101aab3f8..60358d20e2a1 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -55,6 +55,7 @@ regex-syntax = "0.8.0" [dev-dependencies] async-trait = { workspace = true } +criterion = { workspace = true } ctor = { workspace = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-window = { workspace = true } @@ -62,3 +63,7 @@ datafusion-functions-window-common = { workspace = true } datafusion-sql = { workspace = true } env_logger = { workspace = true } insta = { workspace = true } + +[[bench]] +name = "projection_unnecessary" +harness = false diff --git a/datafusion/optimizer/README.md b/datafusion/optimizer/README.md index 61bc1cd70145..1c9b37e09fc8 100644 --- a/datafusion/optimizer/README.md +++ b/datafusion/optimizer/README.md @@ -17,6 +17,15 @@ under the License. --> -Please see [Query Optimizer] in the Library User Guide +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. +This crate contains the DataFusion logical optimizer. +Please see [Query Optimizer] in the Library User Guide for more information. + +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion [query optimizer]: https://datafusion.apache.org/library-user-guide/query-optimizer.html diff --git a/datafusion/optimizer/benches/projection_unnecessary.rs b/datafusion/optimizer/benches/projection_unnecessary.rs new file mode 100644 index 000000000000..c9f248fe49b5 --- /dev/null +++ b/datafusion/optimizer/benches/projection_unnecessary.rs @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::ToDFSchema; +use datafusion_common::{Column, TableReference}; +use datafusion_expr::{logical_plan::LogicalPlan, projection_schema, Expr}; +use datafusion_optimizer::optimize_projections::is_projection_unnecessary; +use std::sync::Arc; + +fn is_projection_unnecessary_old( + input: &LogicalPlan, + proj_exprs: &[Expr], +) -> datafusion_common::Result { + // First check if all expressions are trivial (cheaper operation than `projection_schema`) + if !proj_exprs + .iter() + .all(|expr| matches!(expr, Expr::Column(_) | Expr::Literal(_, _))) + { + return Ok(false); + } + let proj_schema = projection_schema(input, proj_exprs)?; + Ok(&proj_schema == input.schema()) +} + +fn create_plan_with_many_exprs(num_exprs: usize) -> (LogicalPlan, Vec) { + // Create schema with many fields + let fields = (0..num_exprs) + .map(|i| Field::new(format!("col{i}"), DataType::Int32, false)) + .collect::>(); + let schema = Schema::new(fields); + + // Create table scan + let table_scan = LogicalPlan::EmptyRelation(datafusion_expr::EmptyRelation { + produce_one_row: true, + schema: Arc::new(schema.clone().to_dfschema().unwrap()), + }); + + // Create projection expressions (just column references) + let exprs = (0..num_exprs) + .map(|i| Expr::Column(Column::new(None::, format!("col{i}")))) + .collect(); + + (table_scan, exprs) +} + +fn benchmark_is_projection_unnecessary(c: &mut Criterion) { + let (plan, exprs) = create_plan_with_many_exprs(1000); + + let mut group = c.benchmark_group("projection_unnecessary_comparison"); + + group.bench_function("is_projection_unnecessary_new", |b| { + b.iter(|| black_box(is_projection_unnecessary(&plan, &exprs).unwrap())) + }); + + group.bench_function("is_projection_unnecessary_old", |b| { + b.iter(|| black_box(is_projection_unnecessary_old(&plan, &exprs).unwrap())) + }); + + group.finish(); +} + +criterion_group!(benches, benchmark_is_projection_unnecessary); +criterion_main!(benches); diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index f8a818563609..fa7ff1b8b19d 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -189,19 +189,19 @@ fn grouping_function_on_id( // Postgres allows grouping function for group by without grouping sets, the result is then // always 0 if !is_grouping_set { - return Ok(Expr::Literal(ScalarValue::from(0i32))); + return Ok(Expr::Literal(ScalarValue::from(0i32), None)); } let group_by_expr_count = group_by_expr.len(); let literal = |value: usize| { if group_by_expr_count < 8 { - Expr::Literal(ScalarValue::from(value as u8)) + Expr::Literal(ScalarValue::from(value as u8), None) } else if group_by_expr_count < 16 { - Expr::Literal(ScalarValue::from(value as u16)) + Expr::Literal(ScalarValue::from(value as u16), None) } else if group_by_expr_count < 32 { - Expr::Literal(ScalarValue::from(value as u32)) + Expr::Literal(ScalarValue::from(value as u32), None) } else { - Expr::Literal(ScalarValue::from(value as u64)) + Expr::Literal(ScalarValue::from(value as u64), None) } }; diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index d47f7ea6ce68..b5a3e9a2d585 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -41,7 +41,7 @@ use datafusion_expr::expr_schema::cast_subquery; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion}; use datafusion_expr::type_coercion::functions::{ - data_types_with_aggregate_udf, data_types_with_scalar_udf, + data_types_with_scalar_udf, fields_with_aggregate_udf, }; use datafusion_expr::type_coercion::other::{ get_coerce_type_for_case_expression, get_coerce_type_for_list, @@ -539,17 +539,18 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { ), ))) } - Expr::WindowFunction(WindowFunction { - fun, - params: - expr::WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment, - }, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + expr::WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment, + }, + } = *window_fun; let window_frame = coerce_window_frame(window_frame, self.schema, &order_by)?; @@ -565,7 +566,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { }; Ok(Transformed::yes( - Expr::WindowFunction(WindowFunction::new(fun, args)) + Expr::from(WindowFunction::new(fun, args)) .partition_by(partition_by) .order_by(order_by) .window_frame(window_frame) @@ -578,7 +579,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { Expr::Alias(_) | Expr::Column(_) | Expr::ScalarVariable(_, _) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::SimilarTo(_) | Expr::IsNotNull(_) | Expr::IsNull(_) @@ -718,6 +719,9 @@ fn coerce_frame_bound( fn extract_window_frame_target_type(col_type: &DataType) -> Result { if col_type.is_numeric() || is_utf8_or_utf8view_or_large_utf8(col_type) + || matches!(col_type, DataType::List(_)) + || matches!(col_type, DataType::LargeList(_)) + || matches!(col_type, DataType::FixedSizeList(_, _)) || matches!(col_type, DataType::Null) || matches!(col_type, DataType::Boolean) { @@ -808,12 +812,15 @@ fn coerce_arguments_for_signature_with_aggregate_udf( return Ok(expressions); } - let current_types = expressions + let current_fields = expressions .iter() - .map(|e| e.get_type(schema)) + .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; - let new_types = data_types_with_aggregate_udf(¤t_types, func)?; + let new_types = fields_with_aggregate_udf(¤t_fields, func)? + .into_iter() + .map(|f| f.data_type().clone()) + .collect::>(); expressions .into_iter() @@ -1055,12 +1062,13 @@ mod test { use arrow::datatypes::DataType::Utf8; use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit}; + use insta::assert_snapshot; use crate::analyzer::type_coercion::{ coerce_case_expression, TypeCoercion, TypeCoercionRewriter, }; use crate::analyzer::Analyzer; - use crate::test::{assert_analyzed_plan_eq, assert_analyzed_plan_with_config_eq}; + use crate::assert_analyzed_plan_with_config_eq_snapshot; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans}; @@ -1096,13 +1104,80 @@ mod test { })) } + macro_rules! assert_analyzed_plan_eq { + ( + $plan: expr, + @ $expected: literal $(,)? + ) => {{ + let options = ConfigOptions::default(); + let rule = Arc::new(TypeCoercion::new()); + assert_analyzed_plan_with_config_eq_snapshot!( + options, + rule, + $plan, + @ $expected, + ) + }}; + } + + macro_rules! coerce_on_output_if_viewtype { + ( + $is_viewtype: expr, + $plan: expr, + @ $expected: literal $(,)? + ) => {{ + let mut options = ConfigOptions::default(); + // coerce on output + if $is_viewtype {options.optimizer.expand_views_at_output = true;} + let rule = Arc::new(TypeCoercion::new()); + + assert_analyzed_plan_with_config_eq_snapshot!( + options, + rule, + $plan, + @ $expected, + ) + }}; + } + + fn assert_type_coercion_error( + plan: LogicalPlan, + expected_substr: &str, + ) -> Result<()> { + let options = ConfigOptions::default(); + let analyzer = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]); + + match analyzer.execute_and_check(plan, &options, |_, _| {}) { + Ok(succeeded_plan) => { + panic!( + "Expected a type coercion error, but analysis succeeded: \n{succeeded_plan:#?}" + ); + } + Err(e) => { + let msg = e.to_string(); + assert!( + msg.contains(expected_substr), + "Error did not contain expected substring.\n expected to find: `{expected_substr}`\n actual error: `{msg}`" + ); + } + } + + Ok(()) + } + #[test] fn simple_case() -> Result<()> { let expr = col("a").lt(lit(2_u32)); let empty = empty_with_type(DataType::Float64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a < CAST(UInt32(2) AS Float64)\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a < CAST(UInt32(2) AS Float64) + EmptyRelation + " + ) } #[test] @@ -1137,28 +1212,15 @@ mod test { Arc::new(analyzed_union), )?); - let expected = "Projection: a\n Union\n Projection: CAST(datafusion.test.foo.a AS Int64) AS a\n EmptyRelation\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), top_level_plan, expected) - } - - fn coerce_on_output_if_viewtype(plan: LogicalPlan, expected: &str) -> Result<()> { - let mut options = ConfigOptions::default(); - options.optimizer.expand_views_at_output = true; - - assert_analyzed_plan_with_config_eq( - options, - Arc::new(TypeCoercion::new()), - plan.clone(), - expected, - ) - } - - fn do_not_coerce_on_output(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_analyzed_plan_with_config_eq( - ConfigOptions::default(), - Arc::new(TypeCoercion::new()), - plan.clone(), - expected, + assert_analyzed_plan_eq!( + top_level_plan, + @r" + Projection: a + Union + Projection: CAST(datafusion.test.foo.a AS Int64) AS a + EmptyRelation + EmptyRelation + " ) } @@ -1172,12 +1234,26 @@ mod test { vec![expr.clone()], Arc::clone(&empty), )?); + // Plan A: no coerce - let if_not_coerced = "Projection: a\n EmptyRelation"; - do_not_coerce_on_output(plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + plan.clone(), + @r" + Projection: a + EmptyRelation + " + )?; + // Plan A: coerce requested: Utf8View => LargeUtf8 - let if_coerced = "Projection: CAST(a AS LargeUtf8)\n EmptyRelation"; - coerce_on_output_if_viewtype(plan.clone(), if_coerced)?; + coerce_on_output_if_viewtype!( + true, + plan.clone(), + @r" + Projection: CAST(a AS LargeUtf8) + EmptyRelation + " + )?; // Plan B // scenario: outermost bool projection @@ -1187,12 +1263,33 @@ mod test { Arc::clone(&empty), )?); // Plan B: no coerce - let if_not_coerced = - "Projection: a < CAST(Utf8(\"foo\") AS Utf8View)\n EmptyRelation"; - do_not_coerce_on_output(bool_plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + bool_plan.clone(), + @r#" + Projection: a < CAST(Utf8("foo") AS Utf8View) + EmptyRelation + "# + )?; + + coerce_on_output_if_viewtype!( + false, + plan.clone(), + @r" + Projection: a + EmptyRelation + " + )?; + // Plan B: coerce requested: no coercion applied - let if_coerced = if_not_coerced; - coerce_on_output_if_viewtype(bool_plan, if_coerced)?; + coerce_on_output_if_viewtype!( + true, + plan.clone(), + @r" + Projection: CAST(a AS LargeUtf8) + EmptyRelation + " + )?; // Plan C // scenario: with a non-projection root logical plan node @@ -1202,13 +1299,29 @@ mod test { input: Arc::new(plan), fetch: None, }); + // Plan C: no coerce - let if_not_coerced = - "Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - do_not_coerce_on_output(sort_plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + sort_plan.clone(), + @r" + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation + " + )?; + // Plan C: coerce requested: Utf8View => LargeUtf8 - let if_coerced = "Projection: CAST(a AS LargeUtf8)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - coerce_on_output_if_viewtype(sort_plan.clone(), if_coerced)?; + coerce_on_output_if_viewtype!( + true, + sort_plan.clone(), + @r" + Projection: CAST(a AS LargeUtf8) + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation + " + )?; // Plan D // scenario: two layers of projections with view types @@ -1217,11 +1330,27 @@ mod test { Arc::new(sort_plan), )?); // Plan D: no coerce - let if_not_coerced = "Projection: a\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - do_not_coerce_on_output(plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + plan.clone(), + @r" + Projection: a + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation + " + )?; // Plan B: coerce requested: Utf8View => LargeUtf8 only on outermost - let if_coerced = "Projection: CAST(a AS LargeUtf8)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - coerce_on_output_if_viewtype(plan.clone(), if_coerced)?; + coerce_on_output_if_viewtype!( + true, + plan.clone(), + @r" + Projection: CAST(a AS LargeUtf8) + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation + " + )?; Ok(()) } @@ -1236,12 +1365,26 @@ mod test { vec![expr.clone()], Arc::clone(&empty), )?); + // Plan A: no coerce - let if_not_coerced = "Projection: a\n EmptyRelation"; - do_not_coerce_on_output(plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + plan.clone(), + @r" + Projection: a + EmptyRelation + " + )?; + // Plan A: coerce requested: BinaryView => LargeBinary - let if_coerced = "Projection: CAST(a AS LargeBinary)\n EmptyRelation"; - coerce_on_output_if_viewtype(plan.clone(), if_coerced)?; + coerce_on_output_if_viewtype!( + true, + plan.clone(), + @r" + Projection: CAST(a AS LargeBinary) + EmptyRelation + " + )?; // Plan B // scenario: outermost bool projection @@ -1250,13 +1393,26 @@ mod test { vec![bool_expr], Arc::clone(&empty), )?); + // Plan B: no coerce - let if_not_coerced = - "Projection: a < CAST(Binary(\"8,1,8,1\") AS BinaryView)\n EmptyRelation"; - do_not_coerce_on_output(bool_plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + bool_plan.clone(), + @r#" + Projection: a < CAST(Binary("8,1,8,1") AS BinaryView) + EmptyRelation + "# + )?; + // Plan B: coerce requested: no coercion applied - let if_coerced = if_not_coerced; - coerce_on_output_if_viewtype(bool_plan, if_coerced)?; + coerce_on_output_if_viewtype!( + true, + bool_plan.clone(), + @r#" + Projection: a < CAST(Binary("8,1,8,1") AS BinaryView) + EmptyRelation + "# + )?; // Plan C // scenario: with a non-projection root logical plan node @@ -1266,13 +1422,28 @@ mod test { input: Arc::new(plan), fetch: None, }); + // Plan C: no coerce - let if_not_coerced = - "Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - do_not_coerce_on_output(sort_plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + sort_plan.clone(), + @r" + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation + " + )?; // Plan C: coerce requested: BinaryView => LargeBinary - let if_coerced = "Projection: CAST(a AS LargeBinary)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - coerce_on_output_if_viewtype(sort_plan.clone(), if_coerced)?; + coerce_on_output_if_viewtype!( + true, + sort_plan.clone(), + @r" + Projection: CAST(a AS LargeBinary) + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation + " + )?; // Plan D // scenario: two layers of projections with view types @@ -1280,12 +1451,30 @@ mod test { vec![col("a")], Arc::new(sort_plan), )?); + // Plan D: no coerce - let if_not_coerced = "Projection: a\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - do_not_coerce_on_output(plan.clone(), if_not_coerced)?; + coerce_on_output_if_viewtype!( + false, + plan.clone(), + @r" + Projection: a + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation + " + )?; + // Plan B: coerce requested: BinaryView => LargeBinary only on outermost - let if_coerced = "Projection: CAST(a AS LargeBinary)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; - coerce_on_output_if_viewtype(plan.clone(), if_coerced)?; + coerce_on_output_if_viewtype!( + true, + plan.clone(), + @r" + Projection: CAST(a AS LargeBinary) + Sort: a ASC NULLS FIRST + Projection: a + EmptyRelation + " + )?; Ok(()) } @@ -1299,9 +1488,14 @@ mod test { vec![expr.clone().or(expr)], empty, )?); - let expected = "Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64) + EmptyRelation + " + ) } #[derive(Debug, Clone)] @@ -1340,9 +1534,14 @@ mod test { }) .call(vec![lit(123_i32)]); let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?); - let expected = - "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: TestScalarUDF(CAST(Int32(123) AS Float32)) + EmptyRelation + " + ) } #[test] @@ -1372,9 +1571,14 @@ mod test { vec![scalar_function_expr], empty, )?); - let expected = - "Projection: TestScalarUDF(CAST(Int64(10) AS Float32))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: TestScalarUDF(CAST(Int64(10) AS Float32)) + EmptyRelation + " + ) } #[test] @@ -1397,8 +1601,14 @@ mod test { None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?); - let expected = "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: MY_AVG(CAST(Int64(10) AS Float64)) + EmptyRelation + " + ) } #[test] @@ -1413,8 +1623,8 @@ mod test { return_type, accumulator, vec![ - Field::new("count", DataType::UInt64, true), - Field::new("avg", DataType::Float64, true), + Field::new("count", DataType::UInt64, true).into(), + Field::new("avg", DataType::Float64, true).into(), ], )); let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( @@ -1445,8 +1655,14 @@ mod test { None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); - let expected = "Projection: avg(Float64(12))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: avg(Float64(12)) + EmptyRelation + " + )?; let empty = empty_with_type(DataType::Int32); let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( @@ -1458,9 +1674,14 @@ mod test { None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); - let expected = "Projection: avg(CAST(a AS Float64))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: avg(CAST(a AS Float64)) + EmptyRelation + " + ) } #[test] @@ -1489,10 +1710,14 @@ mod test { + lit(ScalarValue::new_interval_dt(123, 456)); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = - "Projection: CAST(Utf8(\"1998-03-18\") AS Date32) + IntervalDayTime(\"IntervalDayTime { days: 123, milliseconds: 456 }\")\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 456 }") + EmptyRelation + "# + ) } #[test] @@ -1501,8 +1726,12 @@ mod test { let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false); let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)]) + EmptyRelation + ")?; // a in (1,4,8), a is decimal let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false); @@ -1514,8 +1743,12 @@ mod test { )?), })); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + assert_analyzed_plan_eq!( + plan, + @r" + Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))]) + EmptyRelation + ") } #[test] @@ -1528,10 +1761,14 @@ mod test { ); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); - let expected = - "Filter: CAST(a AS Date32) BETWEEN CAST(Utf8(\"2002-05-08\") AS Date32) AND CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\")\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r#" + Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) AND CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") + EmptyRelation + "# + ) } #[test] @@ -1544,11 +1781,15 @@ mod test { ); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); + // TODO: we should cast col(a). - let expected = - "Filter: CAST(a AS Date32) BETWEEN CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\") AND CAST(Utf8(\"2002-12-08\") AS Date32)\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + assert_analyzed_plan_eq!( + plan, + @r#" + Filter: CAST(a AS Date32) BETWEEN CAST(Utf8("2002-05-08") AS Date32) + IntervalYearMonth("1") AND CAST(Utf8("2002-12-08") AS Date32) + EmptyRelation + "# + ) } #[test] @@ -1556,10 +1797,14 @@ mod test { let expr = lit(ScalarValue::Null).between(lit(ScalarValue::Null), lit(2i64)); let empty = empty(); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); - let expected = - "Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2)\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r" + Filter: CAST(NULL AS Int64) BETWEEN CAST(NULL AS Int64) AND Int64(2) + EmptyRelation + " + ) } #[test] @@ -1569,37 +1814,60 @@ mod test { let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?); - let expected = "Projection: a IS TRUE\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IS TRUE + EmptyRelation + " + )?; let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, ""); - let err = ret.unwrap_err().to_string(); - assert!(err.contains("Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"), "{err}"); + assert_type_coercion_error( + plan, + "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean" + )?; // is not true let expr = col("a").is_not_true(); let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a IS NOT TRUE\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IS NOT TRUE + EmptyRelation + " + )?; // is false let expr = col("a").is_false(); let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a IS FALSE\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IS FALSE + EmptyRelation + " + )?; // is not false let expr = col("a").is_not_false(); let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a IS NOT FALSE\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IS NOT FALSE + EmptyRelation + " + ) } #[test] @@ -1610,27 +1878,38 @@ mod test { let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); - let expected = "Projection: a LIKE Utf8(\"abc\")\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: a LIKE Utf8("abc") + EmptyRelation + "# + )?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); - let expected = "Projection: a LIKE CAST(NULL AS Utf8)\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a LIKE CAST(NULL AS Utf8) + EmptyRelation + " + )?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); - let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); - assert!(err.is_err()); - assert!(err.unwrap_err().to_string().contains( - "There isn't a common type to coerce Int64 and Utf8 in LIKE expression" - )); + assert_type_coercion_error( + plan, + "There isn't a common type to coerce Int64 and Utf8 in LIKE expression", + )?; // ilike let expr = Box::new(col("a")); @@ -1638,27 +1917,39 @@ mod test { let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); - let expected = "Projection: a ILIKE Utf8(\"abc\")\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: a ILIKE Utf8("abc") + EmptyRelation + "# + )?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); - let expected = "Projection: a ILIKE CAST(NULL AS Utf8)\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a ILIKE CAST(NULL AS Utf8) + EmptyRelation + " + )?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); - let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); - assert!(err.is_err()); - assert!(err.unwrap_err().to_string().contains( - "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression" - )); + assert_type_coercion_error( + plan, + "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression", + )?; + Ok(()) } @@ -1669,23 +1960,34 @@ mod test { let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?); - let expected = "Projection: a IS UNKNOWN\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; + + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IS UNKNOWN + EmptyRelation + " + )?; let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); - let err = ret.unwrap_err().to_string(); - assert!(err.contains("Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"), "{err}"); + assert_type_coercion_error( + plan, + "Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean" + )?; // is not unknown let expr = col("a").is_not_unknown(); let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a IS NOT UNKNOWN\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + assert_analyzed_plan_eq!( + plan, + @r" + Projection: a IS NOT UNKNOWN + EmptyRelation + " + ) } #[test] @@ -1694,21 +1996,19 @@ mod test { let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)]; // concat-type signature - { - let expr = ScalarUDF::new_from_impl(TestScalarUDF { - signature: Signature::variadic(vec![Utf8], Volatility::Immutable), - }) - .call(args.to_vec()); - let plan = LogicalPlan::Projection(Projection::try_new( - vec![expr], - Arc::clone(&empty), - )?); - let expected = - "Projection: TestScalarUDF(a, Utf8(\"b\"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - } - - Ok(()) + let expr = ScalarUDF::new_from_impl(TestScalarUDF { + signature: Signature::variadic(vec![Utf8], Volatility::Immutable), + }) + .call(args.to_vec()); + let plan = + LogicalPlan::Projection(Projection::try_new(vec![expr], Arc::clone(&empty))?); + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: TestScalarUDF(a, Utf8("b"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8)) + EmptyRelation + "# + ) } #[test] @@ -1758,10 +2058,14 @@ mod test { .eq(cast(lit("1998-03-18"), DataType::Date32)); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = - "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) = CAST(CAST(Utf8(\"1998-03-18\") AS Date32) AS Timestamp(Nanosecond, None))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) = CAST(CAST(Utf8("1998-03-18") AS Date32) AS Timestamp(Nanosecond, None)) + EmptyRelation + "# + ) } fn cast_if_not_same_type( @@ -1882,12 +2186,9 @@ mod test { else_expr: Some(Box::new(col("string"))), }; let err = coerce_case_expression(case, &schema).unwrap_err(); - assert_eq!( + assert_snapshot!( err.strip_backtrace(), - "Error during planning: \ - Failed to coerce case (Interval(MonthDayNano)) and \ - when ([Float32, Binary, Utf8]) to common types in \ - CASE WHEN expression" + @"Error during planning: Failed to coerce case (Interval(MonthDayNano)) and when ([Float32, Binary, Utf8]) to common types in CASE WHEN expression" ); let case = Case { @@ -1900,12 +2201,9 @@ mod test { else_expr: Some(Box::new(col("timestamp"))), }; let err = coerce_case_expression(case, &schema).unwrap_err(); - assert_eq!( + assert_snapshot!( err.strip_backtrace(), - "Error during planning: \ - Failed to coerce then ([Date32, Float32, Binary]) and \ - else (Some(Timestamp(Nanosecond, None))) to common types \ - in CASE WHEN expression" + @"Error during planning: Failed to coerce then ([Date32, Float32, Binary]) and else (Some(Timestamp(Nanosecond, None))) to common types in CASE WHEN expression" ); Ok(()) @@ -2108,12 +2406,14 @@ mod test { let expr = col("a").eq(cast(col("a"), may_type_cutsom)); let empty = empty_with_type(map_type_entries); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: a = CAST(CAST(a AS Map(Field { name: \"key_value\", data_type: Struct([Field { name: \"key\", data_type: Utf8, \ - nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"value\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), \ - nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false)) AS Map(Field { name: \"entries\", data_type: Struct([Field { name: \"key\", data_type: Utf8, nullable: false, \ - dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"value\", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false))\n \ - EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: a = CAST(CAST(a AS Map(Field { name: "key_value", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false)) AS Map(Field { name: "entries", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false)) + EmptyRelation + "# + ) } #[test] @@ -2129,9 +2429,14 @@ mod test { )); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = "Projection: IntervalYearMonth(\"12\") + CAST(Utf8(\"2000-01-01T00:00:00\") AS Timestamp(Nanosecond, None))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: IntervalYearMonth("12") + CAST(Utf8("2000-01-01T00:00:00") AS Timestamp(Nanosecond, None)) + EmptyRelation + "# + ) } #[test] @@ -2149,10 +2454,14 @@ mod test { )); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = - "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) - CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) - CAST(Utf8("1998-03-18") AS Timestamp(Nanosecond, None)) + EmptyRelation + "# + ) } #[test] @@ -2171,14 +2480,17 @@ mod test { )); let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int64)?); // add cast for subquery - let expected = "\ - Filter: a IN ()\ - \n Subquery:\ - \n Projection: CAST(a AS Int64)\ - \n EmptyRelation\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + + assert_analyzed_plan_eq!( + plan, + @r" + Filter: a IN () + Subquery: + Projection: CAST(a AS Int64) + EmptyRelation + EmptyRelation + " + ) } #[test] @@ -2196,14 +2508,17 @@ mod test { false, )); let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_int32)?); + // add cast for subquery - let expected = "\ - Filter: CAST(a AS Int64) IN ()\ - \n Subquery:\ - \n EmptyRelation\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + assert_analyzed_plan_eq!( + plan, + @r" + Filter: CAST(a AS Int64) IN () + Subquery: + EmptyRelation + EmptyRelation + " + ) } #[test] @@ -2221,13 +2536,17 @@ mod test { false, )); let plan = LogicalPlan::Filter(Filter::try_new(in_subquery_expr, empty_outside)?); + // add cast for subquery - let expected = "Filter: CAST(a AS Decimal128(13, 8)) IN ()\ - \n Subquery:\ - \n Projection: CAST(a AS Decimal128(13, 8))\ - \n EmptyRelation\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - Ok(()) + assert_analyzed_plan_eq!( + plan, + @r" + Filter: CAST(a AS Decimal128(13, 8)) IN () + Subquery: + Projection: CAST(a AS Decimal128(13, 8)) + EmptyRelation + EmptyRelation + " + ) } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 69b5fbb9f8c0..6a49e5d22087 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -803,23 +803,39 @@ mod test { use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::optimizer::OptimizerContext; use crate::test::*; - use crate::Optimizer; use datafusion_expr::test::function_stub::{avg, sum}; - fn assert_optimized_plan_eq( - expected: &str, - plan: LogicalPlan, - config: Option<&dyn OptimizerConfig>, - ) { - let optimizer = - Optimizer::with_rules(vec![Arc::new(CommonSubexprEliminate::new())]); - let default_config = OptimizerContext::new(); - let config = config.unwrap_or(&default_config); - let optimized_plan = optimizer.optimize(plan, config, |_, _| ()).unwrap(); - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(expected, formatted_plan); + macro_rules! assert_optimized_plan_equal { + ( + $config:expr, + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rules: Vec> = vec![Arc::new(CommonSubexprEliminate::new())]; + assert_optimized_plan_eq_snapshot!( + $config, + rules, + $plan, + @ $expected, + ) + }}; + + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rules: Vec> = vec![Arc::new(CommonSubexprEliminate::new())]; + let optimizer_ctx = OptimizerContext::new(); + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -844,13 +860,14 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]\ - \n Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]] + Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -864,13 +881,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b\ - \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b + Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -886,7 +904,7 @@ mod test { Signature::exact(vec![DataType::UInt32], Volatility::Stable), return_type.clone(), Arc::clone(&accumulator), - vec![Field::new("value", DataType::UInt32, true)], + vec![Field::new("value", DataType::UInt32, true).into()], ))), vec![inner], false, @@ -917,11 +935,14 @@ mod test { )? .build()?; - let expected = "Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)\ - \n Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c) + Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]] + TableScan: test + " + )?; // test: trafo after aggregate let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -936,11 +957,14 @@ mod test { )? .build()?; - let expected = "Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)\ - \n Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a) + Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]] + TableScan: test + " + )?; // test: transformation before aggregate let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -953,11 +977,14 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\ - \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]] + Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // test: common between agg and group let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -970,11 +997,14 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\ - \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]] + Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // test: all mixed let plan = LogicalPlanBuilder::from(table_scan) @@ -991,14 +1021,15 @@ mod test { )? .build()?; - let expected = "Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)\ - \n Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]\ - \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a) + Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]] + Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1018,14 +1049,15 @@ mod test { )? .build()?; - let expected = "Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)\ - \n Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]\ - \n Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a\ - \n TableScan: table.test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a) + Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]] + Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a + TableScan: table.test + " + ) } #[test] @@ -1039,13 +1071,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 AS first, __common_expr_1 AS second\ - \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS first, __common_expr_1 AS second + Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1056,13 +1089,14 @@ mod test { .project(vec![lit(1) + col("a"), col("a") + lit(1)])? .build()?; - let expected = "Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)\ - \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1) + Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1074,12 +1108,14 @@ mod test { .project(vec![lit(1) + col("a")])? .build()?; - let expected = "Projection: Int32(1) + test.a\ - \n Projection: Int32(1) + test.a, test.a\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: Int32(1) + test.a + Projection: Int32(1) + test.a, test.a + TableScan: test + " + ) } #[test] @@ -1193,14 +1229,15 @@ mod test { .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))? .build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 - Int32(10) > __common_expr_1\ - \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 - Int32(10) > __common_expr_1 + Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1226,7 +1263,7 @@ mod test { fn test_alias_collision() -> Result<()> { let table_scan = test_table_scan()?; - let config = &OptimizerContext::new(); + let config = OptimizerContext::new(); let common_expr_1 = config.alias_generator().next(CSE_PREFIX); let plan = LogicalPlanBuilder::from(table_scan.clone()) .project(vec![ @@ -1241,14 +1278,18 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4\ - \n Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c\ - \n Projection: test.a + test.b AS __common_expr_1, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, Some(config)); - - let config = &OptimizerContext::new(); + assert_optimized_plan_equal!( + config, + plan, + @ r" + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4 + Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c + Projection: test.a + test.b AS __common_expr_1, test.c + TableScan: test + " + )?; + + let config = OptimizerContext::new(); let _common_expr_1 = config.alias_generator().next(CSE_PREFIX); let common_expr_2 = config.alias_generator().next(CSE_PREFIX); let plan = LogicalPlanBuilder::from(table_scan) @@ -1264,12 +1305,16 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4\ - \n Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c\ - \n Projection: test.a + test.b AS __common_expr_2, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, Some(config)); + assert_optimized_plan_equal!( + config, + plan, + @ r" + Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4 + Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c + Projection: test.a + test.b AS __common_expr_2, test.c + TableScan: test + " + )?; Ok(()) } @@ -1308,13 +1353,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5\ - \n Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5 + Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1331,13 +1377,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2\ - \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2 + Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1360,13 +1407,14 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4\ - \n Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4 + Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1382,14 +1430,15 @@ mod test { .project(vec![col("c1"), col("c2")])? .build()?; - let expected = "Projection: c1, c2\ - \n Projection: __common_expr_1 AS c1, __common_expr_1 AS c2\ - \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: c1, c2 + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2 + Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1405,14 +1454,15 @@ mod test { ])? .build()?; - let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2\ - \n Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c\ - \n Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2 + Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c + Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1422,13 +1472,15 @@ mod test { let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 * __common_expr_1 = Int32(30)\ - \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 * __common_expr_1 = Int32(30) + Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1438,13 +1490,15 @@ mod test { let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ - \n Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 + __common_expr_1 = Int32(30) + Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1454,13 +1508,15 @@ mod test { let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ - \n Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 + __common_expr_1 = Int32(30) + Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1470,13 +1526,15 @@ mod test { let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ - \n Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 + __common_expr_1 = Int32(30) + Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1486,13 +1544,15 @@ mod test { let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ - \n Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 + __common_expr_1 = Int32(30) + Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1502,13 +1562,15 @@ mod test { let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a"))); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 AND __common_expr_1\ - \n Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 AND __common_expr_1 + Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1518,13 +1580,15 @@ mod test { let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a"))); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 AND __common_expr_1\ - \n Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 AND __common_expr_1 + Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } #[test] @@ -1535,11 +1599,15 @@ mod test { .eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 - __common_expr_1 = Int32(30)\ - \n Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 - __common_expr_1 = Int32(30) + Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // ((c1 + c2 / c3) * c3 <=> c3 * (c2 / c3 + c1)) let table_scan = test_table_scan()?; @@ -1548,11 +1616,16 @@ mod test { + col("a")) .eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)\ - \n Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30) + Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // c2 / (c1 + c3) <=> c2 / (c3 + c1) let table_scan = test_table_scan()?; @@ -1560,11 +1633,15 @@ mod test { * (col("b") / (col("c") + col("a")))) .eq(lit(30)); let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; - let expected = "Projection: test.a, test.b, test.c\ - \n Filter: __common_expr_1 * __common_expr_1 = Int32(30)\ - \n Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.b, test.c + Filter: __common_expr_1 * __common_expr_1 = Int32(30) + Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; Ok(()) } @@ -1612,10 +1689,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a\ - \n Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a + Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // is_null(a == b) <=> is_null(b == a) let table_scan = test_table_scan()?; @@ -1624,10 +1705,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL\ - \n Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL + Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // a + b between 0 and 10 <=> b + a between 0 and 10 let table_scan = test_table_scan()?; @@ -1636,10 +1721,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)\ - \n Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10) + Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // c between a + b and 10 <=> c between b + a and 10 let table_scan = test_table_scan()?; @@ -1648,10 +1737,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)\ - \n Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10) + Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + )?; // function call with argument <=> function call with argument let udf = ScalarUDF::from(TestUdf::new()); @@ -1661,11 +1754,14 @@ mod test { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![expr1, expr2])? .build()?; - let expected = "Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)\ - \n Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c\ - \n TableScan: test"; - assert_optimized_plan_eq(expected, plan, None); - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a) + Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) } /// returns a "random" function that is marked volatile (aka each invocation diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 418619c8399e..63236787743a 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -71,6 +71,9 @@ pub struct PullUpCorrelatedExpr { pub collected_count_expr_map: HashMap, /// pull up having expr, which must be evaluated after the Join pub pull_up_having_expr: Option, + /// whether we have converted a scalar aggregation into a group aggregation. When unnesting + /// lateral joins, we need to produce a left outer join in such cases. + pub pulled_up_scalar_agg: bool, } impl Default for PullUpCorrelatedExpr { @@ -91,6 +94,7 @@ impl PullUpCorrelatedExpr { need_handle_count_bug: false, collected_count_expr_map: HashMap::new(), pull_up_having_expr: None, + pulled_up_scalar_agg: false, } } @@ -313,6 +317,11 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { missing_exprs.push(un_matched_row); } } + if aggregate.group_expr.is_empty() { + // TODO: how do we handle the case where we have pulled multiple aggregations? For example, + // a group agg with a scalar agg as child. + self.pulled_up_scalar_agg = true; + } let new_plan = LogicalPlanBuilder::from((*aggregate.input).clone()) .aggregate(missing_exprs, aggregate.aggr_expr.to_vec())? .build()?; @@ -485,9 +494,12 @@ fn agg_exprs_evaluation_result_on_empty_batch( let new_expr = match expr { Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => { if func.name() == "count" { - Transformed::yes(Expr::Literal(ScalarValue::Int64(Some(0)))) + Transformed::yes(Expr::Literal( + ScalarValue::Int64(Some(0)), + None, + )) } else { - Transformed::yes(Expr::Literal(ScalarValue::Null)) + Transformed::yes(Expr::Literal(ScalarValue::Null, None)) } } _ => Transformed::no(expr), @@ -578,10 +590,10 @@ fn filter_exprs_evaluation_result_on_empty_batch( let result_expr = simplifier.simplify(result_expr)?; match &result_expr { // evaluate to false or null on empty batch, no need to pull up - Expr::Literal(ScalarValue::Null) - | Expr::Literal(ScalarValue::Boolean(Some(false))) => None, + Expr::Literal(ScalarValue::Null, _) + | Expr::Literal(ScalarValue::Boolean(Some(false)), _) => None, // evaluate to true on empty batch, need to pull up the expr - Expr::Literal(ScalarValue::Boolean(Some(true))) => { + Expr::Literal(ScalarValue::Boolean(Some(true)), _) => { for (name, exprs) in input_expr_result_map_for_count_bug { expr_result_map_for_count_bug.insert(name.clone(), exprs.clone()); } @@ -596,7 +608,7 @@ fn filter_exprs_evaluation_result_on_empty_batch( Box::new(result_expr.clone()), Box::new(input_expr.clone()), )], - else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null))), + else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null, None))), }); let expr_key = new_expr.schema_name().to_string(); expr_result_map_for_count_bug.insert(expr_key, new_expr); diff --git a/datafusion/optimizer/src/decorrelate_lateral_join.rs b/datafusion/optimizer/src/decorrelate_lateral_join.rs new file mode 100644 index 000000000000..7d2072ad1ce9 --- /dev/null +++ b/datafusion/optimizer/src/decorrelate_lateral_join.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`DecorrelateLateralJoin`] decorrelates logical plans produced by lateral joins. + +use std::collections::BTreeSet; + +use crate::decorrelate::PullUpCorrelatedExpr; +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_expr::{lit, Join}; + +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; +use datafusion_common::Result; +use datafusion_expr::logical_plan::JoinType; +use datafusion_expr::utils::conjunction; +use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; + +/// Optimizer rule for rewriting lateral joins to joins +#[derive(Default, Debug)] +pub struct DecorrelateLateralJoin {} + +impl DecorrelateLateralJoin { + #[allow(missing_docs)] + pub fn new() -> Self { + Self::default() + } +} + +impl OptimizerRule for DecorrelateLateralJoin { + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + // Find cross joins with outer column references on the right side (i.e., the apply operator). + let LogicalPlan::Join(join) = plan else { + return Ok(Transformed::no(plan)); + }; + + rewrite_internal(join) + } + + fn name(&self) -> &str { + "decorrelate_lateral_join" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } +} + +// Build the decorrelated join based on the original lateral join query. For now, we only support cross/inner +// lateral joins. +fn rewrite_internal(join: Join) -> Result> { + if join.join_type != JoinType::Inner { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + + match join.right.apply_with_subqueries(|p| { + // TODO: support outer joins + if p.contains_outer_reference() { + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + })? { + TreeNodeRecursion::Stop => {} + TreeNodeRecursion::Continue => { + // The left side contains outer references, we need to decorrelate it. + return Ok(Transformed::new( + LogicalPlan::Join(join), + false, + TreeNodeRecursion::Jump, + )); + } + TreeNodeRecursion::Jump => { + unreachable!("") + } + } + + let LogicalPlan::Subquery(subquery) = join.right.as_ref() else { + return Ok(Transformed::no(LogicalPlan::Join(join))); + }; + + if join.join_type != JoinType::Inner { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + let subquery_plan = subquery.subquery.as_ref(); + let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true); + let rewritten_subquery = subquery_plan.clone().rewrite(&mut pull_up).data()?; + if !pull_up.can_pull_up { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + + let mut all_correlated_cols = BTreeSet::new(); + pull_up + .correlated_subquery_cols_map + .values() + .for_each(|cols| all_correlated_cols.extend(cols.clone())); + let join_filter_opt = conjunction(pull_up.join_filters); + let join_filter = match join_filter_opt { + Some(join_filter) => join_filter, + None => lit(true), + }; + // -- inner join but the right side always has one row, we need to rewrite it to a left join + // SELECT * FROM t0, LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0); + // -- inner join but the right side number of rows is related to the filter (join) condition, so keep inner join. + // SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t0.v0 = t1.v0); + let new_plan = LogicalPlanBuilder::from(join.left) + .join_on( + rewritten_subquery, + if pull_up.pulled_up_scalar_agg { + JoinType::Left + } else { + JoinType::Inner + }, + Some(join_filter), + )? + .build()?; + // TODO: handle count(*) bug + Ok(Transformed::new(new_plan, true, TreeNodeRecursion::Jump)) +} diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index c18c48251daa..a72657bf689d 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -427,17 +427,23 @@ mod tests { use super::*; use crate::test::*; + use crate::assert_optimized_plan_eq_display_indent_snapshot; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_expr::builder::table_source; use datafusion_expr::{and, binary_expr, col, lit, not, out_ref_col, table_scan}; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), - plan, - expected, - ); - Ok(()) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(DecorrelatePredicateSubquery::new()); + assert_optimized_plan_eq_display_indent_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; } fn test_subquery_with_name(name: &str) -> Result> { @@ -461,17 +467,21 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq_1.c [c:UInt32]\ - \n TableScan: sq_1 [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [c:UInt32]\ - \n Projection: sq_2.c [c:UInt32]\ - \n TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq_1.c [c:UInt32] + TableScan: sq_1 [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_2 [c:UInt32] + Projection: sq_2.c [c:UInt32] + TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for IN subquery with additional AND filter @@ -489,15 +499,18 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.a = UInt32(1) AND test.b < UInt32(30) [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + Filter: test.a = UInt32(1) AND test.b < UInt32(30) [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for nested IN subqueries @@ -515,18 +528,21 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.b = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [a:UInt32]\ - \n Projection: sq.a [a:UInt32]\ - \n LeftSemi Join: Filter: sq.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq_nested.c [c:UInt32]\ - \n TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.b = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_2 [a:UInt32] + Projection: sq.a [a:UInt32] + LeftSemi Join: Filter: sq.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq_nested.c [c:UInt32] + TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test multiple correlated subqueries @@ -551,23 +567,21 @@ mod tests { .build()?; debug!("plan to optimize:\n{}", plan.display_indent()); - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), - plan, - expected, - ); - Ok(()) + assert_optimized_plan_equal!( + plan, + @r###" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __correlated_sq_2 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + "### + ) } /// Test recursive correlated subqueries @@ -601,23 +615,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]\ - \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ - \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_2 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64] + Projection: lineitem.l_orderkey [l_orderkey:Int64] + TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64] + " + ) } /// Test for correlated IN subquery filter with additional subquery filters @@ -639,20 +651,18 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN subquery with no columns in schema @@ -673,19 +683,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for IN subquery with both columns in schema @@ -703,20 +711,18 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN subquery not equal @@ -737,19 +743,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN subquery less than @@ -770,19 +774,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN subquery filter with subquery disjunction @@ -804,20 +806,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND (customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1)) [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64]\ - \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND (customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1)) [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64] + Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN without projection @@ -861,19 +860,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey + Int32(1) = __correlated_sq_1.o_custkey AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey + Int32(1) = __correlated_sq_1.o_custkey AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN expressions @@ -894,19 +891,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.orders.o_custkey + Int32(1) AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ - \n Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.orders.o_custkey + Int32(1) AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64] + Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN subquery multiple projected columns @@ -959,20 +954,18 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated IN subquery filter @@ -990,19 +983,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\ - \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.c = __correlated_sq_1.c AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32] + Projection: sq.c, sq.a [c:UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for single IN subquery filter @@ -1014,19 +1005,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for single NOT IN subquery filter @@ -1038,19 +1027,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1061,19 +1048,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1087,19 +1072,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1116,19 +1099,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32]\ - \n Projection: sq.c * UInt32(2) [sq.c * UInt32(2):UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32] + Projection: sq.c * UInt32(2) [sq.c * UInt32(2):UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1150,20 +1131,18 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32]\ - \n Projection: sq.c * UInt32(2), sq.a [sq.c * UInt32(2):UInt32, a:UInt32]\ - \n Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32] + Projection: sq.c * UInt32(2), sq.a [sq.c * UInt32(2):UInt32, a:UInt32] + Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1186,20 +1165,18 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a + test.b = __correlated_sq_1.a + __correlated_sq_1.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32]\ - \n Projection: sq.c * UInt32(2), sq.a, sq.b [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32]\ - \n Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq.c * UInt32(2) AND test.a + test.b = __correlated_sq_1.a + __correlated_sq_1.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32] + Projection: sq.c * UInt32(2), sq.a, sq.b [sq.c * UInt32(2):UInt32, a:UInt32, b:UInt32] + Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1228,24 +1205,22 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c * UInt32(2) = __correlated_sq_2.sq2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [sq1.c * UInt32(2):UInt32, a:UInt32]\ - \n Projection: sq1.c * UInt32(2), sq1.a [sq1.c * UInt32(2):UInt32, a:UInt32]\ - \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [sq2.c * UInt32(2):UInt32, a:UInt32]\ - \n Projection: sq2.c * UInt32(2), sq2.a [sq2.c * UInt32(2):UInt32, a:UInt32]\ - \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join: Filter: test.c * UInt32(2) = __correlated_sq_2.sq2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.sq1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [sq1.c * UInt32(2):UInt32, a:UInt32] + Projection: sq1.c * UInt32(2), sq1.a [sq1.c * UInt32(2):UInt32, a:UInt32] + TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_2 [sq2.c * UInt32(2):UInt32, a:UInt32] + Projection: sq2.c * UInt32(2), sq2.a [sq2.c * UInt32(2):UInt32, a:UInt32] + TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1263,20 +1238,18 @@ mod tests { .build()?; // Subquery and outer query refer to the same table. - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: test.c [c:UInt32]\ - \n Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: test.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: test.c [c:UInt32] + Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for multiple exists subqueries in the same filter expression @@ -1297,17 +1270,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: Filter: __correlated_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8] + LeftSemi Join: Filter: __correlated_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __correlated_sq_2 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test recursive correlated subqueries @@ -1340,17 +1317,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_1.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]\ - \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ - \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_2 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + LeftSemi Join: Filter: __correlated_sq_1.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64] + Projection: lineitem.l_orderkey [l_orderkey:Int64] + TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64] + " + ) } /// Test for correlated exists subquery filter with additional subquery filters @@ -1372,15 +1353,18 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } #[test] @@ -1398,14 +1382,17 @@ mod tests { .build()?; // Other rule will pushdown `customer.c_custkey = 1`, - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = UInt32(1) [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for exists subquery with both columns in schema @@ -1423,14 +1410,18 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists subquery not equal @@ -1451,14 +1442,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists subquery less than @@ -1479,14 +1473,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists subquery filter with subquery disjunction @@ -1508,14 +1505,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64]\ - \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1) [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64] + Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists without projection @@ -1535,13 +1535,16 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists expressions @@ -1562,14 +1565,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ - \n Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [orders.o_custkey + Int32(1):Int64, o_custkey:Int64] + Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists subquery filter with additional filters @@ -1589,15 +1595,18 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] + LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated exists subquery filter with disjunctions @@ -1615,16 +1624,19 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean]\ - \n LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ - \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean] + LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated EXISTS subquery filter @@ -1642,14 +1654,17 @@ mod tests { .project(vec![col("test.c")])? .build()?; - let expected = "Projection: test.c [c:UInt32]\ - \n LeftSemi Join: Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\ - \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.c [c:UInt32] + LeftSemi Join: Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32] + Projection: sq.c, sq.a [c:UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for single exists subquery filter @@ -1661,13 +1676,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for single NOT exists subquery filter @@ -1679,13 +1698,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftAnti Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftAnti Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: sq.c [c:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1712,19 +1735,22 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.a = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\ - \n Projection: sq1.c, sq1.a [c:UInt32, a:UInt32]\ - \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [c:UInt32, a:UInt32]\ - \n Projection: sq2.c, sq2.a [c:UInt32, a:UInt32]\ - \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join: Filter: test.a = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32] + LeftSemi Join: Filter: test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32] + Projection: sq1.c, sq1.a [c:UInt32, a:UInt32] + TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_2 [c:UInt32, a:UInt32] + Projection: sq2.c, sq2.a [c:UInt32, a:UInt32] + TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1743,14 +1769,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, a:UInt32]\ - \n Projection: UInt32(1), sq.a [UInt32(1):UInt32, a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, a:UInt32] + Projection: UInt32(1), sq.a [UInt32(1):UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1768,15 +1797,18 @@ mod tests { .build()?; // Subquery and outer query refer to the same table. - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: test.c [c:UInt32]\ - \n Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32] + Projection: test.c [c:UInt32] + Filter: test.a > test.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1796,15 +1828,18 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\ - \n Distinct: [c:UInt32, a:UInt32]\ - \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32] + Distinct: [c:UInt32, a:UInt32] + Projection: sq.c, sq.a [c:UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1824,15 +1859,18 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [sq.b + sq.c:UInt32, a:UInt32]\ - \n Distinct: [sq.b + sq.c:UInt32, a:UInt32]\ - \n Projection: sq.b + sq.c, sq.a [sq.b + sq.c:UInt32, a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [sq.b + sq.c:UInt32, a:UInt32] + Distinct: [sq.b + sq.c:UInt32, a:UInt32] + Projection: sq.b + sq.c, sq.a [sq.b + sq.c:UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1852,15 +1890,18 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, c:UInt32, a:UInt32]\ - \n Distinct: [UInt32(1):UInt32, c:UInt32, a:UInt32]\ - \n Projection: UInt32(1), sq.c, sq.a [UInt32(1):UInt32, c:UInt32, a:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: UInt32(1) + __correlated_sq_1.a > test.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [UInt32(1):UInt32, c:UInt32, a:UInt32] + Distinct: [UInt32(1):UInt32, c:UInt32, a:UInt32] + Projection: UInt32(1), sq.c, sq.a [UInt32(1):UInt32, c:UInt32, a:UInt32] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1884,13 +1925,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [arr:Int32;N]\ - \n Unnest: lists[sq.arr|depth=1] structs[] [arr:Int32;N]\ - \n TableScan: sq [arr:List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r#" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [arr:Int32;N] + Unnest: lists[sq.arr|depth=1] structs[] [arr:Int32;N] + TableScan: sq [arr:List(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N] + "# + ) } #[test] @@ -1915,14 +1960,17 @@ mod tests { .project(vec![col("test.b")])? .build()?; - let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: __correlated_sq_1.a = test.b [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [a:UInt32;N]\ - \n Unnest: lists[sq.a|depth=1] structs[] [a:UInt32;N]\ - \n TableScan: sq [a:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r#" + Projection: test.b [b:UInt32] + LeftSemi Join: Filter: __correlated_sq_1.a = test.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __correlated_sq_1 [a:UInt32;N] + Unnest: lists[sq.a|depth=1] structs[] [a:UInt32;N] + TableScan: sq [a:List(Field { name: "item", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N] + "# + ) } #[test] @@ -1946,13 +1994,16 @@ mod tests { .project(vec![col("\"TEST_A\".\"B\"")])? .build()?; - let expected = "Projection: TEST_A.B [B:UInt32]\ - \n LeftSemi Join: Filter: __correlated_sq_1.A = TEST_A.A [A:UInt32, B:UInt32]\ - \n TableScan: TEST_A [A:UInt32, B:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [Int32(1):Int32, A:UInt32]\ - \n Projection: Int32(1), TEST_B.A [Int32(1):Int32, A:UInt32]\ - \n TableScan: TEST_B [A:UInt32, B:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: TEST_A.B [B:UInt32] + LeftSemi Join: Filter: __correlated_sq_1.A = TEST_A.A [A:UInt32, B:UInt32] + TableScan: TEST_A [A:UInt32, B:UInt32] + SubqueryAlias: __correlated_sq_1 [Int32(1):Int32, A:UInt32] + Projection: Int32(1), TEST_B.A [Int32(1):Int32, A:UInt32] + TableScan: TEST_B [A:UInt32, B:UInt32] + " + ) } } diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index d35572e6d34a..ae1d7df46d52 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use crate::join_key_set::JoinKeySet; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::Result; +use datafusion_common::{NullEquality, Result}; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, @@ -89,6 +89,7 @@ impl OptimizerRule for EliminateCrossJoin { let mut possible_join_keys = JoinKeySet::new(); let mut all_inputs: Vec = vec![]; let mut all_filters: Vec = vec![]; + let mut null_equality = NullEquality::NullEqualsNothing; let parent_predicate = if let LogicalPlan::Filter(filter) = plan { // if input isn't a join that can potentially be rewritten @@ -113,6 +114,12 @@ impl OptimizerRule for EliminateCrossJoin { let Filter { input, predicate, .. } = filter; + + // Extract null_equality setting from the input join + if let LogicalPlan::Join(join) = input.as_ref() { + null_equality = join.null_equality; + } + flatten_join_inputs( Arc::unwrap_or_clone(input), &mut possible_join_keys, @@ -122,26 +129,30 @@ impl OptimizerRule for EliminateCrossJoin { extract_possible_join_keys(&predicate, &mut possible_join_keys); Some(predicate) - } else if matches!( - plan, - LogicalPlan::Join(Join { - join_type: JoinType::Inner, - .. - }) - ) { - if !can_flatten_join_inputs(&plan) { - return Ok(Transformed::no(plan)); - } - flatten_join_inputs( - plan, - &mut possible_join_keys, - &mut all_inputs, - &mut all_filters, - )?; - None } else { - // recursively try to rewrite children - return rewrite_children(self, plan, config); + match plan { + LogicalPlan::Join(Join { + join_type: JoinType::Inner, + null_equality: original_null_equality, + .. + }) => { + if !can_flatten_join_inputs(&plan) { + return Ok(Transformed::no(plan)); + } + flatten_join_inputs( + plan, + &mut possible_join_keys, + &mut all_inputs, + &mut all_filters, + )?; + null_equality = original_null_equality; + None + } + _ => { + // recursively try to rewrite children + return rewrite_children(self, plan, config); + } + } }; // Join keys are handled locally: @@ -153,6 +164,7 @@ impl OptimizerRule for EliminateCrossJoin { &mut all_inputs, &possible_join_keys, &mut all_join_keys, + null_equality, )?; } @@ -290,6 +302,7 @@ fn find_inner_join( rights: &mut Vec, possible_join_keys: &JoinKeySet, all_join_keys: &mut JoinKeySet, + null_equality: NullEquality, ) -> Result { for (i, right_input) in rights.iter().enumerate() { let mut join_keys = vec![]; @@ -328,7 +341,7 @@ fn find_inner_join( on: join_keys, filter: None, schema: join_schema, - null_equals_null: false, + null_equality, })); } } @@ -350,7 +363,7 @@ fn find_inner_join( filter: None, join_type: JoinType::Inner, join_constraint: JoinConstraint::On, - null_equals_null: false, + null_equality, })) } @@ -440,22 +453,28 @@ mod tests { logical_plan::builder::LogicalPlanBuilder, Operator::{And, Or}, }; + use insta::assert_snapshot; + + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let starting_schema = Arc::clone($plan.schema()); + let rule = EliminateCrossJoin::new(); + let Transformed {transformed: is_plan_transformed, data: optimized_plan, ..} = rule.rewrite($plan, &OptimizerContext::new()).unwrap(); + let formatted_plan = optimized_plan.display_indent_schema(); + // Ensure the rule was actually applied + assert!(is_plan_transformed, "failed to optimize plan"); + // Verify the schema remains unchanged + assert_eq!(&starting_schema, optimized_plan.schema()); + assert_snapshot!( + formatted_plan, + @ $expected, + ); - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: Vec<&str>) { - let starting_schema = Arc::clone(plan.schema()); - let rule = EliminateCrossJoin::new(); - let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); - assert!(transformed_plan.transformed, "failed to optimize plan"); - let optimized_plan = transformed_plan.data; - let formatted = optimized_plan.display_indent_schema().to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - - assert_eq!(&starting_schema, optimized_plan.schema()) + Ok(()) + }}; } #[test] @@ -473,16 +492,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -501,16 +519,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -528,16 +545,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -559,15 +575,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -589,15 +605,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -615,15 +631,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -644,19 +660,18 @@ mod tests { .filter(col("t1.a").gt(lit(15u32)))? .build()?; - let expected = vec![ - "Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]" - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -691,19 +706,18 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -765,22 +779,21 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -840,22 +853,21 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -915,22 +927,21 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -994,22 +1005,21 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1083,21 +1093,20 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1177,20 +1186,19 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1208,15 +1216,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1235,16 +1243,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1263,16 +1270,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1291,16 +1297,15 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); - - Ok(()) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -1328,17 +1333,81 @@ mod tests { ))? .build()?; - let expected = vec![ - "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - ]; - - assert_optimized_plan_eq(plan, expected); + assert_optimized_plan_equal!( + plan, + @ r" + Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) + } + + #[test] + fn preserve_null_equality_setting() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Create an inner join with NullEquality::NullEqualsNull + let join_schema = Arc::new(build_join_schema( + t1.schema(), + t2.schema(), + &JoinType::Inner, + )?); + + let inner_join = LogicalPlan::Join(Join { + left: Arc::new(t1), + right: Arc::new(t2), + join_type: JoinType::Inner, + join_constraint: JoinConstraint::On, + on: vec![], + filter: None, + schema: join_schema, + null_equality: NullEquality::NullEqualsNull, // Test preservation + }); + + // Apply filter that can create join conditions + let plan = LogicalPlanBuilder::from(inner_join) + .filter(binary_expr( + col("t1.a").eq(col("t2.a")), + And, + col("t2.c").lt(lit(20u32)), + ))? + .build()?; + + let rule = EliminateCrossJoin::new(); + let optimized_plan = rule.rewrite(plan, &OptimizerContext::new())?.data; + + // Verify that null_equality is preserved in the optimized plan + fn check_null_equality_preserved(plan: &LogicalPlan) -> bool { + match plan { + LogicalPlan::Join(join) => { + // All joins in the optimized plan should preserve null equality + if join.null_equality == NullEquality::NullEqualsNothing { + return false; + } + // Recursively check child plans + plan.inputs() + .iter() + .all(|input| check_null_equality_preserved(input)) + } + _ => { + // Recursively check child plans for non-join nodes + plan.inputs() + .iter() + .all(|input| check_null_equality_preserved(input)) + } + } + } + + assert!( + check_null_equality_preserved(&optimized_plan), + "null_equality setting should be preserved after optimization" + ); Ok(()) } diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 466950092095..a6651df938a7 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -118,16 +118,26 @@ impl OptimizerRule for EliminateDuplicatedExpr { #[cfg(test)] mod tests { use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; + use crate::OptimizerContext; use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; use std::sync::Arc; - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - crate::test::assert_optimized_plan_eq( - Arc::new(EliminateDuplicatedExpr::new()), - plan, - expected, - ) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateDuplicatedExpr::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -137,10 +147,12 @@ mod tests { .sort_by(vec![col("a"), col("a"), col("b"), col("c")])? .limit(5, Some(10))? .build()?; - let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Limit: skip=5, fetch=10 + Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST + TableScan: test + ") } #[test] @@ -156,9 +168,11 @@ mod tests { .sort(sort_exprs)? .limit(5, Some(10))? .build()?; - let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Limit: skip=5, fetch=10 + Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST + TableScan: test + ") } } diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index 4ed2ac8ba1a4..e28771be548b 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -60,7 +60,7 @@ impl OptimizerRule for EliminateFilter { ) -> Result> { match plan { LogicalPlan::Filter(Filter { - predicate: Expr::Literal(ScalarValue::Boolean(v)), + predicate: Expr::Literal(ScalarValue::Boolean(v), _), input, .. }) => match v { @@ -81,17 +81,29 @@ impl OptimizerRule for EliminateFilter { mod tests { use std::sync::Arc; + use crate::assert_optimized_plan_eq_snapshot; + use crate::OptimizerContext; use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::{ - col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan, - }; + use datafusion_expr::{col, lit, logical_plan::builder::LogicalPlanBuilder, Expr}; use crate::eliminate_filter::EliminateFilter; use crate::test::*; use datafusion_expr::test::function_stub::sum; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateFilter::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -105,13 +117,12 @@ mod tests { .build()?; // No aggregate / scan / limit - let expected = "EmptyRelation"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @"EmptyRelation") } #[test] fn filter_null() -> Result<()> { - let filter_expr = Expr::Literal(ScalarValue::Boolean(None)); + let filter_expr = Expr::Literal(ScalarValue::Boolean(None), None); let table_scan = test_table_scan().unwrap(); let plan = LogicalPlanBuilder::from(table_scan) @@ -120,8 +131,7 @@ mod tests { .build()?; // No aggregate / scan / limit - let expected = "EmptyRelation"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @"EmptyRelation") } #[test] @@ -139,11 +149,12 @@ mod tests { .build()?; // Left side is removed - let expected = "Union\ - \n EmptyRelation\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + EmptyRelation + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + ") } #[test] @@ -156,9 +167,10 @@ mod tests { .filter(filter_expr)? .build()?; - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + ") } #[test] @@ -176,12 +188,13 @@ mod tests { .build()?; // Filter is removed - let expected = "Union\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + ") } #[test] @@ -202,8 +215,9 @@ mod tests { .build()?; // Filter is removed - let expected = "Projection: test.a\ - \n EmptyRelation"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Projection: test.a + EmptyRelation + ") } } diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index 7e252d6dcea0..9c47ce024f91 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -101,7 +101,7 @@ fn is_constant_expression(expr: &Expr) -> bool { Expr::BinaryExpr(e) => { is_constant_expression(&e.left) && is_constant_expression(&e.right) } - Expr::Literal(_) => true, + Expr::Literal(_, _) => true, Expr::ScalarFunction(e) => { matches!( e.func.signature().volatility, @@ -115,7 +115,9 @@ fn is_constant_expression(expr: &Expr) -> bool { #[cfg(test)] mod tests { use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; + use crate::OptimizerContext; use arrow::datatypes::DataType; use datafusion_common::Result; @@ -129,6 +131,22 @@ mod tests { use std::sync::Arc; + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateGroupByConstant::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; + } + #[derive(Debug)] struct ScalarUDFMock { signature: Signature, @@ -167,17 +185,11 @@ mod tests { .aggregate(vec![col("a"), lit(1u32)], vec![count(col("c"))])? .build()?; - let expected = "\ - Projection: test.a, UInt32(1), count(test.c)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Projection: test.a, UInt32(1), count(test.c) + Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]] + TableScan: test + ") } #[test] @@ -187,17 +199,11 @@ mod tests { .aggregate(vec![lit("test"), lit(123u32)], vec![count(col("c"))])? .build()?; - let expected = "\ - Projection: Utf8(\"test\"), UInt32(123), count(test.c)\ - \n Aggregate: groupBy=[[]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r#" + Projection: Utf8("test"), UInt32(123), count(test.c) + Aggregate: groupBy=[[]], aggr=[[count(test.c)]] + TableScan: test + "#) } #[test] @@ -207,16 +213,10 @@ mod tests { .aggregate(vec![col("a"), col("b")], vec![count(col("c"))])? .build()?; - let expected = "\ - Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]] + TableScan: test + ") } #[test] @@ -226,16 +226,10 @@ mod tests { .aggregate(vec![lit(123u32)], Vec::::new())? .build()?; - let expected = "\ - Aggregate: groupBy=[[UInt32(123)]], aggr=[[]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[UInt32(123)]], aggr=[[]] + TableScan: test + ") } #[test] @@ -248,17 +242,11 @@ mod tests { )? .build()?; - let expected = "\ - Projection: UInt32(123) AS const, test.a, count(test.c)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Projection: UInt32(123) AS const, test.a, count(test.c) + Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]] + TableScan: test + ") } #[test] @@ -273,17 +261,11 @@ mod tests { .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])? .build()?; - let expected = "\ - Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c) + Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]] + TableScan: test + ") } #[test] @@ -298,15 +280,9 @@ mod tests { .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])? .build()?; - let expected = "\ - Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[count(test.c)]]\ - \n TableScan: test\ - "; - - assert_optimized_plan_eq( - Arc::new(EliminateGroupByConstant::new()), - plan, - expected, - ) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[count(test.c)]] + TableScan: test + ") } } diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index 789235595dab..dfc3a220d0f9 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -54,7 +54,7 @@ impl OptimizerRule for EliminateJoin { match plan { LogicalPlan::Join(join) if join.join_type == Inner && join.on.is_empty() => { match join.filter { - Some(Expr::Literal(ScalarValue::Boolean(Some(false)))) => Ok( + Some(Expr::Literal(ScalarValue::Boolean(Some(false)), _)) => Ok( Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: join.schema, @@ -74,15 +74,28 @@ impl OptimizerRule for EliminateJoin { #[cfg(test)] mod tests { + use crate::assert_optimized_plan_eq_snapshot; use crate::eliminate_join::EliminateJoin; - use crate::test::*; + use crate::OptimizerContext; use datafusion_common::Result; use datafusion_expr::JoinType::Inner; - use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder, LogicalPlan}; + use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use std::sync::Arc; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(EliminateJoin::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateJoin::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -95,7 +108,6 @@ mod tests { )? .build()?; - let expected = "EmptyRelation"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @"EmptyRelation") } } diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 5d3a1b223b7a..2007e0c82045 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -90,7 +90,6 @@ impl OptimizerRule for EliminateLimit { #[cfg(test)] mod tests { use super::*; - use crate::optimizer::Optimizer; use crate::test::*; use crate::OptimizerContext; use datafusion_common::Column; @@ -100,36 +99,43 @@ mod tests { }; use std::sync::Arc; + use crate::assert_optimized_plan_eq_snapshot; use crate::push_down_limit::PushDownLimit; use datafusion_expr::test::function_stub::sum; - fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - let optimizer = Optimizer::with_rules(vec![Arc::new(EliminateLimit::new())]); - let optimized_plan = - optimizer.optimize(plan, &OptimizerContext::new(), observe)?; - - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); - Ok(()) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rules: Vec> = vec![Arc::new(EliminateLimit::new())]; + let optimizer_ctx = OptimizerContext::new(); + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } - fn assert_optimized_plan_eq_with_pushdown( - plan: LogicalPlan, - expected: &str, - ) -> Result<()> { - fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - let config = OptimizerContext::new().with_max_passes(1); - let optimizer = Optimizer::with_rules(vec![ - Arc::new(PushDownLimit::new()), - Arc::new(EliminateLimit::new()), - ]); - let optimized_plan = optimizer - .optimize(plan, &config, observe) - .expect("failed to optimize plan"); - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); - Ok(()) + macro_rules! assert_optimized_plan_eq_with_pushdown { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![ + Arc::new(PushDownLimit::new()), + Arc::new(EliminateLimit::new()) + ]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -140,8 +146,10 @@ mod tests { .limit(0, Some(0))? .build()?; // No aggregate / scan / limit - let expected = "EmptyRelation"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r"EmptyRelation" + ) } #[test] @@ -157,11 +165,15 @@ mod tests { .build()?; // Left side is removed - let expected = "Union\ - \n EmptyRelation\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Union + EmptyRelation + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + " + ) } #[test] @@ -174,8 +186,10 @@ mod tests { .build()?; // No aggregate / scan / limit - let expected = "EmptyRelation"; - assert_optimized_plan_eq_with_pushdown(plan, expected) + assert_optimized_plan_eq_with_pushdown!( + plan, + @ "EmptyRelation" + ) } #[test] @@ -190,12 +204,16 @@ mod tests { // After remove global-state, we don't record the parent // So, bottom don't know parent info, so can't eliminate. - let expected = "Limit: skip=2, fetch=1\ - \n Sort: test.a ASC NULLS LAST, fetch=3\ - \n Limit: skip=0, fetch=2\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_eq_with_pushdown(plan, expected) + assert_optimized_plan_eq_with_pushdown!( + plan, + @ r" + Limit: skip=2, fetch=1 + Sort: test.a ASC NULLS LAST, fetch=3 + Limit: skip=0, fetch=2 + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + " + ) } #[test] @@ -208,12 +226,16 @@ mod tests { .limit(0, Some(1))? .build()?; - let expected = "Limit: skip=0, fetch=1\ - \n Sort: test.a ASC NULLS LAST\ - \n Limit: skip=0, fetch=2\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Limit: skip=0, fetch=1 + Sort: test.a ASC NULLS LAST + Limit: skip=0, fetch=2 + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + " + ) } #[test] @@ -226,12 +248,16 @@ mod tests { .limit(3, Some(1))? .build()?; - let expected = "Limit: skip=3, fetch=1\ - \n Sort: test.a ASC NULLS LAST\ - \n Limit: skip=2, fetch=1\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Limit: skip=3, fetch=1 + Sort: test.a ASC NULLS LAST + Limit: skip=2, fetch=1 + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + " + ) } #[test] @@ -248,12 +274,16 @@ mod tests { .limit(3, Some(1))? .build()?; - let expected = "Limit: skip=3, fetch=1\ - \n Inner Join: Using test.a = test1.a\ - \n Limit: skip=2, fetch=1\ - \n TableScan: test\ - \n TableScan: test1"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Limit: skip=3, fetch=1 + Inner Join: Using test.a = test1.a + Limit: skip=2, fetch=1 + TableScan: test + TableScan: test1 + " + ) } #[test] @@ -264,8 +294,12 @@ mod tests { .limit(0, None)? .build()?; - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] + TableScan: test + " + ) } } diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 94da08243d78..f8f93727cd9b 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -116,7 +116,8 @@ mod tests { use super::*; use crate::analyzer::type_coercion::TypeCoercion; use crate::analyzer::Analyzer; - use crate::test::*; + use crate::assert_optimized_plan_eq_snapshot; + use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_expr::{col, logical_plan::table_scan}; @@ -129,15 +130,23 @@ mod tests { ]) } - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - let options = ConfigOptions::default(); - let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]) - .execute_and_check(plan, &options, |_, _| {})?; - assert_optimized_plan_eq( - Arc::new(EliminateNestedUnion::new()), - analyzed_plan, - expected, - ) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let options = ConfigOptions::default(); + let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]) + .execute_and_check($plan, &options, |_, _| {})?; + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateNestedUnion::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + analyzed_plan, + @ $expected, + ) + }}; } #[test] @@ -146,11 +155,11 @@ mod tests { let plan = plan_builder.clone().union(plan_builder.build()?)?.build()?; - let expected = "\ - Union\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + TableScan: table + TableScan: table + ") } #[test] @@ -162,11 +171,12 @@ mod tests { .union_distinct(plan_builder.build()?)? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table + TableScan: table + ") } #[test] @@ -180,13 +190,13 @@ mod tests { .union(plan_builder.build()?)? .build()?; - let expected = "\ - Union\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + TableScan: table + TableScan: table + TableScan: table + TableScan: table + ") } #[test] @@ -200,14 +210,15 @@ mod tests { .union(plan_builder.build()?)? .build()?; - let expected = "Union\ - \n Distinct:\ - \n Union\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + Distinct: + Union + TableScan: table + TableScan: table + TableScan: table + TableScan: table + ") } #[test] @@ -222,14 +233,15 @@ mod tests { .union_distinct(plan_builder.build()?)? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table + TableScan: table + TableScan: table + TableScan: table + TableScan: table + ") } #[test] @@ -243,13 +255,14 @@ mod tests { .union_distinct(plan_builder.build()?)? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table + TableScan: table + TableScan: table + TableScan: table + ") } // We don't need to use project_with_column_index in logical optimizer, @@ -273,13 +286,14 @@ mod tests { )? .build()?; - let expected = "Union\ - \n TableScan: table\ - \n Projection: table.id AS id, table.key, table.value\ - \n TableScan: table\ - \n Projection: table.id AS id, table.key, table.value\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + TableScan: table + Projection: table.id AS id, table.key, table.value + TableScan: table + Projection: table.id AS id, table.key, table.value + TableScan: table + ") } #[test] @@ -301,14 +315,15 @@ mod tests { )? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table\ - \n Projection: table.id AS id, table.key, table.value\ - \n TableScan: table\ - \n Projection: table.id AS id, table.key, table.value\ - \n TableScan: table"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table + Projection: table.id AS id, table.key, table.value + TableScan: table + Projection: table.id AS id, table.key, table.value + TableScan: table + ") } #[test] @@ -348,13 +363,14 @@ mod tests { .union(table_3.build()?)? .build()?; - let expected = "Union\ - \n TableScan: table_1\ - \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ - \n TableScan: table_1\ - \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ - \n TableScan: table_1"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Union + TableScan: table_1 + Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value + TableScan: table_1 + Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value + TableScan: table_1 + ") } #[test] @@ -394,13 +410,14 @@ mod tests { .union_distinct(table_3.build()?)? .build()?; - let expected = "Distinct:\ - \n Union\ - \n TableScan: table_1\ - \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ - \n TableScan: table_1\ - \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ - \n TableScan: table_1"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + TableScan: table_1 + Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value + TableScan: table_1 + Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value + TableScan: table_1 + ") } } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 1ecb32ca2a43..45877642f276 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -118,7 +118,7 @@ impl OptimizerRule for EliminateOuterJoin { on: join.on.clone(), filter: join.filter.clone(), schema: Arc::clone(&join.schema), - null_equals_null: join.null_equals_null, + null_equality: join.null_equality, })); Filter::try_new(filter.predicate, new_join) .map(|f| Transformed::yes(LogicalPlan::Filter(f))) @@ -304,7 +304,9 @@ fn extract_non_nullable_columns( #[cfg(test)] mod tests { use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; + use crate::OptimizerContext; use arrow::datatypes::DataType; use datafusion_expr::{ binary_expr, cast, col, lit, @@ -313,8 +315,20 @@ mod tests { Operator::{And, Or}, }; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(EliminateOuterJoin::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(EliminateOuterJoin::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -332,12 +346,13 @@ mod tests { )? .filter(col("t2.b").is_null())? .build()?; - let expected = "\ - Filter: t2.b IS NULL\ - \n Left Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b IS NULL + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -355,12 +370,13 @@ mod tests { )? .filter(col("t2.b").is_not_null())? .build()?; - let expected = "\ - Filter: t2.b IS NOT NULL\ - \n Inner Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b IS NOT NULL + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -382,12 +398,13 @@ mod tests { col("t1.c").lt(lit(20u32)), ))? .build()?; - let expected = "\ - Filter: t1.b > UInt32(10) OR t1.c < UInt32(20)\ - \n Inner Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: t1.b > UInt32(10) OR t1.c < UInt32(20) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -409,12 +426,13 @@ mod tests { col("t2.c").lt(lit(20u32)), ))? .build()?; - let expected = "\ - Filter: t1.b > UInt32(10) AND t2.c < UInt32(20)\ - \n Inner Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: t1.b > UInt32(10) AND t2.c < UInt32(20) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -436,11 +454,12 @@ mod tests { try_cast(col("t2.c"), DataType::Int64).lt(lit(20u32)), ))? .build()?; - let expected = "\ - Filter: CAST(t1.b AS Int64) > UInt32(10) AND TRY_CAST(t2.c AS Int64) < UInt32(20)\ - \n Inner Join: t1.a = t2.a\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Filter: CAST(t1.b AS Int64) > UInt32(10) AND TRY_CAST(t2.c AS Int64) < UInt32(20) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") } } diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 48191ec20631..55cf33ef4304 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -75,7 +75,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_type, join_constraint, schema, - null_equals_null, + null_equality, }) => { let left_schema = left.schema(); let right_schema = right.schema(); @@ -92,7 +92,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_type, join_constraint, schema, - null_equals_null, + null_equality, }))) } else { Ok(Transformed::no(LogicalPlan::Join(Join { @@ -103,7 +103,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_type, join_constraint, schema, - null_equals_null, + null_equality, }))) } } @@ -155,6 +155,7 @@ fn split_eq_and_noneq_join_predicate( #[cfg(test)] mod tests { use super::*; + use crate::assert_optimized_plan_eq_display_indent_snapshot; use crate::test::*; use arrow::datatypes::DataType; use datafusion_expr::{ @@ -162,14 +163,18 @@ mod tests { }; use std::sync::Arc; - fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq_display_indent( - Arc::new(ExtractEquijoinPredicate {}), - plan, - expected, - ); - - Ok(()) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(ExtractEquijoinPredicate {}); + assert_optimized_plan_eq_display_indent_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; } #[test] @@ -180,11 +185,15 @@ mod tests { let plan = LogicalPlanBuilder::from(t1) .join_on(t2, JoinType::Left, Some(col("t1.a").eq(col("t2.a"))))? .build()?; - let expected = "Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -199,11 +208,15 @@ mod tests { Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))), )? .build()?; - let expected = "Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -222,11 +235,15 @@ mod tests { ), )? .build()?; - let expected = "Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -249,11 +266,15 @@ mod tests { ), )? .build()?; - let expected = "Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -275,11 +296,15 @@ mod tests { ), )? .build()?; - let expected = "Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -310,13 +335,17 @@ mod tests { ), )? .build()?; - let expected = "Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -343,13 +372,17 @@ mod tests { Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))), )? .build()?; - let expected = "Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -369,10 +402,14 @@ mod tests { let plan = LogicalPlanBuilder::from(t1) .join_on(t2, JoinType::Left, Some(filter))? .build()?; - let expected = "Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\ - \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] + " + ) } } diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 2e7a751ca4c5..8ad7fa53c0e3 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -21,7 +21,7 @@ use crate::optimizer::ApplyOrder; use crate::push_down_filter::on_lr_is_preserved; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; -use datafusion_common::Result; +use datafusion_common::{NullEquality, Result}; use datafusion_expr::utils::conjunction; use datafusion_expr::{logical_plan::Filter, Expr, ExprSchemable, LogicalPlan}; use std::sync::Arc; @@ -51,7 +51,8 @@ impl OptimizerRule for FilterNullJoinKeys { } match plan { LogicalPlan::Join(mut join) - if !join.on.is_empty() && !join.null_equals_null => + if !join.on.is_empty() + && join.null_equality == NullEquality::NullEqualsNothing => { let (left_preserved, right_preserved) = on_lr_is_preserved(join.join_type); @@ -107,35 +108,52 @@ fn create_not_null_predicate(filters: Vec) -> Expr { #[cfg(test)] mod tests { use super::*; - use crate::test::assert_optimized_plan_eq; + use crate::assert_optimized_plan_eq_snapshot; + use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Column; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{col, lit, JoinType, LogicalPlanBuilder}; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(FilterNullJoinKeys {})]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] fn left_nullable() -> Result<()> { let (t1, t2) = test_tables()?; let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?; - let expected = "Inner Join: t1.optional_id = t2.id\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.optional_id = t2.id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + TableScan: t2 + ") } #[test] fn left_nullable_left_join() -> Result<()> { let (t1, t2) = test_tables()?; let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Left)?; - let expected = "Left Join: t1.optional_id = t2.id\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Left Join: t1.optional_id = t2.id + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -144,22 +162,26 @@ mod tests { // Note: order of tables is reversed let plan = build_plan(t_right, t_left, "t2.id", "t1.optional_id", JoinType::Left)?; - let expected = "Left Join: t2.id = t1.optional_id\ - \n TableScan: t2\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Left Join: t2.id = t1.optional_id + TableScan: t2 + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + ") } #[test] fn left_nullable_on_condition_reversed() -> Result<()> { let (t1, t2) = test_tables()?; let plan = build_plan(t1, t2, "t2.id", "t1.optional_id", JoinType::Inner)?; - let expected = "Inner Join: t1.optional_id = t2.id\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.optional_id = t2.id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -189,14 +211,16 @@ mod tests { None, )? .build()?; - let expected = "Inner Join: t3.t1_id = t1.id, t3.t2_id = t2.id\ - \n Filter: t3.t1_id IS NOT NULL AND t3.t2_id IS NOT NULL\ - \n TableScan: t3\ - \n Inner Join: t1.optional_id = t2.id\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t3.t1_id = t1.id, t3.t2_id = t2.id + Filter: t3.t1_id IS NOT NULL AND t3.t2_id IS NOT NULL + TableScan: t3 + Inner Join: t1.optional_id = t2.id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -213,11 +237,13 @@ mod tests { None, )? .build()?; - let expected = "Inner Join: t1.optional_id + UInt32(1) = t2.id + UInt32(1)\ - \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\ - \n TableScan: t1\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.optional_id + UInt32(1) = t2.id + UInt32(1) + Filter: t1.optional_id + UInt32(1) IS NOT NULL + TableScan: t1 + TableScan: t2 + ") } #[test] @@ -234,11 +260,13 @@ mod tests { None, )? .build()?; - let expected = "Inner Join: t1.id + UInt32(1) = t2.optional_id + UInt32(1)\ - \n TableScan: t1\ - \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.id + UInt32(1) = t2.optional_id + UInt32(1) + TableScan: t1 + Filter: t2.optional_id + UInt32(1) IS NOT NULL + TableScan: t2 + ") } #[test] @@ -255,13 +283,14 @@ mod tests { None, )? .build()?; - let expected = - "Inner Join: t1.optional_id + UInt32(1) = t2.optional_id + UInt32(1)\ - \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\ - \n TableScan: t1\ - \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan, expected) + + assert_optimized_plan_equal!(plan, @r" + Inner Join: t1.optional_id + UInt32(1) = t2.optional_id + UInt32(1) + Filter: t1.optional_id + UInt32(1) IS NOT NULL + TableScan: t1 + Filter: t2.optional_id + UInt32(1) IS NOT NULL + TableScan: t2 + ") } #[test] @@ -283,13 +312,22 @@ mod tests { None, )? .build()?; - let expected = "Inner Join: t1.optional_id = t2.optional_id\ - \n Filter: t1.optional_id IS NOT NULL\ - \n TableScan: t1\ - \n Filter: t2.optional_id IS NOT NULL\ - \n TableScan: t2"; - assert_optimized_plan_equal(plan_from_cols, expected)?; - assert_optimized_plan_equal(plan_from_exprs, expected) + + assert_optimized_plan_equal!(plan_from_cols, @r" + Inner Join: t1.optional_id = t2.optional_id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + Filter: t2.optional_id IS NOT NULL + TableScan: t2 + ")?; + + assert_optimized_plan_equal!(plan_from_exprs, @r" + Inner Join: t1.optional_id = t2.optional_id + Filter: t1.optional_id IS NOT NULL + TableScan: t1 + Filter: t2.optional_id IS NOT NULL + TableScan: t2 + ") } fn build_plan( diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 893cb249a2a8..280010e3d92c 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -40,6 +40,7 @@ pub mod analyzer; pub mod common_subexpr_eliminate; pub mod decorrelate; +pub mod decorrelate_lateral_join; pub mod decorrelate_predicate_subquery; pub mod eliminate_cross_join; pub mod eliminate_duplicated_expr; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index b3a09e2dcbcc..023ee4ea5a84 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -31,8 +31,7 @@ use datafusion_common::{ use datafusion_expr::expr::Alias; use datafusion_expr::Unnest; use datafusion_expr::{ - logical_plan::LogicalPlan, projection_schema, Aggregate, Distinct, Expr, Projection, - TableScan, Window, + logical_plan::LogicalPlan, Aggregate, Distinct, Expr, Projection, TableScan, Window, }; use crate::optimize_projections::required_indices::RequiredIndices; @@ -455,6 +454,17 @@ fn merge_consecutive_projections(proj: Projection) -> Result::new(); expr.iter() @@ -523,7 +533,7 @@ fn merge_consecutive_projections(proj: Projection) -> Result bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_)) + matches!(expr, Expr::Column(_) | Expr::Literal(_, _)) } /// Rewrites a projection expression using the projection before it (i.e. its input) @@ -573,8 +583,18 @@ fn is_expr_trivial(expr: &Expr) -> bool { fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { expr.transform_up(|expr| { match expr { - // remove any intermediate aliases - Expr::Alias(alias) => Ok(Transformed::yes(*alias.expr)), + // remove any intermediate aliases if they do not carry metadata + Expr::Alias(alias) => { + match alias + .metadata + .as_ref() + .map(|h| h.is_empty()) + .unwrap_or(true) + { + true => Ok(Transformed::yes(*alias.expr)), + false => Ok(Transformed::no(Expr::Alias(alias))), + } + } Expr::Column(col) => { // Find index of column: let idx = input.schema.index_of_column(&col)?; @@ -652,10 +672,10 @@ fn outer_columns_helper_multi<'a, 'b>( /// Depending on the join type, it divides the requirement indices into those /// that apply to the left child and those that apply to the right child. /// -/// - For `INNER`, `LEFT`, `RIGHT` and `FULL` joins, the requirements are split -/// between left and right children. The right child indices are adjusted to -/// point to valid positions within the right child by subtracting the length -/// of the left child. +/// - For `INNER`, `LEFT`, `RIGHT`, `FULL`, `LEFTMARK`, and `RIGHTMARK` joins, +/// the requirements are split between left and right children. The right +/// child indices are adjusted to point to valid positions within the right +/// child by subtracting the length of the left child. /// /// - For `LEFT ANTI`, `LEFT SEMI`, `RIGHT SEMI` and `RIGHT ANTI` joins, all /// requirements are re-routed to either the left child or the right child @@ -684,7 +704,8 @@ fn split_join_requirements( | JoinType::Left | JoinType::Right | JoinType::Full - | JoinType::LeftMark => { + | JoinType::LeftMark + | JoinType::RightMark => { // Decrease right side indices by `left_len` so that they point to valid // positions within the right child: indices.split_off(left_len) @@ -774,9 +795,24 @@ fn rewrite_projection_given_requirements( /// Projection is unnecessary, when /// - input schema of the projection, output schema of the projection are same, and /// - all projection expressions are either Column or Literal -fn is_projection_unnecessary(input: &LogicalPlan, proj_exprs: &[Expr]) -> Result { - let proj_schema = projection_schema(input, proj_exprs)?; - Ok(&proj_schema == input.schema() && proj_exprs.iter().all(is_expr_trivial)) +pub fn is_projection_unnecessary( + input: &LogicalPlan, + proj_exprs: &[Expr], +) -> Result { + // First check if the number of expressions is equal to the number of fields in the input schema. + if proj_exprs.len() != input.schema().fields().len() { + return Ok(false); + } + Ok(input.schema().iter().zip(proj_exprs.iter()).all( + |((field_relation, field_name), expr)| { + // Check if the expression is a column and if it matches the field name + if let Expr::Column(col) = expr { + col.relation.as_ref() == field_relation && col.name.eq(field_name.name()) + } else { + false + } + }, + )) } #[cfg(test)] @@ -791,8 +827,8 @@ mod tests { use crate::optimize_projections::OptimizeProjections; use crate::optimizer::Optimizer; use crate::test::{ - assert_fields_eq, assert_optimized_plan_eq, scan_empty, test_table_scan, - test_table_scan_fields, test_table_scan_with_name, + assert_fields_eq, scan_empty, test_table_scan, test_table_scan_fields, + test_table_scan_with_name, }; use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; @@ -810,13 +846,27 @@ mod tests { not, try_cast, when, BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator, Projection, UserDefinedLogicalNodeCore, WindowFunctionDefinition, }; + use insta::assert_snapshot; + use crate::assert_optimized_plan_eq_snapshot; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::{count, max, min}; use datafusion_functions_aggregate::min_max::max_udaf; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(OptimizeProjections::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[derive(Debug, Hash, PartialEq, Eq)] @@ -1005,9 +1055,13 @@ mod tests { .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? .build()?; - let expected = "Projection: Int32(1) + test.a\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: Int32(1) + test.a + TableScan: test projection=[a] + " + ) } #[test] @@ -1019,9 +1073,13 @@ mod tests { .project(vec![binary_expr(lit(1), Operator::Plus, col("a"))])? .build()?; - let expected = "Projection: Int32(1) + test.a\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: Int32(1) + test.a + TableScan: test projection=[a] + " + ) } #[test] @@ -1032,9 +1090,13 @@ mod tests { .project(vec![col("a").alias("alias")])? .build()?; - let expected = "Projection: test.a AS alias\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS alias + TableScan: test projection=[a] + " + ) } #[test] @@ -1045,9 +1107,13 @@ mod tests { .project(vec![col("alias2").alias("alias")])? .build()?; - let expected = "Projection: test.a AS alias\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS alias + TableScan: test projection=[a] + " + ) } #[test] @@ -1065,11 +1131,15 @@ mod tests { .build() .unwrap(); - let expected = "Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]]\ - \n Projection: \ - \n Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]]\ - \n TableScan: ?table? projection=[]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] + Projection: + Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] + TableScan: ?table? projection=[] + " + ) } #[test] @@ -1079,9 +1149,13 @@ mod tests { .project(vec![-col("a")])? .build()?; - let expected = "Projection: (- test.a)\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: (- test.a) + TableScan: test projection=[a] + " + ) } #[test] @@ -1091,9 +1165,13 @@ mod tests { .project(vec![col("a").is_null()])? .build()?; - let expected = "Projection: test.a IS NULL\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS NULL + TableScan: test projection=[a] + " + ) } #[test] @@ -1103,9 +1181,13 @@ mod tests { .project(vec![col("a").is_not_null()])? .build()?; - let expected = "Projection: test.a IS NOT NULL\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS NOT NULL + TableScan: test projection=[a] + " + ) } #[test] @@ -1115,9 +1197,13 @@ mod tests { .project(vec![col("a").is_true()])? .build()?; - let expected = "Projection: test.a IS TRUE\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS TRUE + TableScan: test projection=[a] + " + ) } #[test] @@ -1127,9 +1213,13 @@ mod tests { .project(vec![col("a").is_not_true()])? .build()?; - let expected = "Projection: test.a IS NOT TRUE\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS NOT TRUE + TableScan: test projection=[a] + " + ) } #[test] @@ -1139,9 +1229,13 @@ mod tests { .project(vec![col("a").is_false()])? .build()?; - let expected = "Projection: test.a IS FALSE\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS FALSE + TableScan: test projection=[a] + " + ) } #[test] @@ -1151,9 +1245,13 @@ mod tests { .project(vec![col("a").is_not_false()])? .build()?; - let expected = "Projection: test.a IS NOT FALSE\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS NOT FALSE + TableScan: test projection=[a] + " + ) } #[test] @@ -1163,9 +1261,13 @@ mod tests { .project(vec![col("a").is_unknown()])? .build()?; - let expected = "Projection: test.a IS UNKNOWN\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS UNKNOWN + TableScan: test projection=[a] + " + ) } #[test] @@ -1175,9 +1277,13 @@ mod tests { .project(vec![col("a").is_not_unknown()])? .build()?; - let expected = "Projection: test.a IS NOT UNKNOWN\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a IS NOT UNKNOWN + TableScan: test projection=[a] + " + ) } #[test] @@ -1187,9 +1293,13 @@ mod tests { .project(vec![not(col("a"))])? .build()?; - let expected = "Projection: NOT test.a\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: NOT test.a + TableScan: test projection=[a] + " + ) } #[test] @@ -1199,9 +1309,13 @@ mod tests { .project(vec![try_cast(col("a"), DataType::Float64)])? .build()?; - let expected = "Projection: TRY_CAST(test.a AS Float64)\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: TRY_CAST(test.a AS Float64) + TableScan: test projection=[a] + " + ) } #[test] @@ -1215,9 +1329,13 @@ mod tests { .project(vec![similar_to_expr])? .build()?; - let expected = "Projection: test.a SIMILAR TO Utf8(\"[0-9]\")\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r#" + Projection: test.a SIMILAR TO Utf8("[0-9]") + TableScan: test projection=[a] + "# + ) } #[test] @@ -1227,9 +1345,13 @@ mod tests { .project(vec![col("a").between(lit(1), lit(3))])? .build()?; - let expected = "Projection: test.a BETWEEN Int32(1) AND Int32(3)\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a BETWEEN Int32(1) AND Int32(3) + TableScan: test projection=[a] + " + ) } // Test Case expression @@ -1246,9 +1368,13 @@ mod tests { ])? .build()?; - let expected = "Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE Int32(0) END AS d\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE Int32(0) END AS d + TableScan: test projection=[a] + " + ) } // Test outer projection isn't discarded despite the same schema as inner @@ -1266,11 +1392,14 @@ mod tests { ])? .build()?; - let expected = - "Projection: a, CASE WHEN a = Int32(1) THEN Int32(10) ELSE d END AS d\ - \n Projection: test.a + Int32(1) AS a, Int32(0) AS d\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: a, CASE WHEN a = Int32(1) THEN Int32(10) ELSE d END AS d + Projection: test.a + Int32(1) AS a, Int32(0) AS d + TableScan: test projection=[a] + " + ) } // Since only column `a` is referred at the output. Scan should only contain projection=[a]. @@ -1288,10 +1417,14 @@ mod tests { .project(vec![col("a"), lit(0).alias("d")])? .build()?; - let expected = "Projection: test.a, Int32(0) AS d\ - \n NoOpUserDefined\ - \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, Int32(0) AS d + NoOpUserDefined + TableScan: test projection=[a] + " + ) } // Only column `a` is referred at the output. However, User defined node itself uses column `b` @@ -1315,10 +1448,14 @@ mod tests { .project(vec![col("a"), lit(0).alias("d")])? .build()?; - let expected = "Projection: test.a, Int32(0) AS d\ - \n NoOpUserDefined\ - \n TableScan: test projection=[a, b]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, Int32(0) AS d + NoOpUserDefined + TableScan: test projection=[a, b] + " + ) } // Only column `a` is referred at the output. However, User defined node itself uses expression `b+c` @@ -1350,10 +1487,14 @@ mod tests { .project(vec![col("a"), lit(0).alias("d")])? .build()?; - let expected = "Projection: test.a, Int32(0) AS d\ - \n NoOpUserDefined\ - \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, Int32(0) AS d + NoOpUserDefined + TableScan: test projection=[a, b, c] + " + ) } // Columns `l.a`, `l.c`, `r.a` is referred at the output. @@ -1374,11 +1515,15 @@ mod tests { .project(vec![col("l.a"), col("l.c"), col("r.a"), lit(0).alias("d")])? .build()?; - let expected = "Projection: l.a, l.c, r.a, Int32(0) AS d\ - \n UserDefinedCrossJoin\ - \n TableScan: l projection=[a, c]\ - \n TableScan: r projection=[a]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: l.a, l.c, r.a, Int32(0) AS d + UserDefinedCrossJoin + TableScan: l projection=[a, c] + TableScan: r projection=[a] + " + ) } #[test] @@ -1389,10 +1534,13 @@ mod tests { .aggregate(Vec::::new(), vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[max(test.b)]]\ - \n TableScan: test projection=[b]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[]], aggr=[[max(test.b)]] + TableScan: test projection=[b] + " + ) } #[test] @@ -1403,10 +1551,13 @@ mod tests { .aggregate(vec![col("c")], vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[max(test.b)]]\ - \n TableScan: test projection=[b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[max(test.b)]] + TableScan: test projection=[b, c] + " + ) } #[test] @@ -1418,11 +1569,14 @@ mod tests { .aggregate(vec![col("c")], vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[a.c]], aggr=[[max(a.b)]]\ - \n SubqueryAlias: a\ - \n TableScan: test projection=[b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[a.c]], aggr=[[max(a.b)]] + SubqueryAlias: a + TableScan: test projection=[b, c] + " + ) } #[test] @@ -1434,12 +1588,15 @@ mod tests { .aggregate(Vec::::new(), vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[max(test.b)]]\ - \n Projection: test.b\ - \n Filter: test.c > Int32(1)\ - \n TableScan: test projection=[b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[]], aggr=[[max(test.b)]] + Projection: test.b + Filter: test.c > Int32(1) + TableScan: test projection=[b, c] + " + ) } #[test] @@ -1460,11 +1617,13 @@ mod tests { .project([col(Column::new_unqualified("tag.one"))])? .build()?; - let expected = "\ - Aggregate: groupBy=[[]], aggr=[[max(m4.tag.one) AS tag.one]]\ - \n TableScan: m4 projection=[tag.one]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[]], aggr=[[max(m4.tag.one) AS tag.one]] + TableScan: m4 projection=[tag.one] + " + ) } #[test] @@ -1475,10 +1634,13 @@ mod tests { .project(vec![col("a"), col("b"), col("c")])? .project(vec![col("a"), col("c"), col("b")])? .build()?; - let expected = "Projection: test.a, test.c, test.b\ - \n TableScan: test projection=[a, b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.c, test.b + TableScan: test projection=[a, b, c] + " + ) } #[test] @@ -1486,9 +1648,10 @@ mod tests { let schema = Schema::new(test_table_scan_fields()); let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?.build()?; - let expected = "TableScan: test projection=[b, a, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @"TableScan: test projection=[b, a, c]" + ) } #[test] @@ -1498,10 +1661,13 @@ mod tests { let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))? .project(vec![col("a"), col("b")])? .build()?; - let expected = "Projection: test.a, test.b\ - \n TableScan: test projection=[b, a]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.b + TableScan: test projection=[b, a] + " + ) } #[test] @@ -1511,10 +1677,13 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![col("c"), col("b"), col("a")])? .build()?; - let expected = "Projection: test.c, test.b, test.a\ - \n TableScan: test projection=[a, b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.c, test.b, test.a + TableScan: test projection=[a, b, c] + " + ) } #[test] @@ -1529,14 +1698,18 @@ mod tests { .filter(col("a").gt(lit(1)))? .project(vec![col("a"), col("c"), col("b")])? .build()?; - let expected = "Projection: test.a, test.c, test.b\ - \n Filter: test.a > Int32(1)\ - \n Filter: test.b > Int32(1)\ - \n Projection: test.c, test.a, test.b\ - \n Filter: test.c > Int32(1)\ - \n Projection: test.c, test.b, test.a\ - \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.c, test.b + Filter: test.a > Int32(1) + Filter: test.b > Int32(1) + Projection: test.c, test.a, test.b + Filter: test.c > Int32(1) + Projection: test.c, test.b, test.a + TableScan: test projection=[a, b, c] + " + ) } #[test] @@ -1551,14 +1724,17 @@ mod tests { .project(vec![col("a"), col("b"), col("c1")])? .build()?; - // make sure projections are pushed down to both table scans - let expected = "Left Join: test.a = test2.c1\ - \n TableScan: test projection=[a, b]\ - \n TableScan: test2 projection=[c1]"; - let optimized_plan = optimize(plan)?; - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); + + // make sure projections are pushed down to both table scans + assert_snapshot!( + optimized_plan.clone(), + @r" + Left Join: test.a = test2.c1 + TableScan: test projection=[a, b] + TableScan: test2 projection=[c1] + " + ); // make sure schema for join node include both join columns let optimized_join = optimized_plan; @@ -1602,15 +1778,18 @@ mod tests { .project(vec![col("a"), col("b")])? .build()?; - // make sure projections are pushed down to both table scans - let expected = "Projection: test.a, test.b\ - \n Left Join: test.a = test2.c1\ - \n TableScan: test projection=[a, b]\ - \n TableScan: test2 projection=[c1]"; - let optimized_plan = optimize(plan)?; - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); + + // make sure projections are pushed down to both table scans + assert_snapshot!( + optimized_plan.clone(), + @r" + Projection: test.a, test.b + Left Join: test.a = test2.c1 + TableScan: test projection=[a, b] + TableScan: test2 projection=[c1] + " + ); // make sure schema for join node include both join columns let optimized_join = optimized_plan.inputs()[0]; @@ -1648,19 +1827,22 @@ mod tests { let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?; let plan = LogicalPlanBuilder::from(table_scan) - .join_using(table2_scan, JoinType::Left, vec!["a"])? + .join_using(table2_scan, JoinType::Left, vec!["a".into()])? .project(vec![col("a"), col("b")])? .build()?; - // make sure projections are pushed down to table scan - let expected = "Projection: test.a, test.b\ - \n Left Join: Using test.a = test2.a\ - \n TableScan: test projection=[a, b]\ - \n TableScan: test2 projection=[a]"; - let optimized_plan = optimize(plan)?; - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); + + // make sure projections are pushed down to table scan + assert_snapshot!( + optimized_plan.clone(), + @r" + Projection: test.a, test.b + Left Join: Using test.a = test2.a + TableScan: test projection=[a, b] + TableScan: test2 projection=[a] + " + ); // make sure schema for join node include both join columns let optimized_join = optimized_plan.inputs()[0]; @@ -1692,17 +1874,20 @@ mod tests { fn cast() -> Result<()> { let table_scan = test_table_scan()?; - let projection = LogicalPlanBuilder::from(table_scan) + let plan = LogicalPlanBuilder::from(table_scan) .project(vec![Expr::Cast(Cast::new( Box::new(col("c")), DataType::Float64, ))])? .build()?; - let expected = "Projection: CAST(test.c AS Float64)\ - \n TableScan: test projection=[c]"; - - assert_optimized_plan_equal(projection, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: CAST(test.c AS Float64) + TableScan: test projection=[c] + " + ) } #[test] @@ -1716,9 +1901,10 @@ mod tests { assert_fields_eq(&table_scan, vec!["a", "b", "c"]); assert_fields_eq(&plan, vec!["a", "b"]); - let expected = "TableScan: test projection=[a, b]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @"TableScan: test projection=[a, b]" + ) } #[test] @@ -1737,9 +1923,10 @@ mod tests { assert_fields_eq(&plan, vec!["a", "b"]); - let expected = "TableScan: test projection=[a, b]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @"TableScan: test projection=[a, b]" + ) } #[test] @@ -1755,11 +1942,14 @@ mod tests { assert_fields_eq(&plan, vec!["c", "a"]); - let expected = "Limit: skip=0, fetch=5\ - \n Projection: test.c, test.a\ - \n TableScan: test projection=[a, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=5 + Projection: test.c, test.a + TableScan: test projection=[a, c] + " + ) } #[test] @@ -1767,8 +1957,10 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan).build()?; // should expand projection to all columns without projection - let expected = "TableScan: test projection=[a, b, c]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @"TableScan: test projection=[a, b, c]" + ) } #[test] @@ -1777,9 +1969,13 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![lit(1_i64), lit(2_i64)])? .build()?; - let expected = "Projection: Int64(1), Int64(2)\ - \n TableScan: test projection=[]"; - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: Int64(1), Int64(2) + TableScan: test projection=[] + " + ) } /// tests that it removes unused columns in projections @@ -1799,13 +1995,15 @@ mod tests { assert_fields_eq(&plan, vec!["c", "max(test.a)"]); let plan = optimize(plan).expect("failed to optimize plan"); - let expected = "\ - Aggregate: groupBy=[[test.c]], aggr=[[max(test.a)]]\ - \n Filter: test.c > Int32(1)\ - \n Projection: test.c, test.a\ - \n TableScan: test projection=[a, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[max(test.a)]] + Filter: test.c > Int32(1) + Projection: test.c, test.a + TableScan: test projection=[a, c] + " + ) } /// tests that it removes un-needed projections @@ -1823,11 +2021,13 @@ mod tests { assert_fields_eq(&plan, vec!["a"]); - let expected = "\ - Projection: Int32(1) AS a\ - \n TableScan: test projection=[]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: Int32(1) AS a + TableScan: test projection=[] + " + ) } #[test] @@ -1852,11 +2052,13 @@ mod tests { assert_fields_eq(&plan, vec!["a"]); - let expected = "\ - Projection: Int32(1) AS a\ - \n TableScan: test projection=[], full_filters=[b = Int32(1)]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: Int32(1) AS a + TableScan: test projection=[], full_filters=[b = Int32(1)] + " + ) } /// tests that optimizing twice yields same plan @@ -1895,12 +2097,15 @@ mod tests { assert_fields_eq(&plan, vec!["c", "a", "max(test.b)"]); - let expected = "Projection: test.c, test.a, max(test.b)\ - \n Filter: test.c > Int32(1)\ - \n Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b)]]\ - \n TableScan: test projection=[a, b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.c, test.a, max(test.b) + Filter: test.c > Int32(1) + Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b)]] + TableScan: test projection=[a, b, c] + " + ) } #[test] @@ -1917,10 +2122,13 @@ mod tests { )? .build()?; - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(test.b), count(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\ - \n TableScan: test projection=[a, b, c]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.a]], aggr=[[count(test.b), count(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]] + TableScan: test projection=[a, b, c] + " + ) } #[test] @@ -1933,18 +2141,21 @@ mod tests { .project(vec![col("a")])? .build()?; - let expected = "Projection: test.a\ - \n Distinct:\ - \n TableScan: test projection=[a, b]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Distinct: + TableScan: test projection=[a, b] + " + ) } #[test] fn test_window() -> Result<()> { let table_scan = test_table_scan()?; - let max1 = Expr::WindowFunction(expr::WindowFunction::new( + let max1 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.a")], )) @@ -1952,7 +2163,7 @@ mod tests { .build() .unwrap(); - let max2 = Expr::WindowFunction(expr::WindowFunction::new( + let max2 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.b")], )); @@ -1965,13 +2176,16 @@ mod tests { .project(vec![col1, col2])? .build()?; - let expected = "Projection: max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n Projection: test.b, max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n TableScan: test projection=[a, b]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + WindowAggr: windowExpr=[[max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + Projection: test.b, max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + WindowAggr: windowExpr=[[max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + TableScan: test projection=[a, b] + " + ) } fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index b40121dbfeb7..4d2c2c7c79cd 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -33,6 +33,7 @@ use datafusion_common::{internal_err, DFSchema, DataFusionError, HashSet, Result use datafusion_expr::logical_plan::LogicalPlan; use crate::common_subexpr_eliminate::CommonSubexprEliminate; +use crate::decorrelate_lateral_join::DecorrelateLateralJoin; use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery; use crate::eliminate_cross_join::EliminateCrossJoin; use crate::eliminate_duplicated_expr::EliminateDuplicatedExpr; @@ -226,6 +227,7 @@ impl Optimizer { Arc::new(EliminateJoin::new()), Arc::new(DecorrelatePredicateSubquery::new()), Arc::new(ScalarSubqueryToJoin::new()), + Arc::new(DecorrelateLateralJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), Arc::new(EliminateDuplicatedExpr::new()), Arc::new(EliminateFilter::new()), @@ -413,7 +415,7 @@ impl Optimizer { previous_plans.insert(LogicalPlanSignature::new(&new_plan)); if !plan_is_fresh { // plan did not change, so no need to continue trying to optimize - debug!("optimizer pass {} did not make changes", i); + debug!("optimizer pass {i} did not make changes"); break; } i += 1; diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 344707ae8dbe..4fb9e117e2af 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -242,17 +242,31 @@ mod tests { binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Operator, }; + use crate::assert_optimized_plan_eq_snapshot; use crate::eliminate_filter::EliminateFilter; use crate::eliminate_nested_union::EliminateNestedUnion; use crate::test::{ - assert_optimized_plan_eq, assert_optimized_plan_with_rules, test_table_scan, - test_table_scan_fields, test_table_scan_with_name, + assert_optimized_plan_with_rules, test_table_scan, test_table_scan_fields, + test_table_scan_with_name, }; + use crate::OptimizerContext; use super::*; - fn assert_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(PropagateEmptyRelation::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(PropagateEmptyRelation::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } fn assert_together_optimized_plan( @@ -280,8 +294,7 @@ mod tests { .project(vec![binary_expr(lit(1), Operator::Plus, lit(1))])? .build()?; - let expected = "EmptyRelation"; - assert_eq(plan, expected) + assert_optimized_plan_equal!(plan, @"EmptyRelation") } #[test] diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index c9617514e453..bcb867f6e7fa 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -40,6 +40,7 @@ use datafusion_expr::{ }; use crate::optimizer::ApplyOrder; +use crate::simplify_expressions::simplify_predicates; use crate::utils::{has_all_column_refs, is_restrict_null_predicate}; use crate::{OptimizerConfig, OptimizerRule}; @@ -168,7 +169,7 @@ pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) { JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false), // No columns from the left side of the join can be referenced in output // predicates for semi/anti joins, so whether we specify t/f doesn't matter. - JoinType::RightSemi | JoinType::RightAnti => (false, true), + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true), } } @@ -191,6 +192,7 @@ pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) { JoinType::LeftAnti => (false, true), JoinType::RightAnti => (true, false), JoinType::LeftMark => (false, true), + JoinType::RightMark => (true, false), } } @@ -254,7 +256,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { let mut is_evaluate = true; predicate.apply(|expr| match expr { Expr::Column(_) - | Expr::Literal(_) + | Expr::Literal(_, _) | Expr::Placeholder(_) | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump), Expr::Exists { .. } @@ -691,7 +693,7 @@ fn infer_join_predicates_from_on_filters( inferred_predicates, ) } - JoinType::Right | JoinType::RightSemi => { + JoinType::Right | JoinType::RightSemi | JoinType::RightMark => { infer_join_predicates_impl::( join_col_keys, on_filters, @@ -778,6 +780,18 @@ impl OptimizerRule for PushDownFilter { return Ok(Transformed::no(plan)); }; + let predicate = split_conjunction_owned(filter.predicate.clone()); + let old_predicate_len = predicate.len(); + let new_predicates = simplify_predicates(predicate)?; + if old_predicate_len != new_predicates.len() { + let Some(new_predicate) = conjunction(new_predicates) else { + // new_predicates is empty - remove the filter entirely + // Return the child plan without the filter + return Ok(Transformed::yes(Arc::unwrap_or_clone(filter.input))); + }; + filter.predicate = new_predicate; + } + match Arc::unwrap_or_clone(filter.input) { LogicalPlan::Filter(child_filter) => { let parents_predicates = split_conjunction_owned(filter.predicate); @@ -1391,7 +1405,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; - use datafusion_common::{DFSchemaRef, ScalarValue}; + use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue}; use datafusion_expr::expr::{ScalarFunction, WindowFunction}; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ @@ -1401,38 +1415,47 @@ mod tests { WindowFunctionDefinition, }; + use crate::assert_optimized_plan_eq_snapshot; use crate::optimizer::Optimizer; use crate::simplify_expressions::SimplifyExpressions; use crate::test::*; use crate::OptimizerContext; use datafusion_expr::test::function_stub::sum; + use insta::assert_snapshot; use super::*; fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - crate::test::assert_optimized_plan_eq( - Arc::new(PushDownFilter::new()), - plan, - expected, - ) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(PushDownFilter::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } - fn assert_optimized_plan_eq_with_rewrite_predicate( - plan: LogicalPlan, - expected: &str, - ) -> Result<()> { - let optimizer = Optimizer::with_rules(vec![ - Arc::new(SimplifyExpressions::new()), - Arc::new(PushDownFilter::new()), - ]); - let optimized_plan = - optimizer.optimize(plan, &OptimizerContext::new(), observe)?; - - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(expected, formatted_plan); - Ok(()) + macro_rules! assert_optimized_plan_eq_with_rewrite_predicate { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer = Optimizer::with_rules(vec![ + Arc::new(SimplifyExpressions::new()), + Arc::new(PushDownFilter::new()), + ]); + let optimized_plan = optimizer.optimize($plan, &OptimizerContext::new(), observe)?; + assert_snapshot!(optimized_plan, @ $expected); + Ok::<(), DataFusionError>(()) + }}; } #[test] @@ -1443,10 +1466,13 @@ mod tests { .filter(col("a").eq(lit(1i64)))? .build()?; // filter is before projection - let expected = "\ - Projection: test.a, test.b\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.b + TableScan: test, full_filters=[test.a = Int64(1)] + " + ) } #[test] @@ -1458,12 +1484,15 @@ mod tests { .filter(col("a").eq(lit(1i64)))? .build()?; // filter is before single projection - let expected = "\ - Filter: test.a = Int64(1)\ - \n Limit: skip=0, fetch=10\ - \n Projection: test.a, test.b\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.a = Int64(1) + Limit: skip=0, fetch=10 + Projection: test.a, test.b + TableScan: test + " + ) } #[test] @@ -1472,8 +1501,10 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(lit(0i64).eq(lit(1i64)))? .build()?; - let expected = "TableScan: test, full_filters=[Int64(0) = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @"TableScan: test, full_filters=[Int64(0) = Int64(1)]" + ) } #[test] @@ -1485,11 +1516,14 @@ mod tests { .filter(col("a").eq(lit(1i64)))? .build()?; // filter is before double projection - let expected = "\ - Projection: test.c, test.b\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.c, test.b + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.a = Int64(1)] + " + ) } #[test] @@ -1500,10 +1534,13 @@ mod tests { .filter(col("a").gt(lit(10i64)))? .build()?; // filter of key aggregation is commutative - let expected = "\ - Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]\ - \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]] + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) } #[test] @@ -1513,10 +1550,14 @@ mod tests { .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])? .filter(col("b").gt(lit(10i64)))? .build()?; - let expected = "Filter: test.b > Int64(10)\ - \n Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.b > Int64(10) + Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]] + TableScan: test + " + ) } #[test] @@ -1525,10 +1566,13 @@ mod tests { .aggregate(vec![add(col("b"), col("a"))], vec![sum(col("a")), col("b")])? .filter(col("test.b + test.a").gt(lit(10i64)))? .build()?; - let expected = - "Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]\ - \n TableScan: test, full_filters=[test.b + test.a > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]] + TableScan: test, full_filters=[test.b + test.a > Int64(10)] + " + ) } #[test] @@ -1539,11 +1583,14 @@ mod tests { .filter(col("b").gt(lit(10i64)))? .build()?; // filter of aggregate is after aggregation since they are non-commutative - let expected = "\ - Filter: b > Int64(10)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: b > Int64(10) + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]] + TableScan: test + " + ) } /// verifies that when partitioning by 'a' and 'b', and filtering by 'b', 'b' is pushed @@ -1551,7 +1598,7 @@ mod tests { fn filter_move_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1567,10 +1614,13 @@ mod tests { .filter(col("b").gt(lit(10i64)))? .build()?; - let expected = "\ - WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test, full_filters=[test.b > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test, full_filters=[test.b > Int64(10)] + " + ) } /// verifies that when partitioning by 'a' and 'b', and filtering by 'a' and 'b', both 'a' and @@ -1579,7 +1629,7 @@ mod tests { fn filter_move_complex_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1595,10 +1645,13 @@ mod tests { .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))? .build()?; - let expected = "\ - WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a, test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test, full_filters=[test.a > Int64(10), test.b = Int64(1)] + " + ) } /// verifies that when partitioning by 'a' and filtering by 'a' and 'b', only 'a' is pushed @@ -1606,7 +1659,7 @@ mod tests { fn filter_move_partial_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1622,11 +1675,14 @@ mod tests { .filter(and(col("a").gt(lit(10i64)), col("b").eq(lit(1i64))))? .build()?; - let expected = "\ - Filter: test.b = Int64(1)\ - \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.b = Int64(1) + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) } /// verifies that filters on partition expressions are not pushed, as the single expression @@ -1635,7 +1691,7 @@ mod tests { fn filter_expression_keep_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1653,11 +1709,14 @@ mod tests { .filter(add(col("a"), col("b")).gt(lit(10i64)))? .build()?; - let expected = "\ - Filter: test.a + test.b > Int64(10)\ - \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.a + test.b > Int64(10) + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a + test.b] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test + " + ) } /// verifies that filters are not pushed on order by columns (that are not used in partitioning) @@ -1665,7 +1724,7 @@ mod tests { fn filter_order_keep_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1681,11 +1740,14 @@ mod tests { .filter(col("c").gt(lit(10i64)))? .build()?; - let expected = "\ - Filter: test.c > Int64(10)\ - \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.c > Int64(10) + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test + " + ) } /// verifies that when we use multiple window functions with a common partition key, the filter @@ -1694,7 +1756,7 @@ mod tests { fn filter_multiple_windows_common_partitions() -> Result<()> { let table_scan = test_table_scan()?; - let window1 = Expr::WindowFunction(WindowFunction::new( + let window1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1705,7 +1767,7 @@ mod tests { .build() .unwrap(); - let window2 = Expr::WindowFunction(WindowFunction::new( + let window2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1721,10 +1783,13 @@ mod tests { .filter(col("a").gt(lit(10i64)))? // a appears in both window functions .build()?; - let expected = "\ - WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) } /// verifies that when we use multiple window functions with different partitions keys, the @@ -1733,7 +1798,7 @@ mod tests { fn filter_multiple_windows_disjoint_partitions() -> Result<()> { let table_scan = test_table_scan()?; - let window1 = Expr::WindowFunction(WindowFunction::new( + let window1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1744,7 +1809,7 @@ mod tests { .build() .unwrap(); - let window2 = Expr::WindowFunction(WindowFunction::new( + let window2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1760,11 +1825,14 @@ mod tests { .filter(col("b").gt(lit(10i64)))? // b only appears in one window function .build()?; - let expected = "\ - Filter: test.b > Int64(10)\ - \n WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.b > Int64(10) + WindowAggr: windowExpr=[[rank() PARTITION BY [test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, rank() PARTITION BY [test.b, test.a] ORDER BY [test.c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + TableScan: test + " + ) } /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written @@ -1776,10 +1844,13 @@ mod tests { .filter(col("b").eq(lit(1i64)))? .build()?; // filter is before projection - let expected = "\ - Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a = Int64(1)] + " + ) } fn add(left: Expr, right: Expr) -> Expr { @@ -1811,19 +1882,21 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "\ - Filter: b = Int64(1)\ - \n Projection: test.a * Int32(2) + test.c AS b, test.c\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Filter: b = Int64(1) + Projection: test.a * Int32(2) + test.c AS b, test.c + TableScan: test + ", ); - // filter is before projection - let expected = "\ - Projection: test.a * Int32(2) + test.c AS b, test.c\ - \n TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a * Int32(2) + test.c AS b, test.c + TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)] + " + ) } /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written @@ -1841,21 +1914,23 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "\ - Filter: a = Int64(1)\ - \n Projection: b * Int32(3) AS a, test.c\ - \n Projection: test.a * Int32(2) + test.c AS b, test.c\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Filter: a = Int64(1) + Projection: b * Int32(3) AS a, test.c + Projection: test.a * Int32(2) + test.c AS b, test.c + TableScan: test + ", ); - // filter is before the projections - let expected = "\ - Projection: b * Int32(3) AS a, test.c\ - \n Projection: test.a * Int32(2) + test.c AS b, test.c\ - \n TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: b * Int32(3) AS a, test.c + Projection: test.a * Int32(2) + test.c AS b, test.c + TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)] + " + ) } #[derive(Debug, PartialEq, Eq, Hash)] @@ -1930,10 +2005,13 @@ mod tests { .build()?; // Push filter below NoopPlan - let expected = "\ - NoopPlan\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_equal!( + plan, + @r" + NoopPlan + TableScan: test, full_filters=[test.a = Int64(1)] + " + )?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1946,11 +2024,14 @@ mod tests { .build()?; // Push only predicate on `a` below NoopPlan - let expected = "\ - Filter: test.c = Int64(2)\ - \n NoopPlan\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.c = Int64(2) + NoopPlan + TableScan: test, full_filters=[test.a = Int64(1)] + " + )?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1963,11 +2044,14 @@ mod tests { .build()?; // Push filter below NoopPlan for each child branch - let expected = "\ - NoopPlan\ - \n TableScan: test, full_filters=[test.a = Int64(1)]\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_equal!( + plan, + @r" + NoopPlan + TableScan: test, full_filters=[test.a = Int64(1)] + TableScan: test, full_filters=[test.a = Int64(1)] + " + )?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1980,12 +2064,15 @@ mod tests { .build()?; // Push only predicate on `a` below NoopPlan - let expected = "\ - Filter: test.c = Int64(2)\ - \n NoopPlan\ - \n TableScan: test, full_filters=[test.a = Int64(1)]\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.c = Int64(2) + NoopPlan + TableScan: test, full_filters=[test.a = Int64(1)] + TableScan: test, full_filters=[test.a = Int64(1)] + " + ) } /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed @@ -2002,23 +2089,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "\ - Filter: sum(test.c) > Int64(10)\ - \n Filter: b > Int64(10)\ - \n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Filter: sum(test.c) > Int64(10) + Filter: b > Int64(10) + Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]] + Projection: test.a AS b, test.c + TableScan: test + ", ); - // filter is before the projections - let expected = "\ - Filter: sum(test.c) > Int64(10)\ - \n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: sum(test.c) > Int64(10) + Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]] + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) } /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed @@ -2037,22 +2126,24 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "\ - Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)\ - \n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20) + Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]] + Projection: test.a AS b, test.c + TableScan: test + ", ); - // filter is before the projections - let expected = "\ - Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)\ - \n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20) + Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]] + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) } /// verifies that when two limits are in place, we jump neither @@ -2067,14 +2158,17 @@ mod tests { .filter(col("a").eq(lit(1i64)))? .build()?; // filter does not just any of the limits - let expected = "\ - Projection: test.a, test.b\ - \n Filter: test.a = Int64(1)\ - \n Limit: skip=0, fetch=10\ - \n Limit: skip=0, fetch=20\ - \n Projection: test.a, test.b\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.b + Filter: test.a = Int64(1) + Limit: skip=0, fetch=10 + Limit: skip=0, fetch=20 + Projection: test.a, test.b + TableScan: test + " + ) } #[test] @@ -2086,10 +2180,14 @@ mod tests { .filter(col("a").eq(lit(1i64)))? .build()?; // filter appears below Union - let expected = "Union\ - \n TableScan: test, full_filters=[test.a = Int64(1)]\ - \n TableScan: test2, full_filters=[test2.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Union + TableScan: test, full_filters=[test.a = Int64(1)] + TableScan: test2, full_filters=[test2.a = Int64(1)] + " + ) } #[test] @@ -2106,13 +2204,18 @@ mod tests { .build()?; // filter appears below Union - let expected = "Union\n SubqueryAlias: test2\ - \n Projection: test.a AS b\ - \n TableScan: test, full_filters=[test.a = Int64(1)]\ - \n SubqueryAlias: test2\ - \n Projection: test.a AS b\ - \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Union + SubqueryAlias: test2 + Projection: test.a AS b + TableScan: test, full_filters=[test.a = Int64(1)] + SubqueryAlias: test2 + Projection: test.a AS b + TableScan: test, full_filters=[test.a = Int64(1)] + " + ) } #[test] @@ -2136,14 +2239,17 @@ mod tests { .filter(filter)? .build()?; - let expected = "Projection: test.a, test1.d\ - \n Cross Join: \ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.a = Int32(1)]\ - \n Projection: test1.d, test1.e, test1.f\ - \n TableScan: test1, full_filters=[test1.d > Int32(2)]"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test1.d + Cross Join: + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.a = Int32(1)] + Projection: test1.d, test1.e, test1.f + TableScan: test1, full_filters=[test1.d > Int32(2)] + " + ) } #[test] @@ -2163,13 +2269,17 @@ mod tests { .filter(filter)? .build()?; - let expected = "Projection: test.a, test1.a\ - \n Cross Join: \ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.a = Int32(1)]\ - \n Projection: test1.a, test1.b, test1.c\ - \n TableScan: test1, full_filters=[test1.a > Int32(2)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test1.a + Cross Join: + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.a = Int32(1)] + Projection: test1.a, test1.b, test1.c + TableScan: test1, full_filters=[test1.a > Int32(2)] + " + ) } /// verifies that filters with the same columns are correctly placed @@ -2186,24 +2296,26 @@ mod tests { // Should be able to move both filters below the projections // not part of the test - assert_eq!( - format!("{plan}"), - "Filter: test.a >= Int64(1)\ - \n Projection: test.a\ - \n Limit: skip=0, fetch=1\ - \n Filter: test.a <= Int64(1)\ - \n Projection: test.a\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Filter: test.a >= Int64(1) + Projection: test.a + Limit: skip=0, fetch=1 + Filter: test.a <= Int64(1) + Projection: test.a + TableScan: test + ", ); - - let expected = "\ - Projection: test.a\ - \n Filter: test.a >= Int64(1)\ - \n Limit: skip=0, fetch=1\ - \n Projection: test.a\ - \n TableScan: test, full_filters=[test.a <= Int64(1)]"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Filter: test.a >= Int64(1) + Limit: skip=0, fetch=1 + Projection: test.a + TableScan: test, full_filters=[test.a <= Int64(1)] + " + ) } /// verifies that filters to be placed on the same depth are ANDed @@ -2218,22 +2330,24 @@ mod tests { .build()?; // not part of the test - assert_eq!( - format!("{plan}"), - "Projection: test.a\ - \n Filter: test.a >= Int64(1)\ - \n Filter: test.a <= Int64(1)\ - \n Limit: skip=0, fetch=1\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Projection: test.a + Filter: test.a >= Int64(1) + Filter: test.a <= Int64(1) + Limit: skip=0, fetch=1 + TableScan: test + ", ); - - let expected = "\ - Projection: test.a\ - \n Filter: test.a >= Int64(1) AND test.a <= Int64(1)\ - \n Limit: skip=0, fetch=1\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Filter: test.a >= Int64(1) AND test.a <= Int64(1) + Limit: skip=0, fetch=1 + TableScan: test + " + ) } /// verifies that filters on a plan with user nodes are not lost @@ -2247,19 +2361,21 @@ mod tests { let plan = user_defined::new(plan); - let expected = "\ - TestUserDefined\ - \n Filter: test.a <= Int64(1)\ - \n TableScan: test"; - // not part of the test - assert_eq!(format!("{plan}"), expected); - - let expected = "\ - TestUserDefined\ - \n TableScan: test, full_filters=[test.a <= Int64(1)]"; - - assert_optimized_plan_eq(plan, expected) + assert_snapshot!(plan, + @r" + TestUserDefined + Filter: test.a <= Int64(1) + TableScan: test + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + TestUserDefined + TableScan: test, full_filters=[test.a <= Int64(1)] + " + ) } /// post-on-join predicates on a column common to both sides is pushed to both sides @@ -2282,22 +2398,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test.a <= Int64(1)\ - \n Inner Join: test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test.a <= Int64(1) + Inner Join: test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2 + ", ); - // filter sent to side before the join - let expected = "\ - Inner Join: test.a = test2.a\ - \n TableScan: test, full_filters=[test.a <= Int64(1)]\ - \n Projection: test2.a\ - \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: test.a = test2.a + TableScan: test, full_filters=[test.a <= Int64(1)] + Projection: test2.a + TableScan: test2, full_filters=[test2.a <= Int64(1)] + " + ) } /// post-using-join predicates on a column common to both sides is pushed to both sides @@ -2319,22 +2438,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test.a <= Int64(1)\ - \n Inner Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test.a <= Int64(1) + Inner Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2 + ", ); - // filter sent to side before the join - let expected = "\ - Inner Join: Using test.a = test2.a\ - \n TableScan: test, full_filters=[test.a <= Int64(1)]\ - \n Projection: test2.a\ - \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: Using test.a = test2.a + TableScan: test, full_filters=[test.a <= Int64(1)] + Projection: test2.a + TableScan: test2, full_filters=[test2.a <= Int64(1)] + " + ) } /// post-join predicates with columns from both sides are converted to join filters @@ -2359,24 +2481,27 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test.c <= test2.b\ - \n Inner Join: test.a = test2.a\ - \n Projection: test.a, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test.c <= test2.b + Inner Join: test.a = test2.a + Projection: test.a, test.c + TableScan: test + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // Filter is converted to Join Filter - let expected = "\ - Inner Join: test.a = test2.a Filter: test.c <= test2.b\ - \n Projection: test.a, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: test.a = test2.a Filter: test.c <= test2.b + Projection: test.a, test.c + TableScan: test + Projection: test2.a, test2.b + TableScan: test2 + " + ) } /// post-join predicates with columns from one side of a join are pushed only to that side @@ -2402,23 +2527,26 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test.b <= Int64(1)\ - \n Inner Join: test.a = test2.a\ - \n Projection: test.a, test.b\ - \n TableScan: test\ - \n Projection: test2.a, test2.c\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test.b <= Int64(1) + Inner Join: test.a = test2.a + Projection: test.a, test.b + TableScan: test + Projection: test2.a, test2.c + TableScan: test2 + ", ); - - let expected = "\ - Inner Join: test.a = test2.a\ - \n Projection: test.a, test.b\ - \n TableScan: test, full_filters=[test.b <= Int64(1)]\ - \n Projection: test2.a, test2.c\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: test.a = test2.a + Projection: test.a, test.b + TableScan: test, full_filters=[test.b <= Int64(1)] + Projection: test2.a, test2.c + TableScan: test2 + " + ) } /// post-join predicates on the right side of a left join are not duplicated @@ -2441,23 +2569,26 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test2.a <= Int64(1)\ - \n Left Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test2.a <= Int64(1) + Left Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2 + ", ); - // filter not duplicated nor pushed down - i.e. noop - let expected = "\ - Filter: test2.a <= Int64(1)\ - \n Left Join: Using test.a = test2.a\ - \n TableScan: test, full_filters=[test.a <= Int64(1)]\ - \n Projection: test2.a\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test2.a <= Int64(1) + Left Join: Using test.a = test2.a + TableScan: test, full_filters=[test.a <= Int64(1)] + Projection: test2.a + TableScan: test2 + " + ) } /// post-join predicates on the left side of a right join are not duplicated @@ -2479,23 +2610,26 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test.a <= Int64(1)\ - \n Right Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test.a <= Int64(1) + Right Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2 + ", ); - // filter not duplicated nor pushed down - i.e. noop - let expected = "\ - Filter: test.a <= Int64(1)\ - \n Right Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.a <= Int64(1) + Right Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2, full_filters=[test2.a <= Int64(1)] + " + ) } /// post-left-join predicate on a column common to both sides is only pushed to the left side @@ -2518,22 +2652,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test.a <= Int64(1)\ - \n Left Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test.a <= Int64(1) + Left Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2 + ", ); - // filter sent to left side of the join, not the right - let expected = "\ - Left Join: Using test.a = test2.a\ - \n TableScan: test, full_filters=[test.a <= Int64(1)]\ - \n Projection: test2.a\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: Using test.a = test2.a + TableScan: test, full_filters=[test.a <= Int64(1)] + Projection: test2.a + TableScan: test2 + " + ) } /// post-right-join predicate on a column common to both sides is only pushed to the right side @@ -2556,22 +2693,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test2.a <= Int64(1)\ - \n Right Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test2.a <= Int64(1) + Right Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2 + ", ); - // filter sent to right side of join, not duplicated to the left - let expected = "\ - Right Join: Using test.a = test2.a\ - \n TableScan: test\ - \n Projection: test2.a\ - \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Right Join: Using test.a = test2.a + TableScan: test + Projection: test2.a + TableScan: test2, full_filters=[test2.a <= Int64(1)] + " + ) } /// single table predicate parts of ON condition should be pushed to both inputs @@ -2599,22 +2739,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2 + ", ); - - let expected = "\ - Inner Join: test.a = test2.a Filter: test.b < test2.b\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.c > UInt32(1)]\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: test.a = test2.a Filter: test.b < test2.b + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.c > UInt32(1)] + Projection: test2.a, test2.b, test2.c + TableScan: test2, full_filters=[test2.c > UInt32(4)] + " + ) } /// join filter should be completely removed after pushdown @@ -2641,22 +2784,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2 + ", ); - - let expected = "\ - Inner Join: test.a = test2.a\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.b > UInt32(1)]\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: test.a = test2.a + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.b > UInt32(1)] + Projection: test2.a, test2.b, test2.c + TableScan: test2, full_filters=[test2.c > UInt32(4)] + " + ) } /// predicate on join key in filter expression should be pushed down to both inputs @@ -2681,22 +2827,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Inner Join: test.a = test2.b Filter: test.a > UInt32(1)\ - \n Projection: test.a\ - \n TableScan: test\ - \n Projection: test2.b\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Inner Join: test.a = test2.b Filter: test.a > UInt32(1) + Projection: test.a + TableScan: test + Projection: test2.b + TableScan: test2 + ", ); - - let expected = "\ - Inner Join: test.a = test2.b\ - \n Projection: test.a\ - \n TableScan: test, full_filters=[test.a > UInt32(1)]\ - \n Projection: test2.b\ - \n TableScan: test2, full_filters=[test2.b > UInt32(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: test.a = test2.b + Projection: test.a + TableScan: test, full_filters=[test.a > UInt32(1)] + Projection: test2.b + TableScan: test2, full_filters=[test2.b > UInt32(1)] + " + ) } /// single table predicate parts of ON condition should be pushed to right input @@ -2724,22 +2873,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2 + ", ); - - let expected = "\ - Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2, full_filters=[test2.c > UInt32(4)] + " + ) } /// single table predicate parts of ON condition should be pushed to left input @@ -2767,22 +2919,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2 + ", ); - - let expected = "\ - Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.a > UInt32(1)]\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Right Join: test.a = test2.a Filter: test.b < test2.b AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.a > UInt32(1)] + Projection: test2.a, test2.b, test2.c + TableScan: test2 + " + ) } /// single table predicate parts of ON condition should not be pushed @@ -2810,17 +2965,25 @@ mod tests { .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test\ - \n Projection: test2.a, test2.b, test2.c\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2 + ", ); - - let expected = &format!("{plan}"); - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4) + Projection: test.a, test.b, test.c + TableScan: test + Projection: test2.a, test2.b, test2.c + TableScan: test2 + " + ) } struct PushDownProvider { @@ -2887,9 +3050,10 @@ mod tests { fn filter_with_table_provider_exact() -> Result<()> { let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?; - let expected = "\ - TableScan: test, full_filters=[a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @"TableScan: test, full_filters=[a = Int64(1)]" + ) } #[test] @@ -2897,10 +3061,13 @@ mod tests { let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?; - let expected = "\ - Filter: a = Int64(1)\ - \n TableScan: test, partial_filters=[a = Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: a = Int64(1) + TableScan: test, partial_filters=[a = Int64(1)] + " + ) } #[test] @@ -2913,13 +3080,15 @@ mod tests { .expect("failed to optimize plan") .data; - let expected = "\ - Filter: a = Int64(1)\ - \n TableScan: test, partial_filters=[a = Int64(1)]"; - // Optimizing the same plan multiple times should produce the same plan // each time. - assert_optimized_plan_eq(optimized_plan, expected) + assert_optimized_plan_equal!( + optimized_plan, + @r" + Filter: a = Int64(1) + TableScan: test, partial_filters=[a = Int64(1)] + " + ) } #[test] @@ -2927,10 +3096,13 @@ mod tests { let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?; - let expected = "\ - Filter: a = Int64(1)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: a = Int64(1) + TableScan: test + " + ) } #[test] @@ -2944,11 +3116,14 @@ mod tests { .project(vec![col("a"), col("b")])? .build()?; - let expected = "Projection: a, b\ - \n Filter: a = Int64(10) AND b > Int64(11)\ - \n TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: a, b + Filter: a = Int64(10) AND b > Int64(11) + TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)] + " + ) } #[test] @@ -2962,13 +3137,13 @@ mod tests { .project(vec![col("a"), col("b")])? .build()?; - let expected = r#" -Projection: a, b - TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)] - "# - .trim(); - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: a, b + TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)] + " + ) } #[test] @@ -2983,20 +3158,21 @@ Projection: a, b .build()?; // filter on col b - assert_eq!( - format!("{plan}"), - "Filter: b > Int64(10) AND test.c > Int64(10)\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test" + assert_snapshot!(plan, + @r" + Filter: b > Int64(10) AND test.c > Int64(10) + Projection: test.a AS b, test.c + TableScan: test + ", ); - // rewrite filter col b to test.a - let expected = "\ - Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\ - "; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)] + " + ) } #[test] @@ -3012,23 +3188,23 @@ Projection: a, b .build()?; // filter on col b - assert_eq!( - format!("{plan}"), - "Filter: b > Int64(10) AND test.c > Int64(10)\ - \n Projection: b, test.c\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test\ - " + assert_snapshot!(plan, + @r" + Filter: b > Int64(10) AND test.c > Int64(10) + Projection: b, test.c + Projection: test.a AS b, test.c + TableScan: test + ", ); - // rewrite filter col b to test.a - let expected = "\ - Projection: b, test.c\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\ - "; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: b, test.c + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)] + " + ) } #[test] @@ -3040,20 +3216,21 @@ Projection: a, b .build()?; // filter on col b and d - assert_eq!( - format!("{plan}"), - "Filter: b > Int64(10) AND d > Int64(10)\ - \n Projection: test.a AS b, test.c AS d\ - \n TableScan: test\ - " + assert_snapshot!(plan, + @r" + Filter: b > Int64(10) AND d > Int64(10) + Projection: test.a AS b, test.c AS d + TableScan: test + ", ); - // rewrite filter col b to test.a, col d to test.c - let expected = "\ - Projection: test.a AS b, test.c AS d\ - \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS b, test.c AS d + TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)] + " + ) } /// predicate on join key in filter expression should be pushed down to both inputs @@ -3077,23 +3254,26 @@ Projection: a, b )? .build()?; - assert_eq!( - format!("{plan}"), - "Inner Join: c = d Filter: c > UInt32(1)\ - \n Projection: test.a AS c\ - \n TableScan: test\ - \n Projection: test2.b AS d\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Inner Join: c = d Filter: c > UInt32(1) + Projection: test.a AS c + TableScan: test + Projection: test2.b AS d + TableScan: test2 + ", ); - // Change filter on col `c`, 'd' to `test.a`, 'test.b' - let expected = "\ - Inner Join: c = d\ - \n Projection: test.a AS c\ - \n TableScan: test, full_filters=[test.a > UInt32(1)]\ - \n Projection: test2.b AS d\ - \n TableScan: test2, full_filters=[test2.b > UInt32(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Inner Join: c = d + Projection: test.a AS c + TableScan: test, full_filters=[test.a > UInt32(1)] + Projection: test2.b AS d + TableScan: test2, full_filters=[test2.b > UInt32(1)] + " + ) } #[test] @@ -3109,20 +3289,21 @@ Projection: a, b .build()?; // filter on col b - assert_eq!( - format!("{plan}"), - "Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test\ - " + assert_snapshot!(plan, + @r" + Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)]) + Projection: test.a AS b, test.c + TableScan: test + ", ); - // rewrite filter col b to test.a - let expected = "\ - Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])] + " + ) } #[test] @@ -3139,22 +3320,23 @@ Projection: a, b .build()?; // filter on col b - assert_eq!( - format!("{plan}"), - "Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\ - \n Projection: b, test.c\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test\ - " + assert_snapshot!(plan, + @r" + Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)]) + Projection: b, test.c + Projection: test.a AS b, test.c + TableScan: test + ", ); - // rewrite filter col b to test.a - let expected = "\ - Projection: b, test.c\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: b, test.c + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])] + " + ) } #[test] @@ -3174,23 +3356,27 @@ Projection: a, b .build()?; // filter on col b in subquery - let expected_before = "\ - Filter: b IN ()\ - \n Subquery:\ - \n Projection: sq.c\ - \n TableScan: sq\ - \n Projection: test.a AS b, test.c\ - \n TableScan: test"; - assert_eq!(format!("{plan}"), expected_before); - + assert_snapshot!(plan, + @r" + Filter: b IN () + Subquery: + Projection: sq.c + TableScan: sq + Projection: test.a AS b, test.c + TableScan: test + ", + ); // rewrite filter col b to test.a - let expected_after = "\ - Projection: test.a AS b, test.c\ - \n TableScan: test, full_filters=[test.a IN ()]\ - \n Subquery:\ - \n Projection: sq.c\ - \n TableScan: sq"; - assert_optimized_plan_eq(plan, expected_after) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a AS b, test.c + TableScan: test, full_filters=[test.a IN ()] + Subquery: + Projection: sq.c + TableScan: sq + " + ) } #[test] @@ -3205,25 +3391,31 @@ Projection: a, b .project(vec![col("b.a")])? .build()?; - let expected_before = "Projection: b.a\ - \n Filter: b.a = Int64(1)\ - \n SubqueryAlias: b\ - \n Projection: b.a\ - \n SubqueryAlias: b\ - \n Projection: Int64(0) AS a\ - \n EmptyRelation"; - assert_eq!(format!("{plan}"), expected_before); - + assert_snapshot!(plan, + @r" + Projection: b.a + Filter: b.a = Int64(1) + SubqueryAlias: b + Projection: b.a + SubqueryAlias: b + Projection: Int64(0) AS a + EmptyRelation + ", + ); // Ensure that the predicate without any columns (0 = 1) is // still there. - let expected_after = "Projection: b.a\ - \n SubqueryAlias: b\ - \n Projection: b.a\ - \n SubqueryAlias: b\ - \n Projection: Int64(0) AS a\ - \n Filter: Int64(0) = Int64(1)\ - \n EmptyRelation"; - assert_optimized_plan_eq(plan, expected_after) + assert_optimized_plan_equal!( + plan, + @r" + Projection: b.a + SubqueryAlias: b + Projection: b.a + SubqueryAlias: b + Projection: Int64(0) AS a + Filter: Int64(0) = Int64(1) + EmptyRelation + " + ) } #[test] @@ -3245,13 +3437,14 @@ Projection: a, b .cross_join(right)? .filter(filter)? .build()?; - let expected = "\ - Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10)\ - \n Projection: test.a, test.b, test.c\ - \n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]\ - \n Projection: test1.a AS d, test1.a AS e\ - \n TableScan: test1"; - assert_optimized_plan_eq_with_rewrite_predicate(plan.clone(), expected)?; + + assert_optimized_plan_eq_with_rewrite_predicate!(plan.clone(), @r" + Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10) + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)] + Projection: test1.a AS d, test1.a AS e + TableScan: test1 + ")?; // Originally global state which can help to avoid duplicate Filters been generated and pushed down. // Now the global state is removed. Need to double confirm that avoid duplicate Filters. @@ -3259,7 +3452,16 @@ Projection: a, b .rewrite(plan, &OptimizerContext::new()) .expect("failed to optimize plan") .data; - assert_optimized_plan_eq(optimized_plan, expected) + assert_optimized_plan_equal!( + optimized_plan, + @r" + Inner Join: Filter: test.a = d AND test.b > UInt32(1) OR test.b = e AND test.c < UInt32(10) + Projection: test.a, test.b, test.c + TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)] + Projection: test1.a AS d, test1.a AS e + TableScan: test1 + " + ) } #[test] @@ -3283,23 +3485,26 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test2.a <= Int64(1)\ - \n LeftSemi Join: test1.a = test2.a\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2" + assert_snapshot!(plan, + @r" + Filter: test2.a <= Int64(1) + LeftSemi Join: test1.a = test2.a + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // Inferred the predicate `test1.a <= Int64(1)` and push it down to the left side. - let expected = "\ - Filter: test2.a <= Int64(1)\ - \n LeftSemi Join: test1.a = test2.a\ - \n TableScan: test1, full_filters=[test1.a <= Int64(1)]\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test2.a <= Int64(1) + LeftSemi Join: test1.a = test2.a + TableScan: test1, full_filters=[test1.a <= Int64(1)] + Projection: test2.a, test2.b + TableScan: test2 + " + ) } #[test] @@ -3326,21 +3531,24 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // Both side will be pushed down. - let expected = "\ - LeftSemi Join: test1.a = test2.a\ - \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + LeftSemi Join: test1.a = test2.a + TableScan: test1, full_filters=[test1.b > UInt32(1)] + Projection: test2.a, test2.b + TableScan: test2, full_filters=[test2.b > UInt32(2)] + " + ) } #[test] @@ -3364,23 +3572,26 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test1.a <= Int64(1)\ - \n RightSemi Join: test1.a = test2.a\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + Filter: test1.a <= Int64(1) + RightSemi Join: test1.a = test2.a + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // Inferred the predicate `test2.a <= Int64(1)` and push it down to the right side. - let expected = "\ - Filter: test1.a <= Int64(1)\ - \n RightSemi Join: test1.a = test2.a\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test1.a <= Int64(1) + RightSemi Join: test1.a = test2.a + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2, full_filters=[test2.a <= Int64(1)] + " + ) } #[test] @@ -3407,21 +3618,24 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // Both side will be pushed down. - let expected = "\ - RightSemi Join: test1.a = test2.a\ - \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + RightSemi Join: test1.a = test2.a + TableScan: test1, full_filters=[test1.b > UInt32(1)] + Projection: test2.a, test2.b + TableScan: test2, full_filters=[test2.b > UInt32(2)] + " + ) } #[test] @@ -3448,25 +3662,28 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test2.a > UInt32(2)\ - \n LeftAnti Join: test1.a = test2.a\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + Filter: test2.a > UInt32(2) + LeftAnti Join: test1.a = test2.a + Projection: test1.a, test1.b + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // For left anti, filter of the right side filter can be pushed down. - let expected = "\ - Filter: test2.a > UInt32(2)\ - \n LeftAnti Join: test1.a = test2.a\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1, full_filters=[test1.a > UInt32(2)]\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test2.a > UInt32(2) + LeftAnti Join: test1.a = test2.a + Projection: test1.a, test1.b + TableScan: test1, full_filters=[test1.a > UInt32(2)] + Projection: test2.a, test2.b + TableScan: test2 + " + ) } #[test] @@ -3496,23 +3713,26 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) + Projection: test1.a, test1.b + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // For left anti, filter of the right side filter can be pushed down. - let expected = "\ - LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1)\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) + Projection: test1.a, test1.b + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2, full_filters=[test2.b > UInt32(2)] + " + ) } #[test] @@ -3539,25 +3759,28 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "Filter: test1.a > UInt32(2)\ - \n RightAnti Join: test1.a = test2.a\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + Filter: test1.a > UInt32(2) + RightAnti Join: test1.a = test2.a + Projection: test1.a, test1.b + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // For right anti, filter of the left side can be pushed down. - let expected = "\ - Filter: test1.a > UInt32(2)\ - \n RightAnti Join: test1.a = test2.a\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2, full_filters=[test2.a > UInt32(2)]"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test1.a > UInt32(2) + RightAnti Join: test1.a = test2.a + Projection: test1.a, test1.b + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2, full_filters=[test2.a > UInt32(2)] + " + ) } #[test] @@ -3587,22 +3810,26 @@ Projection: a, b .build()?; // not part of the test, just good to know: - assert_eq!( - format!("{plan}"), - "RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2", + assert_snapshot!(plan, + @r" + RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2) + Projection: test1.a, test1.b + TableScan: test1 + Projection: test2.a, test2.b + TableScan: test2 + ", ); - // For right anti, filter of the left side can be pushed down. - let expected = "RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2)\ - \n Projection: test1.a, test1.b\ - \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ - \n Projection: test2.a, test2.b\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + RightAnti Join: test1.a = test2.a Filter: test2.b > UInt32(2) + Projection: test1.a, test1.b + TableScan: test1, full_filters=[test1.b > UInt32(1)] + Projection: test2.a, test2.b + TableScan: test2 + " + ) } #[derive(Debug)] @@ -3648,21 +3875,27 @@ Projection: a, b .project(vec![col("t.a"), col("t.r")])? .build()?; - let expected_before = "Projection: t.a, t.r\ - \n Filter: t.a > Int32(5) AND t.r > Float64(0.5)\ - \n SubqueryAlias: t\ - \n Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r\ - \n Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]\ - \n TableScan: test1"; - assert_eq!(format!("{plan}"), expected_before); - - let expected_after = "Projection: t.a, t.r\ - \n SubqueryAlias: t\ - \n Filter: r > Float64(0.5)\ - \n Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r\ - \n Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]\ - \n TableScan: test1, full_filters=[test1.a > Int32(5)]"; - assert_optimized_plan_eq(plan, expected_after) + assert_snapshot!(plan, + @r" + Projection: t.a, t.r + Filter: t.a > Int32(5) AND t.r > Float64(0.5) + SubqueryAlias: t + Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r + Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]] + TableScan: test1 + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + Projection: t.a, t.r + SubqueryAlias: t + Filter: r > Float64(0.5) + Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r + Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]] + TableScan: test1, full_filters=[test1.a > Int32(5)] + " + ) } #[test] @@ -3692,23 +3925,29 @@ Projection: a, b .project(vec![col("t.a"), col("t.r")])? .build()?; - let expected_before = "Projection: t.a, t.r\ - \n Filter: t.r > Float64(0.8)\ - \n SubqueryAlias: t\ - \n Projection: test1.a AS a, TestScalarUDF() AS r\ - \n Inner Join: test1.a = test2.a\ - \n TableScan: test1\ - \n TableScan: test2"; - assert_eq!(format!("{plan}"), expected_before); - - let expected = "Projection: t.a, t.r\ - \n SubqueryAlias: t\ - \n Filter: r > Float64(0.8)\ - \n Projection: test1.a AS a, TestScalarUDF() AS r\ - \n Inner Join: test1.a = test2.a\ - \n TableScan: test1\ - \n TableScan: test2"; - assert_optimized_plan_eq(plan, expected) + assert_snapshot!(plan, + @r" + Projection: t.a, t.r + Filter: t.r > Float64(0.8) + SubqueryAlias: t + Projection: test1.a AS a, TestScalarUDF() AS r + Inner Join: test1.a = test2.a + TableScan: test1 + TableScan: test2 + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + Projection: t.a, t.r + SubqueryAlias: t + Filter: r > Float64(0.8) + Projection: test1.a AS a, TestScalarUDF() AS r + Inner Join: test1.a = test2.a + TableScan: test1 + TableScan: test2 + " + ) } #[test] @@ -3724,15 +3963,21 @@ Projection: a, b .filter(expr.gt(lit(0.1)))? .build()?; - let expected_before = "Filter: TestScalarUDF() > Float64(0.1)\ - \n Projection: test.a, test.b\ - \n TableScan: test"; - assert_eq!(format!("{plan}"), expected_before); - - let expected_after = "Projection: test.a, test.b\ - \n Filter: TestScalarUDF() > Float64(0.1)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected_after) + assert_snapshot!(plan, + @r" + Filter: TestScalarUDF() > Float64(0.1) + Projection: test.a, test.b + TableScan: test + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.b + Filter: TestScalarUDF() > Float64(0.1) + TableScan: test + " + ) } #[test] @@ -3752,15 +3997,21 @@ Projection: a, b )? .build()?; - let expected_before = "Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)\ - \n Projection: test.a, test.b\ - \n TableScan: test"; - assert_eq!(format!("{plan}"), expected_before); - - let expected_after = "Projection: test.a, test.b\ - \n Filter: TestScalarUDF() > Float64(0.1)\ - \n TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)]"; - assert_optimized_plan_eq(plan, expected_after) + assert_snapshot!(plan, + @r" + Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10) + Projection: test.a, test.b + TableScan: test + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, test.b + Filter: TestScalarUDF() > Float64(0.1) + TableScan: test, full_filters=[t.a > Int32(5), t.b > Int32(10)] + " + ) } #[test] @@ -3783,15 +4034,21 @@ Projection: a, b )? .build()?; - let expected_before = "Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10)\ - \n Projection: a, b\ - \n TableScan: test"; - assert_eq!(format!("{plan}"), expected_before); - - let expected_after = "Projection: a, b\ - \n Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected_after) + assert_snapshot!(plan, + @r" + Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10) + Projection: a, b + TableScan: test + ", + ); + assert_optimized_plan_equal!( + plan, + @r" + Projection: a, b + Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1) + TableScan: test + " + ) } #[test] @@ -3864,12 +4121,19 @@ Projection: a, b let plan = LogicalPlanBuilder::from(node).filter(lit(false))?.build()?; // Check the original plan format (not part of the test assertions) - let expected_before = "Filter: Boolean(false)\ - \n TestUserNode"; - assert_eq!(format!("{plan}"), expected_before); - + assert_snapshot!(plan, + @r" + Filter: Boolean(false) + TestUserNode + ", + ); // Check that the filter is pushed down to the user-defined node - let expected_after = "Filter: Boolean(false)\n TestUserNode"; - assert_optimized_plan_eq(plan, expected_after) + assert_optimized_plan_equal!( + plan, + @r" + Filter: Boolean(false) + TestUserNode + " + ) } } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 1e9ef16bde67..ec042dd350ca 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -276,8 +276,10 @@ mod test { use std::vec; use super::*; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; + use crate::OptimizerContext; use datafusion_common::DFSchemaRef; use datafusion_expr::{ col, exists, logical_plan::builder::LogicalPlanBuilder, Expr, Extension, @@ -285,8 +287,20 @@ mod test { }; use datafusion_functions_aggregate::expr_fn::max; - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(PushDownLimit::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[derive(Debug, PartialEq, Eq, Hash)] @@ -408,12 +422,15 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Limit: skip=0, fetch=1000\ - \n NoopPlan\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + NoopPlan + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -430,12 +447,15 @@ mod test { .limit(10, Some(1000))? .build()?; - let expected = "Limit: skip=10, fetch=1000\ - \n NoopPlan\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + NoopPlan + Limit: skip=0, fetch=1010 + TableScan: test, fetch=1010 + " + ) } #[test] @@ -453,12 +473,15 @@ mod test { .limit(20, Some(500))? .build()?; - let expected = "Limit: skip=30, fetch=500\ - \n NoopPlan\ - \n Limit: skip=0, fetch=530\ - \n TableScan: test, fetch=530"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=30, fetch=500 + NoopPlan + Limit: skip=0, fetch=530 + TableScan: test, fetch=530 + " + ) } #[test] @@ -475,14 +498,17 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Limit: skip=0, fetch=1000\ - \n NoopPlan\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + NoopPlan + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -499,11 +525,14 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Limit: skip=0, fetch=1000\ - \n NoLimitNoopPlan\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + NoLimitNoopPlan + TableScan: test + " + ) } #[test] @@ -517,11 +546,14 @@ mod test { // Should push the limit down to table provider // When it has a select - let expected = "Projection: test.a\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -536,10 +568,13 @@ mod test { // Should push down the smallest limit // Towards table scan // This rule doesn't replace multiple limits - let expected = "Limit: skip=0, fetch=10\ - \n TableScan: test, fetch=10"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=10 + TableScan: test, fetch=10 + " + ) } #[test] @@ -552,11 +587,14 @@ mod test { .build()?; // Limit should *not* push down aggregate node - let expected = "Limit: skip=0, fetch=1000\ - \n Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]] + TableScan: test + " + ) } #[test] @@ -569,14 +607,17 @@ mod test { .build()?; // Limit should push down through union - let expected = "Limit: skip=0, fetch=1000\ - \n Union\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + Union + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -589,11 +630,14 @@ mod test { .build()?; // Should push down limit to sort - let expected = "Limit: skip=0, fetch=10\ - \n Sort: test.a ASC NULLS LAST, fetch=10\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=10 + Sort: test.a ASC NULLS LAST, fetch=10 + TableScan: test + " + ) } #[test] @@ -606,11 +650,14 @@ mod test { .build()?; // Should push down limit to sort - let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a ASC NULLS LAST, fetch=15\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=5, fetch=10 + Sort: test.a ASC NULLS LAST, fetch=15 + TableScan: test + " + ) } #[test] @@ -624,12 +671,15 @@ mod test { .build()?; // Limit should use deeper LIMIT 1000, but Limit 10 shouldn't push down aggregation - let expected = "Limit: skip=0, fetch=10\ - \n Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=10 + Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]] + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -641,10 +691,13 @@ mod test { // Should not push any limit down to table provider // When it has a select - let expected = "Limit: skip=10, fetch=None\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=None + TableScan: test + " + ) } #[test] @@ -658,11 +711,14 @@ mod test { // Should push the limit down to table provider // When it has a select - let expected = "Projection: test.a\ - \n Limit: skip=10, fetch=1000\ - \n TableScan: test, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Limit: skip=10, fetch=1000 + TableScan: test, fetch=1010 + " + ) } #[test] @@ -675,11 +731,14 @@ mod test { .limit(10, None)? .build()?; - let expected = "Projection: test.a\ - \n Limit: skip=10, fetch=990\ - \n TableScan: test, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Limit: skip=10, fetch=990 + TableScan: test, fetch=1000 + " + ) } #[test] @@ -692,11 +751,14 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Projection: test.a\ - \n Limit: skip=10, fetch=1000\ - \n TableScan: test, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a + Limit: skip=10, fetch=1000 + TableScan: test, fetch=1010 + " + ) } #[test] @@ -709,10 +771,13 @@ mod test { .limit(0, Some(10))? .build()?; - let expected = "Limit: skip=10, fetch=10\ - \n TableScan: test, fetch=20"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=10 + TableScan: test, fetch=20 + " + ) } #[test] @@ -725,11 +790,14 @@ mod test { .build()?; // Limit should *not* push down aggregate node - let expected = "Limit: skip=10, fetch=1000\ - \n Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]\ - \n TableScan: test"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]] + TableScan: test + " + ) } #[test] @@ -742,14 +810,17 @@ mod test { .build()?; // Limit should push down through union - let expected = "Limit: skip=10, fetch=1000\ - \n Union\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test, fetch=1010\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Union + Limit: skip=0, fetch=1010 + TableScan: test, fetch=1010 + Limit: skip=0, fetch=1010 + TableScan: test, fetch=1010 + " + ) } #[test] @@ -768,12 +839,15 @@ mod test { .build()?; // Limit pushdown Not supported in Join - let expected = "Limit: skip=10, fetch=1000\ - \n Inner Join: test.a = test2.a\ - \n TableScan: test\ - \n TableScan: test2"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Inner Join: test.a = test2.a + TableScan: test + TableScan: test2 + " + ) } #[test] @@ -792,12 +866,15 @@ mod test { .build()?; // Limit pushdown Not supported in Join - let expected = "Limit: skip=10, fetch=1000\ - \n Inner Join: test.a = test2.a\ - \n TableScan: test\ - \n TableScan: test2"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Inner Join: test.a = test2.a + TableScan: test + TableScan: test2 + " + ) } #[test] @@ -817,16 +894,19 @@ mod test { .build()?; // Limit pushdown Not supported in sub_query - let expected = "Limit: skip=10, fetch=100\ - \n Filter: EXISTS ()\ - \n Subquery:\ - \n Filter: test1.a = test1.a\ - \n Projection: test1.a\ - \n TableScan: test1\ - \n Projection: test2.a\ - \n TableScan: test2"; - - assert_optimized_plan_equal(outer_query, expected) + assert_optimized_plan_equal!( + outer_query, + @r" + Limit: skip=10, fetch=100 + Filter: EXISTS () + Subquery: + Filter: test1.a = test1.a + Projection: test1.a + TableScan: test1 + Projection: test2.a + TableScan: test2 + " + ) } #[test] @@ -846,16 +926,19 @@ mod test { .build()?; // Limit pushdown Not supported in sub_query - let expected = "Limit: skip=10, fetch=100\ - \n Filter: EXISTS ()\ - \n Subquery:\ - \n Filter: test1.a = test1.a\ - \n Projection: test1.a\ - \n TableScan: test1\ - \n Projection: test2.a\ - \n TableScan: test2"; - - assert_optimized_plan_equal(outer_query, expected) + assert_optimized_plan_equal!( + outer_query, + @r" + Limit: skip=10, fetch=100 + Filter: EXISTS () + Subquery: + Filter: test1.a = test1.a + Projection: test1.a + TableScan: test1 + Projection: test2.a + TableScan: test2 + " + ) } #[test] @@ -874,13 +957,16 @@ mod test { .build()?; // Limit pushdown Not supported in Join - let expected = "Limit: skip=10, fetch=1000\ - \n Left Join: test.a = test2.a\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test, fetch=1010\ - \n TableScan: test2"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Left Join: test.a = test2.a + Limit: skip=0, fetch=1010 + TableScan: test, fetch=1010 + TableScan: test2 + " + ) } #[test] @@ -899,13 +985,16 @@ mod test { .build()?; // Limit pushdown Not supported in Join - let expected = "Limit: skip=0, fetch=1000\ - \n Right Join: test.a = test2.a\ - \n TableScan: test\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test2, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + Right Join: test.a = test2.a + TableScan: test + Limit: skip=0, fetch=1000 + TableScan: test2, fetch=1000 + " + ) } #[test] @@ -924,13 +1013,16 @@ mod test { .build()?; // Limit pushdown with offset supported in right outer join - let expected = "Limit: skip=10, fetch=1000\ - \n Right Join: test.a = test2.a\ - \n TableScan: test\ - \n Limit: skip=0, fetch=1010\ - \n TableScan: test2, fetch=1010"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=10, fetch=1000 + Right Join: test.a = test2.a + TableScan: test + Limit: skip=0, fetch=1010 + TableScan: test2, fetch=1010 + " + ) } #[test] @@ -943,14 +1035,17 @@ mod test { .limit(0, Some(1000))? .build()?; - let expected = "Limit: skip=0, fetch=1000\ - \n Cross Join: \ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test, fetch=1000\ - \n Limit: skip=0, fetch=1000\ - \n TableScan: test2, fetch=1000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + Cross Join: + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + Limit: skip=0, fetch=1000 + TableScan: test2, fetch=1000 + " + ) } #[test] @@ -963,14 +1058,17 @@ mod test { .limit(1000, Some(1000))? .build()?; - let expected = "Limit: skip=1000, fetch=1000\ - \n Cross Join: \ - \n Limit: skip=0, fetch=2000\ - \n TableScan: test, fetch=2000\ - \n Limit: skip=0, fetch=2000\ - \n TableScan: test2, fetch=2000"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=1000, fetch=1000 + Cross Join: + Limit: skip=0, fetch=2000 + TableScan: test, fetch=2000 + Limit: skip=0, fetch=2000 + TableScan: test2, fetch=2000 + " + ) } #[test] @@ -982,10 +1080,13 @@ mod test { .limit(1000, None)? .build()?; - let expected = "Limit: skip=1000, fetch=0\ - \n TableScan: test, fetch=0"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=1000, fetch=0 + TableScan: test, fetch=0 + " + ) } #[test] @@ -997,10 +1098,13 @@ mod test { .limit(1000, None)? .build()?; - let expected = "Limit: skip=1000, fetch=0\ - \n TableScan: test, fetch=0"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=1000, fetch=0 + TableScan: test, fetch=0 + " + ) } #[test] @@ -1013,10 +1117,13 @@ mod test { .limit(1000, None)? .build()?; - let expected = "SubqueryAlias: a\ - \n Limit: skip=1000, fetch=0\ - \n TableScan: test, fetch=0"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + SubqueryAlias: a + Limit: skip=1000, fetch=0 + TableScan: test, fetch=0 + " + ) } } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 48b2828faf45..2383787fa0e8 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -186,21 +186,29 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { mod tests { use std::sync::Arc; + use crate::assert_optimized_plan_eq_snapshot; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; use crate::test::*; + use crate::OptimizerContext; use datafusion_common::Result; - use datafusion_expr::{ - col, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan, - }; + use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder, Expr}; use datafusion_functions_aggregate::sum::sum; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq( - Arc::new(ReplaceDistinctWithAggregate::new()), - plan.clone(), - expected, - ) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = vec![Arc::new(ReplaceDistinctWithAggregate::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -212,8 +220,11 @@ mod tests { .distinct()? .build()?; - let expected = "Projection: test.c\n Aggregate: groupBy=[[test.c]], aggr=[[]]\n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal!(plan, @r" + Projection: test.c + Aggregate: groupBy=[[test.c]], aggr=[[]] + TableScan: test + ") } #[test] @@ -225,9 +236,11 @@ mod tests { .distinct()? .build()?; - let expected = - "Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal!(plan, @r" + Projection: test.a, test.b + Aggregate: groupBy=[[test.a, test.b]], aggr=[[]] + TableScan: test + ") } #[test] @@ -238,8 +251,11 @@ mod tests { .distinct()? .build()?; - let expected = "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[test.a, test.b]], aggr=[[]] + Projection: test.a, test.b + TableScan: test + ") } #[test] @@ -251,8 +267,11 @@ mod tests { .distinct()? .build()?; - let expected = - "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]]\n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[test.a, test.b]], aggr=[[]] + Projection: test.a, test.b + Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]] + TableScan: test + ") } } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 5c89bc29a596..2f9a2f6bb9ed 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -335,7 +335,7 @@ fn build_join( .join_on( sub_query_alias, JoinType::Left, - vec![Expr::Literal(ScalarValue::Boolean(Some(true)))], + vec![Expr::Literal(ScalarValue::Boolean(Some(true)), None)], )? .build()? } @@ -365,7 +365,7 @@ fn build_join( ), ( Box::new(Expr::Not(Box::new(filter.clone()))), - Box::new(Expr::Literal(ScalarValue::Null)), + Box::new(Expr::Literal(ScalarValue::Null, None)), ), ], else_expr: Some(Box::new(Expr::Column(Column::new_unqualified( @@ -407,9 +407,24 @@ mod tests { use arrow::datatypes::DataType; use datafusion_expr::test::function_stub::sum; + use crate::assert_optimized_plan_eq_display_indent_snapshot; use datafusion_expr::{col, lit, out_ref_col, scalar_subquery, Between}; use datafusion_functions_aggregate::min_max::{max, min}; + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(ScalarSubqueryToJoin::new()); + assert_optimized_plan_eq_display_indent_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; + } + /// Test multiple correlated subqueries #[test] fn multiple_subqueries() -> Result<()> { @@ -433,25 +448,24 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test recursive correlated subqueries @@ -488,26 +502,25 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_acctbal < __scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: sum(orders.o_totalprice), orders.o_custkey, __always_true [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, sum(orders.o_totalprice):Float64;N]\ - \n Filter: orders.o_totalprice < __scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean]\ - \n Projection: sum(lineitem.l_extendedprice), lineitem.l_orderkey, __always_true [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[lineitem.l_orderkey, Boolean(true) AS __always_true]], aggr=[[sum(lineitem.l_extendedprice)]] [l_orderkey:Int64, __always_true:Boolean, sum(lineitem.l_extendedprice):Float64;N]\ - \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_acctbal < __scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean] + Projection: sum(orders.o_totalprice), orders.o_custkey, __always_true [sum(orders.o_totalprice):Float64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, sum(orders.o_totalprice):Float64;N] + Filter: orders.o_totalprice < __scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N, __always_true:Boolean;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __scalar_sq_2 [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean] + Projection: sum(lineitem.l_extendedprice), lineitem.l_orderkey, __always_true [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[lineitem.l_orderkey, Boolean(true) AS __always_true]], aggr=[[sum(lineitem.l_extendedprice)]] [l_orderkey:Int64, __always_true:Boolean, sum(lineitem.l_extendedprice):Float64;N] + TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64] + " + ) } /// Test for correlated scalar subquery filter with additional subquery filters @@ -530,22 +543,20 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\ - \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated scalar subquery with no columns in schema @@ -568,20 +579,19 @@ mod tests { .build()?; // it will optimize, but fail for the same reason the unoptimized query would - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for scalar subquery with both columns in schema @@ -600,22 +610,20 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated scalar subquery not equal @@ -638,21 +646,19 @@ mod tests { .build()?; // Unsupported predicate, subquery should not be decorrelated - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ - \n Subquery: [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + " + ) } /// Test for correlated scalar subquery less than @@ -675,21 +681,19 @@ mod tests { .build()?; // Unsupported predicate, subquery should not be decorrelated - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ - \n Subquery: [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + " + ) } /// Test for correlated scalar subquery filter with subquery disjunction @@ -713,21 +717,19 @@ mod tests { .build()?; // Unsupported predicate, subquery should not be decorrelated - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ - \n Subquery: [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + " + ) } /// Test for correlated scalar without projection @@ -768,21 +770,19 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: max(orders.o_custkey) + Int32(1), orders.o_custkey, __always_true [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey) + Int32(1), orders.o_custkey, __always_true [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated scalar subquery with non-strong project @@ -812,20 +812,18 @@ mod tests { .project(vec![col("customer.c_custkey"), scalar_subquery(sq)])? .build()?; - let expected = "Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(NULL AS Boolean) THEN Utf8(\"a\") ELSE Utf8(\"b\") END ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END END AS CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END [c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END:Utf8;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END:Utf8;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END:Utf8, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END, orders.o_custkey, __always_true [CASE WHEN max(orders.o_totalprice) THEN Utf8(\"a\") ELSE Utf8(\"b\") END:Utf8, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_totalprice):Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r#" + Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(NULL AS Boolean) THEN Utf8("a") ELSE Utf8("b") END ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END END AS CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END [c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean] + Projection: CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END, orders.o_custkey, __always_true [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_totalprice)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_totalprice):Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + "# + ) } /// Test for correlated scalar subquery multiple projected columns @@ -875,21 +873,19 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } #[test] @@ -914,21 +910,19 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated scalar subquery filter with disjunctions @@ -954,21 +948,19 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } /// Test for correlated scalar subquery filter @@ -987,21 +979,19 @@ mod tests { .project(vec![col("test.c")])? .build()?; - let expected = "Projection: test.c [c:UInt32]\ - \n Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]\ - \n Left Join: Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]\ - \n Projection: min(sq.c), sq.a, __always_true [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean]\ - \n Aggregate: groupBy=[[sq.a, Boolean(true) AS __always_true]], aggr=[[min(sq.c)]] [a:UInt32, __always_true:Boolean, min(sq.c):UInt32;N]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: test.c [c:UInt32] + Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N] + Left Join: Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N, __always_true:Boolean;N] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean] + Projection: min(sq.c), sq.a, __always_true [min(sq.c):UInt32;N, a:UInt32, __always_true:Boolean] + Aggregate: groupBy=[[sq.a, Boolean(true) AS __always_true]], aggr=[[min(sq.c)]] [a:UInt32, __always_true:Boolean, min(sq.c):UInt32;N] + TableScan: sq [a:UInt32, b:UInt32, c:UInt32] + " + ) } /// Test for non-correlated scalar subquery with no filters @@ -1019,21 +1009,19 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } #[test] @@ -1050,21 +1038,19 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } #[test] @@ -1102,26 +1088,24 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: min(orders.o_custkey), orders.o_custkey, __always_true [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, min(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean]\ - \n Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, __always_true:Boolean;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: min(orders.o_custkey), orders.o_custkey, __always_true [min(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, min(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Projection: max(orders.o_custkey), orders.o_custkey, __always_true [max(orders.o_custkey):Int64;N, o_custkey:Int64, __always_true:Boolean] + Aggregate: groupBy=[[orders.o_custkey, Boolean(true) AS __always_true]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, __always_true:Boolean, max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } #[test] @@ -1151,25 +1135,23 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]\ - \n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]\ - \n Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N]\ - \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N]\ - \n Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N]\ - \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - - assert_multi_rules_optimized_plan_eq_display_indent( - vec![Arc::new(ScalarSubqueryToJoin::new())], + assert_optimized_plan_equal!( plan, - expected, - ); - Ok(()) + @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N] + Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + " + ) } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 8e25bb753436..2be7a2b0bd6e 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -34,11 +34,11 @@ use datafusion_common::{ use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, - Operator, Volatility, WindowFunctionDefinition, + Operator, Volatility, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_expr::{ - expr::{InList, InSubquery, WindowFunction}, + expr::{InList, InSubquery}, utils::{iter_conjunction, iter_conjunction_owned}, }; use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast}; @@ -58,6 +58,7 @@ use crate::{ analyzer::type_coercion::TypeCoercionRewriter, simplify_expressions::unwrap_cast::try_cast_literal_to_type, }; +use datafusion_expr::expr::FieldMetadata; use indexmap::IndexSet; use regex::Regex; @@ -188,7 +189,7 @@ impl ExprSimplifier { /// assert_eq!(expr, b_lt_2); /// ``` pub fn simplify(&self, expr: Expr) -> Result { - Ok(self.simplify_with_cycle_count(expr)?.0) + Ok(self.simplify_with_cycle_count_transformed(expr)?.0.data) } /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating @@ -198,7 +199,34 @@ impl ExprSimplifier { /// /// See [Self::simplify] for details and usage examples. /// + #[deprecated( + since = "48.0.0", + note = "Use `simplify_with_cycle_count_transformed` instead" + )] + #[allow(unused_mut)] pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> { + let (transformed, cycle_count) = + self.simplify_with_cycle_count_transformed(expr)?; + Ok((transformed.data, cycle_count)) + } + + /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating + /// constants and applying algebraic simplifications. Additionally returns a `u32` + /// representing the number of simplification cycles performed, which can be useful for testing + /// optimizations. + /// + /// # Returns + /// + /// A tuple containing: + /// - The simplified expression wrapped in a `Transformed` indicating if changes were made + /// - The number of simplification cycles that were performed + /// + /// See [Self::simplify] for details and usage examples. + /// + pub fn simplify_with_cycle_count_transformed( + &self, + mut expr: Expr, + ) -> Result<(Transformed, u32)> { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); @@ -212,6 +240,7 @@ impl ExprSimplifier { // simplifications can enable new constant evaluation // see `Self::with_max_cycles` let mut num_cycles = 0; + let mut has_transformed = false; loop { let Transformed { data, transformed, .. @@ -221,13 +250,18 @@ impl ExprSimplifier { .transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?; expr = data; num_cycles += 1; + // Track if any transformation occurred + has_transformed = has_transformed || transformed; if !transformed || num_cycles >= self.max_simplifier_cycles { break; } } // shorten inlist should be started after other inlist rules are applied expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?; - Ok((expr, num_cycles)) + Ok(( + Transformed::new_transformed(expr, has_transformed), + num_cycles, + )) } /// Apply type coercion to an [`Expr`] so that it can be @@ -392,15 +426,15 @@ impl ExprSimplifier { /// let expr = col("a").is_not_null(); /// /// // When using default maximum cycles, 2 cycles will be performed. - /// let (simplified_expr, count) = simplifier.simplify_with_cycle_count(expr.clone()).unwrap(); - /// assert_eq!(simplified_expr, lit(true)); + /// let (simplified_expr, count) = simplifier.simplify_with_cycle_count_transformed(expr.clone()).unwrap(); + /// assert_eq!(simplified_expr.data, lit(true)); /// // 2 cycles were executed, but only 1 was needed /// assert_eq!(count, 2); /// /// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1. - /// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count(expr.clone()).unwrap(); + /// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count_transformed(expr.clone()).unwrap(); /// // Expression has been rewritten to: (c = a AND b = 1) - /// assert_eq!(simplified_expr, lit(true)); + /// assert_eq!(simplified_expr.data, lit(true)); /// // Only 1 cycle was executed /// assert_eq!(count, 1); /// @@ -444,7 +478,7 @@ impl TreeNodeRewriter for Canonicalizer { }))) } // - (Expr::Literal(_a), Expr::Column(_b), Some(swapped_op)) => { + (Expr::Literal(_a, _), Expr::Column(_b), Some(swapped_op)) => { Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, op: swapped_op, @@ -487,11 +521,12 @@ struct ConstEvaluator<'a> { #[allow(dead_code)] /// The simplify result of ConstEvaluator +#[allow(clippy::large_enum_variant)] enum ConstSimplifyResult { // Expr was simplified and contains the new expression - Simplified(ScalarValue), + Simplified(ScalarValue, Option), // Expr was not simplified and original value is returned - NotSimplified(ScalarValue), + NotSimplified(ScalarValue, Option), // Evaluation encountered an error, contains the original expression SimplifyRuntimeError(DataFusionError, Expr), } @@ -533,11 +568,11 @@ impl TreeNodeRewriter for ConstEvaluator<'_> { // any error is countered during simplification, return the original // so that normal evaluation can occur Some(true) => match self.evaluate_to_scalar(expr) { - ConstSimplifyResult::Simplified(s) => { - Ok(Transformed::yes(Expr::Literal(s))) + ConstSimplifyResult::Simplified(s, m) => { + Ok(Transformed::yes(Expr::Literal(s, m))) } - ConstSimplifyResult::NotSimplified(s) => { - Ok(Transformed::no(Expr::Literal(s))) + ConstSimplifyResult::NotSimplified(s, m) => { + Ok(Transformed::no(Expr::Literal(s, m))) } ConstSimplifyResult::SimplifyRuntimeError(_, expr) => { Ok(Transformed::yes(expr)) @@ -606,7 +641,7 @@ impl<'a> ConstEvaluator<'a> { Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } - Expr::Literal(_) + Expr::Literal(_, _) | Expr::Alias(..) | Expr::Unnest(_) | Expr::BinaryExpr { .. } @@ -632,8 +667,8 @@ impl<'a> ConstEvaluator<'a> { /// Internal helper to evaluates an Expr pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> ConstSimplifyResult { - if let Expr::Literal(s) = expr { - return ConstSimplifyResult::NotSimplified(s); + if let Expr::Literal(s, m) = expr { + return ConstSimplifyResult::NotSimplified(s, m); } let phys_expr = @@ -641,6 +676,16 @@ impl<'a> ConstEvaluator<'a> { Ok(e) => e, Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), }; + let metadata = phys_expr + .return_field(self.input_batch.schema_ref()) + .ok() + .and_then(|f| { + let m = f.metadata(); + match m.is_empty() { + true => None, + false => Some(FieldMetadata::from(m)), + } + }); let col_val = match phys_expr.evaluate(&self.input_batch) { Ok(v) => v, Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), @@ -653,13 +698,15 @@ impl<'a> ConstEvaluator<'a> { expr, ) } else if as_list_array(&a).is_ok() { - ConstSimplifyResult::Simplified(ScalarValue::List( - a.as_list::().to_owned().into(), - )) + ConstSimplifyResult::Simplified( + ScalarValue::List(a.as_list::().to_owned().into()), + metadata, + ) } else if as_large_list_array(&a).is_ok() { - ConstSimplifyResult::Simplified(ScalarValue::LargeList( - a.as_list::().to_owned().into(), - )) + ConstSimplifyResult::Simplified( + ScalarValue::LargeList(a.as_list::().to_owned().into()), + metadata, + ) } else { // Non-ListArray match ScalarValue::try_from_array(&a, 0) { @@ -671,7 +718,7 @@ impl<'a> ConstEvaluator<'a> { expr, ) } else { - ConstSimplifyResult::Simplified(s) + ConstSimplifyResult::Simplified(s, metadata) } } Err(err) => ConstSimplifyResult::SimplifyRuntimeError(err, expr), @@ -689,7 +736,7 @@ impl<'a> ConstEvaluator<'a> { expr, ) } else { - ConstSimplifyResult::Simplified(s) + ConstSimplifyResult::Simplified(s, metadata) } } } @@ -1104,9 +1151,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { && !info.get_data_type(&left)?.is_floating() && is_one(&right) => { - Transformed::yes(Expr::Literal(ScalarValue::new_zero( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&left)?)?, + None, + )) } // @@ -1147,9 +1195,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseAnd, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_zero( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&left)?)?, + None, + )) } // A & !A -> 0 (if A not nullable) @@ -1158,9 +1207,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseAnd, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_zero( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&left)?)?, + None, + )) } // (..A..) & A --> (..A..) @@ -1233,9 +1283,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseOr, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_negative_one(&info.get_data_type(&left)?)?, + None, + )) } // A | !A -> -1 (if A not nullable) @@ -1244,9 +1295,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseOr, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_negative_one(&info.get_data_type(&left)?)?, + None, + )) } // (..A..) | A --> (..A..) @@ -1319,9 +1371,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseXor, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_negative_one(&info.get_data_type(&left)?)?, + None, + )) } // A ^ !A -> -1 (if A not nullable) @@ -1330,9 +1383,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { op: BitwiseXor, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( - &info.get_data_type(&left)?, - )?)) + Transformed::yes(Expr::Literal( + ScalarValue::new_negative_one(&info.get_data_type(&left)?)?, + None, + )) } // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A) @@ -1343,7 +1397,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { }) if expr_contains(&left, &right, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&left, &right, false); Transformed::yes(if expr == *right { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?) + Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&right)?)?, + None, + ) } else { expr }) @@ -1357,7 +1414,10 @@ impl TreeNodeRewriter for Simplifier<'_, S> { }) if expr_contains(&right, &left, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&right, &left, true); Transformed::yes(if expr == *left { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + Expr::Literal( + ScalarValue::new_zero(&info.get_data_type(&left)?)?, + None, + ) } else { expr }) @@ -1489,12 +1549,9 @@ impl TreeNodeRewriter for Simplifier<'_, S> { (_, expr) => Transformed::no(expr), }, - Expr::WindowFunction(WindowFunction { - fun: WindowFunctionDefinition::WindowUDF(ref udwf), - .. - }) => match (udwf.simplify(), expr) { + Expr::WindowFunction(ref window_fun) => match (window_fun.simplify(), expr) { (Some(simplify_function), Expr::WindowFunction(wf)) => { - Transformed::yes(simplify_function(wf, info)?) + Transformed::yes(simplify_function(*wf, info)?) } (_, expr) => Transformed::no(expr), }, @@ -1573,8 +1630,9 @@ impl TreeNodeRewriter for Simplifier<'_, S> { })) } Some(pattern_str) - if !pattern_str - .contains(['%', '_', escape_char].as_ref()) => + if !like.case_insensitive + && !pattern_str + .contains(['%', '_', escape_char].as_ref()) => { // If the pattern does not contain any wildcards, we can simplify the like expression to an equality expression // TODO: handle escape characters @@ -1610,7 +1668,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { expr, list, negated, - }) if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) => { + }) if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null, None) => { Transformed::yes(lit(negated)) } @@ -1793,7 +1851,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { info, &left, op, &right, ) && op.supports_propagation() => { - unwrap_cast_in_comparison_for_binary(info, left, right, op)? + unwrap_cast_in_comparison_for_binary(info, *left, *right, op)? } // literal op try_cast/cast(expr as data_type) // --> @@ -1806,8 +1864,8 @@ impl TreeNodeRewriter for Simplifier<'_, S> { { unwrap_cast_in_comparison_for_binary( info, - right, - left, + *right, + *left, op.swap().unwrap(), )? } @@ -1836,7 +1894,7 @@ impl TreeNodeRewriter for Simplifier<'_, S> { .into_iter() .map(|right| { match right { - Expr::Literal(right_lit_value) => { + Expr::Literal(right_lit_value, _) => { // if the right_lit_value can be casted to the type of internal_left_expr // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal let Some(value) = try_cast_literal_to_type(&right_lit_value, &expr_type) else { @@ -1870,18 +1928,18 @@ impl TreeNodeRewriter for Simplifier<'_, S> { fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option)> { match expr { - Expr::Literal(ScalarValue::Utf8(s)) => Some((DataType::Utf8, s)), - Expr::Literal(ScalarValue::LargeUtf8(s)) => Some((DataType::LargeUtf8, s)), - Expr::Literal(ScalarValue::Utf8View(s)) => Some((DataType::Utf8View, s)), + Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)), + Expr::Literal(ScalarValue::LargeUtf8(s), _) => Some((DataType::LargeUtf8, s)), + Expr::Literal(ScalarValue::Utf8View(s), _) => Some((DataType::Utf8View, s)), _ => None, } } fn to_string_scalar(data_type: DataType, value: Option) -> Expr { match data_type { - DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value)), - DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value)), - DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value)), + DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value), None), + DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value), None), + DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value), None), _ => unreachable!(), } } @@ -1927,12 +1985,12 @@ fn as_inlist(expr: &Expr) -> Option> { Expr::InList(inlist) => Some(Cow::Borrowed(inlist)), Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => { match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(_)) => Some(Cow::Owned(InList { + (Expr::Column(_), Expr::Literal(_, _)) => Some(Cow::Owned(InList { expr: left.clone(), list: vec![*right.clone()], negated: false, })), - (Expr::Literal(_), Expr::Column(_)) => Some(Cow::Owned(InList { + (Expr::Literal(_, _), Expr::Column(_)) => Some(Cow::Owned(InList { expr: right.clone(), list: vec![*left.clone()], negated: false, @@ -1952,12 +2010,12 @@ fn to_inlist(expr: Expr) -> Option { op: Operator::Eq, right, }) => match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(_)) => Some(InList { + (Expr::Column(_), Expr::Literal(_, _)) => Some(InList { expr: left, list: vec![*right], negated: false, }), - (Expr::Literal(_), Expr::Column(_)) => Some(InList { + (Expr::Literal(_, _), Expr::Column(_)) => Some(InList { expr: right, list: vec![*left], negated: false, @@ -2109,10 +2167,13 @@ fn simplify_null_div_other_case( #[cfg(test)] mod tests { + use super::*; use crate::simplify_expressions::SimplifyContext; use crate::test::test_table_scan_with_name; + use arrow::datatypes::FieldRef; use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; use datafusion_expr::{ + expr::WindowFunction, function::{ AccumulatorArgs, AggregateFunctionSimplification, WindowFunctionSimplification, @@ -2128,8 +2189,6 @@ mod tests { sync::Arc, }; - use super::*; - // ------------------------------ // --- ExprSimplifier tests ----- // ------------------------------ @@ -2375,7 +2434,7 @@ mod tests { #[test] fn test_simplify_multiply_by_null() { - let null = Expr::Literal(ScalarValue::Null); + let null = Expr::Literal(ScalarValue::Null, None); // A * null --> null { let expr = col("c2") * null.clone(); @@ -3310,6 +3369,15 @@ mod tests { simplifier.simplify(expr) } + fn coerce(expr: Expr) -> Expr { + let schema = expr_test_schema(); + let execution_props = ExecutionProps::new(); + let simplifier = ExprSimplifier::new( + SimplifyContext::new(&execution_props).with_schema(Arc::clone(&schema)), + ); + simplifier.coerce(expr, schema.as_ref()).unwrap() + } + fn simplify(expr: Expr) -> Expr { try_simplify(expr).unwrap() } @@ -3320,7 +3388,8 @@ mod tests { let simplifier = ExprSimplifier::new( SimplifyContext::new(&execution_props).with_schema(schema), ); - simplifier.simplify_with_cycle_count(expr) + let (expr, count) = simplifier.simplify_with_cycle_count_transformed(expr)?; + Ok((expr.data, count)) } fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) { @@ -3352,6 +3421,7 @@ mod tests { Field::new("c2_non_null", DataType::Boolean, false), Field::new("c3_non_null", DataType::Int64, false), Field::new("c4_non_null", DataType::UInt32, false), + Field::new("c5", DataType::FixedSizeBinary(3), true), ] .into(), HashMap::new(), @@ -4058,6 +4128,11 @@ mod tests { assert_eq!(simplify(expr), col("c1").like(lit("a_"))); let expr = col("c1").not_like(lit("a_")); assert_eq!(simplify(expr), col("c1").not_like(lit("a_"))); + + let expr = col("c1").ilike(lit("a")); + assert_eq!(simplify(expr), col("c1").ilike(lit("a"))); + let expr = col("c1").not_ilike(lit("a")); + assert_eq!(simplify(expr), col("c1").not_ilike(lit("a"))); } #[test] @@ -4338,8 +4413,7 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), ); - let window_function_expr = - Expr::WindowFunction(WindowFunction::new(udwf, vec![])); + let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![])); let expected = col("result_column"); assert_eq!(simplify(window_function_expr), expected); @@ -4347,8 +4421,7 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), ); - let window_function_expr = - Expr::WindowFunction(WindowFunction::new(udwf, vec![])); + let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![])); let expected = window_function_expr.clone(); assert_eq!(simplify(window_function_expr), expected); @@ -4400,7 +4473,7 @@ mod tests { unimplemented!("not needed for tests") } - fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { unimplemented!("not needed for tests") } } @@ -4475,6 +4548,34 @@ mod tests { } } + #[test] + fn simplify_fixed_size_binary_eq_lit() { + let bytes = [1u8, 2, 3].as_slice(); + + // The expression starts simple. + let expr = col("c5").eq(lit(bytes)); + + // The type coercer introduces a cast. + let coerced = coerce(expr.clone()); + let schema = expr_test_schema(); + assert_eq!( + coerced, + col("c5") + .cast_to(&DataType::Binary, schema.as_ref()) + .unwrap() + .eq(lit(bytes)) + ); + + // The simplifier removes the cast. + assert_eq!( + simplify(coerced), + col("c5").eq(Expr::Literal( + ScalarValue::FixedSizeBinary(3, Some(bytes.to_vec()),), + None + )) + ); + } + fn if_not_null(expr: Expr, then: bool) -> Expr { Expr::Case(Case { expr: Some(expr.is_not_null().into()), diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 4700ab97b5f3..bbb023cfbad9 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -84,7 +84,7 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { low, high, }) => { - if let (Some(interval), Expr::Literal(low), Expr::Literal(high)) = ( + if let (Some(interval), Expr::Literal(low, _), Expr::Literal(high, _)) = ( self.guarantees.get(inner.as_ref()), low.as_ref(), high.as_ref(), @@ -115,7 +115,7 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { .get(left.as_ref()) .map(|interval| Cow::Borrowed(*interval)) .or_else(|| { - if let Expr::Literal(value) = left.as_ref() { + if let Expr::Literal(value, _) = left.as_ref() { Some(Cow::Owned(value.clone().into())) } else { None @@ -126,7 +126,7 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { .get(right.as_ref()) .map(|interval| Cow::Borrowed(*interval)) .or_else(|| { - if let Expr::Literal(value) = right.as_ref() { + if let Expr::Literal(value, _) = right.as_ref() { Some(Cow::Owned(value.clone().into())) } else { None @@ -168,7 +168,7 @@ impl TreeNodeRewriter for GuaranteeRewriter<'_> { let new_list: Vec = list .iter() .filter_map(|expr| { - if let Expr::Literal(item) = expr { + if let Expr::Literal(item, _) = expr { match interval .contains(NullableInterval::from(item.clone())) { @@ -244,8 +244,7 @@ mod tests { let expected = lit(ScalarValue::from(expected_value.clone())); assert_eq!( output, expected, - "{} simplified to {}, but expected {}", - expr, output, expected + "{expr} simplified to {output}, but expected {expected}" ); } } @@ -255,8 +254,7 @@ mod tests { let output = expr.clone().rewrite(rewriter).data().unwrap(); assert_eq!( &output, expr, - "{} was simplified to {}, but expected it to be unchanged", - expr, output + "{expr} was simplified to {output}, but expected it to be unchanged" ); } } @@ -417,7 +415,7 @@ mod tests { let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); let output = col("x").rewrite(&mut rewriter).data().unwrap(); - assert_eq!(output, Expr::Literal(scalar.clone())); + assert_eq!(output, Expr::Literal(scalar.clone(), None)); } } diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index c8638eb72395..a1c1dc17d294 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -39,10 +39,10 @@ impl TreeNodeRewriter for ShortenInListSimplifier { // if expr is a single column reference: // expr IN (A, B, ...) --> (expr = A) OR (expr = B) OR (expr = C) if let Expr::InList(InList { - expr, - list, + ref expr, + ref list, negated, - }) = expr.clone() + }) = expr { if !list.is_empty() && ( @@ -57,7 +57,7 @@ impl TreeNodeRewriter for ShortenInListSimplifier { { let first_val = list[0].clone(); if negated { - return Ok(Transformed::yes(list.into_iter().skip(1).fold( + return Ok(Transformed::yes(list.iter().skip(1).cloned().fold( (*expr.clone()).not_eq(first_val), |acc, y| { // Note that `A and B and C and D` is a left-deep tree structure @@ -81,7 +81,7 @@ impl TreeNodeRewriter for ShortenInListSimplifier { }, ))); } else { - return Ok(Transformed::yes(list.into_iter().skip(1).fold( + return Ok(Transformed::yes(list.iter().skip(1).cloned().fold( (*expr.clone()).eq(first_val), |acc, y| { // Same reasoning as above diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index 5fbee02e3909..7ae38eec9a3a 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -23,6 +23,7 @@ mod guarantees; mod inlist_simplifier; mod regex; pub mod simplify_exprs; +mod simplify_predicates; mod unwrap_cast; mod utils; @@ -31,6 +32,7 @@ pub use datafusion_expr::simplify::{SimplifyContext, SimplifyInfo}; pub use expr_simplifier::*; pub use simplify_exprs::*; +pub use simplify_predicates::simplify_predicates; // Export for test in datafusion/core/tests/optimizer_integration.rs pub use guarantees::GuaranteeRewriter; diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index 0b47cdee212f..82c5ea3d8d82 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -46,7 +46,7 @@ pub fn simplify_regex_expr( ) -> Result { let mode = OperatorMode::new(&op); - if let Expr::Literal(ScalarValue::Utf8(Some(pattern))) = right.as_ref() { + if let Expr::Literal(ScalarValue::Utf8(Some(pattern)), _) = right.as_ref() { // Handle the special case for ".*" pattern if pattern == ANY_CHAR_REGEX_PATTERN { let new_expr = if mode.not { @@ -121,7 +121,7 @@ impl OperatorMode { let like = Like { negated: self.not, expr, - pattern: Box::new(Expr::Literal(ScalarValue::from(pattern))), + pattern: Box::new(Expr::Literal(ScalarValue::from(pattern), None)), escape_char: None, case_insensitive: self.i, }; @@ -255,9 +255,9 @@ fn partial_anchored_literal_to_like(v: &[Hir]) -> Option { }; if match_begin { - Some(format!("{}%", lit)) + Some(format!("{lit}%")) } else { - Some(format!("%{}", lit)) + Some(format!("%{lit}")) } } @@ -331,7 +331,7 @@ fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { } HirKind::Concat(inner) => { if let Some(pattern) = partial_anchored_literal_to_like(inner) - .or(collect_concat_to_like_string(inner)) + .or_else(|| collect_concat_to_like_string(inner)) { return Some(mode.expr(Box::new(left.clone()), pattern)); } diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index e33869ca2b63..ccf90893e17e 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -123,10 +123,11 @@ impl SimplifyExpressions { let name_preserver = NamePreserver::new(&plan); let mut rewrite_expr = |expr: Expr| { let name = name_preserver.save(&expr); - let expr = simplifier.simplify(expr)?; - // TODO it would be nice to have a way to know if the expression was simplified - // or not. For now conservatively return Transformed::yes - Ok(Transformed::yes(name.restore(expr))) + let expr = simplifier.simplify_with_cycle_count_transformed(expr)?.0; + Ok(Transformed::new_transformed( + name.restore(expr.data), + expr.transformed, + )) }; plan.map_expressions(|expr| { @@ -154,12 +155,12 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use chrono::{DateTime, Utc}; - use crate::optimizer::Optimizer; use datafusion_expr::logical_plan::builder::table_scan_with_filters; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::*; use datafusion_functions_aggregate::expr_fn::{max, min}; + use crate::assert_optimized_plan_eq_snapshot; use crate::test::{assert_fields_eq, test_table_scan_with_name}; use crate::OptimizerContext; @@ -179,15 +180,20 @@ mod tests { .expect("building plan") } - fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { - // Use Optimizer to do plan traversal - fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - let optimizer = Optimizer::with_rules(vec![Arc::new(SimplifyExpressions::new())]); - let optimized_plan = - optimizer.optimize(plan, &OptimizerContext::new(), observe)?; - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); - Ok(()) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rules: Vec> = vec![Arc::new(SimplifyExpressions::new())]; + let optimizer_ctx = OptimizerContext::new(); + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; } #[test] @@ -210,9 +216,10 @@ mod tests { assert_eq!(1, table_scan.schema().fields().len()); assert_fields_eq(&table_scan, vec!["a"]); - let expected = "TableScan: test projection=[a], full_filters=[Boolean(true)]"; - - assert_optimized_plan_eq(table_scan, expected) + assert_optimized_plan_equal!( + table_scan, + @ r"TableScan: test projection=[a], full_filters=[Boolean(true)]" + ) } #[test] @@ -223,12 +230,13 @@ mod tests { .filter(and(col("b").gt(lit(1)), col("b").gt(lit(1))))? .build()?; - assert_optimized_plan_eq( + assert_optimized_plan_equal!( plan, - "\ - Filter: test.b > Int32(1)\ - \n Projection: test.a\ - \n TableScan: test", + @ r" + Filter: test.b > Int32(1) + Projection: test.a + TableScan: test + " ) } @@ -240,12 +248,13 @@ mod tests { .filter(and(col("b").gt(lit(1)), col("b").gt(lit(1))))? .build()?; - assert_optimized_plan_eq( + assert_optimized_plan_equal!( plan, - "\ - Filter: test.b > Int32(1)\ - \n Projection: test.a\ - \n TableScan: test", + @ r" + Filter: test.b > Int32(1) + Projection: test.a + TableScan: test + " ) } @@ -257,12 +266,13 @@ mod tests { .filter(or(col("b").gt(lit(1)), col("b").gt(lit(1))))? .build()?; - assert_optimized_plan_eq( + assert_optimized_plan_equal!( plan, - "\ - Filter: test.b > Int32(1)\ - \n Projection: test.a\ - \n TableScan: test", + @ r" + Filter: test.b > Int32(1) + Projection: test.a + TableScan: test + " ) } @@ -278,12 +288,13 @@ mod tests { ))? .build()?; - assert_optimized_plan_eq( + assert_optimized_plan_equal!( plan, - "\ - Filter: test.a > Int32(5) AND test.b < Int32(6)\ - \n Projection: test.a, test.b\ - \n TableScan: test", + @ r" + Filter: test.a > Int32(5) AND test.b < Int32(6) + Projection: test.a, test.b + TableScan: test + " ) } @@ -296,13 +307,15 @@ mod tests { .project(vec![col("a")])? .build()?; - let expected = "\ - Projection: test.a\ - \n Filter: NOT test.c\ - \n Filter: test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a + Filter: NOT test.c + Filter: test.b + TableScan: test + " + ) } #[test] @@ -315,14 +328,16 @@ mod tests { .project(vec![col("a")])? .build()?; - let expected = "\ - Projection: test.a\ - \n Limit: skip=0, fetch=1\ - \n Filter: test.c\ - \n Filter: NOT test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a + Limit: skip=0, fetch=1 + Filter: test.c + Filter: NOT test.b + TableScan: test + " + ) } #[test] @@ -333,12 +348,14 @@ mod tests { .project(vec![col("a")])? .build()?; - let expected = "\ - Projection: test.a\ - \n Filter: NOT test.b AND test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a + Filter: NOT test.b AND test.c + TableScan: test + " + ) } #[test] @@ -349,12 +366,14 @@ mod tests { .project(vec![col("a")])? .build()?; - let expected = "\ - Projection: test.a\ - \n Filter: NOT test.b OR NOT test.c\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a + Filter: NOT test.b OR NOT test.c + TableScan: test + " + ) } #[test] @@ -365,12 +384,14 @@ mod tests { .project(vec![col("a")])? .build()?; - let expected = "\ - Projection: test.a\ - \n Filter: test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a + Filter: test.b + TableScan: test + " + ) } #[test] @@ -380,11 +401,13 @@ mod tests { .project(vec![col("a"), col("d"), col("b").eq(lit(false))])? .build()?; - let expected = "\ - Projection: test.a, test.d, NOT test.b AS test.b = Boolean(false)\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Projection: test.a, test.d, NOT test.b AS test.b = Boolean(false) + TableScan: test + " + ) } #[test] @@ -398,12 +421,14 @@ mod tests { )? .build()?; - let expected = "\ - Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b) AS max(test.b = Boolean(true)), min(test.b)]]\ - \n Projection: test.a, test.c, test.b\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b) AS max(test.b = Boolean(true)), min(test.b)]] + Projection: test.a, test.c, test.b + TableScan: test + " + ) } #[test] @@ -421,10 +446,10 @@ mod tests { let values = vec![vec![expr1, expr2]]; let plan = LogicalPlanBuilder::values(values)?.build()?; - let expected = "\ - Values: (Int32(3) AS Int32(1) + Int32(2), Int32(1) AS Int32(2) - Int32(1))"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ "Values: (Int32(3) AS Int32(1) + Int32(2), Int32(1) AS Int32(2) - Int32(1))" + ) } fn get_optimized_plan_formatted( @@ -481,10 +506,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").gt(lit(10)).not())? .build()?; - let expected = "Filter: test.d <= Int32(10)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d <= Int32(10) + TableScan: test + " + ) } #[test] @@ -494,10 +523,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").gt(lit(10)).and(col("d").lt(lit(100))).not())? .build()?; - let expected = "Filter: test.d <= Int32(10) OR test.d >= Int32(100)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d <= Int32(10) OR test.d >= Int32(100) + TableScan: test + " + ) } #[test] @@ -507,10 +540,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").gt(lit(10)).or(col("d").lt(lit(100))).not())? .build()?; - let expected = "Filter: test.d <= Int32(10) AND test.d >= Int32(100)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d <= Int32(10) AND test.d >= Int32(100) + TableScan: test + " + ) } #[test] @@ -520,10 +557,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").gt(lit(10)).not().not())? .build()?; - let expected = "Filter: test.d > Int32(10)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d > Int32(10) + TableScan: test + " + ) } #[test] @@ -533,10 +574,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("e").is_null().not())? .build()?; - let expected = "Filter: test.e IS NOT NULL\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.e IS NOT NULL + TableScan: test + " + ) } #[test] @@ -546,10 +591,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("e").is_not_null().not())? .build()?; - let expected = "Filter: test.e IS NULL\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.e IS NULL + TableScan: test + " + ) } #[test] @@ -559,11 +608,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], false).not())? .build()?; - let expected = - "Filter: test.d != Int32(1) AND test.d != Int32(2) AND test.d != Int32(3)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d != Int32(1) AND test.d != Int32(2) AND test.d != Int32(3) + TableScan: test + " + ) } #[test] @@ -573,11 +625,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").in_list(vec![lit(1), lit(2), lit(3)], true).not())? .build()?; - let expected = - "Filter: test.d = Int32(1) OR test.d = Int32(2) OR test.d = Int32(3)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d = Int32(1) OR test.d = Int32(2) OR test.d = Int32(3) + TableScan: test + " + ) } #[test] @@ -588,10 +643,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(qual.not())? .build()?; - let expected = "Filter: test.d < Int32(1) OR test.d > Int32(10)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d < Int32(1) OR test.d > Int32(10) + TableScan: test + " + ) } #[test] @@ -602,10 +661,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(qual.not())? .build()?; - let expected = "Filter: test.d >= Int32(1) AND test.d <= Int32(10)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d >= Int32(1) AND test.d <= Int32(10) + TableScan: test + " + ) } #[test] @@ -622,10 +685,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("a").like(col("b")).not())? .build()?; - let expected = "Filter: test.a NOT LIKE test.b\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.a NOT LIKE test.b + TableScan: test + " + ) } #[test] @@ -642,10 +709,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("a").not_like(col("b")).not())? .build()?; - let expected = "Filter: test.a LIKE test.b\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.a LIKE test.b + TableScan: test + " + ) } #[test] @@ -662,10 +733,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("a").ilike(col("b")).not())? .build()?; - let expected = "Filter: test.a NOT ILIKE test.b\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.a NOT ILIKE test.b + TableScan: test + " + ) } #[test] @@ -675,10 +750,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(binary_expr(col("d"), Operator::IsDistinctFrom, lit(10)).not())? .build()?; - let expected = "Filter: test.d IS NOT DISTINCT FROM Int32(10)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d IS NOT DISTINCT FROM Int32(10) + TableScan: test + " + ) } #[test] @@ -688,10 +767,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(binary_expr(col("d"), Operator::IsNotDistinctFrom, lit(10)).not())? .build()?; - let expected = "Filter: test.d IS DISTINCT FROM Int32(10)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.d IS DISTINCT FROM Int32(10) + TableScan: test + " + ) } #[test] @@ -713,11 +796,14 @@ mod tests { // before simplify: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32) // after simplify: t1.a + UInt32(1) = t2.a + UInt32(2) AS t1.a + Int64(1) = t2.a + Int64(2) - let expected = "Inner Join: t1.a + UInt32(1) = t2.a + UInt32(2)\ - \n TableScan: t1\ - \n TableScan: t2"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Inner Join: t1.a + UInt32(1) = t2.a + UInt32(2) + TableScan: t1 + TableScan: t2 + " + ) } #[test] @@ -727,10 +813,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").is_not_null())? .build()?; - let expected = "Filter: Boolean(true)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: Boolean(true) + TableScan: test + " + ) } #[test] @@ -740,10 +830,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter(col("d").is_null())? .build()?; - let expected = "Filter: Boolean(false)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Filter: Boolean(false) + TableScan: test + " + ) } #[test] @@ -760,10 +854,13 @@ mod tests { )? .build()?; - let expected = "Aggregate: groupBy=[[GROUPING SETS ((Int32(43) AS age, test.a), (Boolean(false) AS cond), (test.d AS e, Int32(3) AS Int32(1) + Int32(2)))]], aggr=[[]]\ - \n TableScan: test"; - - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r" + Aggregate: groupBy=[[GROUPING SETS ((Int32(43) AS age, test.a), (Boolean(false) AS cond), (test.d AS e, Int32(3) AS Int32(1) + Int32(2)))]], aggr=[[]] + TableScan: test + " + ) } #[test] @@ -778,19 +875,27 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan.clone()) .filter(binary_expr(col("a"), Operator::RegexMatch, lit(".*")))? .build()?; - let expected = "Filter: test.a IS NOT NULL\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_equal!( + plan, + @ r" + Filter: test.a IS NOT NULL + TableScan: test + " + )?; // Test `!= ".*"` transforms to checking if the column is empty let plan = LogicalPlanBuilder::from(table_scan.clone()) .filter(binary_expr(col("a"), Operator::RegexNotMatch, lit(".*")))? .build()?; - let expected = "Filter: test.a = Utf8(\"\")\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_equal!( + plan, + @ r#" + Filter: test.a = Utf8("") + TableScan: test + "# + )?; // Test case-insensitive versions @@ -798,18 +903,26 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan.clone()) .filter(binary_expr(col("b"), Operator::RegexIMatch, lit(".*")))? .build()?; - let expected = "Filter: Boolean(true)\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected)?; + assert_optimized_plan_equal!( + plan, + @ r" + Filter: Boolean(true) + TableScan: test + " + )?; // Test `!~ ".*"` (case-insensitive) transforms to checking if the column is empty let plan = LogicalPlanBuilder::from(table_scan.clone()) .filter(binary_expr(col("a"), Operator::RegexNotIMatch, lit(".*")))? .build()?; - let expected = "Filter: test.a = Utf8(\"\")\ - \n TableScan: test"; - assert_optimized_plan_eq(plan, expected) + assert_optimized_plan_equal!( + plan, + @ r#" + Filter: test.a = Utf8("") + TableScan: test + "# + ) } } diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs b/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs new file mode 100644 index 000000000000..32b2315e15d5 --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Simplifies predicates by reducing redundant or overlapping conditions. +//! +//! This module provides functionality to optimize logical predicates used in query planning +//! by eliminating redundant conditions, thus reducing the number of predicates to evaluate. +//! Unlike the simplifier in `simplify_expressions/simplify_exprs.rs`, which focuses on +//! general expression simplification (e.g., constant folding and algebraic simplifications), +//! this module specifically targets predicate optimization by handling containment relationships. +//! For example, it can simplify `x > 5 AND x > 6` to just `x > 6`, as the latter condition +//! encompasses the former, resulting in fewer checks during query execution. + +use datafusion_common::{Column, Result, ScalarValue}; +use datafusion_expr::{BinaryExpr, Cast, Expr, Operator}; +use std::collections::BTreeMap; + +/// Simplifies a list of predicates by removing redundancies. +/// +/// This function takes a vector of predicate expressions and groups them by the column they reference. +/// Predicates that reference a single column and are comparison operations (e.g., >, >=, <, <=, =) +/// are analyzed to remove redundant conditions. For instance, `x > 5 AND x > 6` is simplified to +/// `x > 6`. Other predicates that do not fit this pattern are retained as-is. +/// +/// # Arguments +/// * `predicates` - A vector of `Expr` representing the predicates to simplify. +/// +/// # Returns +/// A `Result` containing a vector of simplified `Expr` predicates. +pub fn simplify_predicates(predicates: Vec) -> Result> { + // Early return for simple cases + if predicates.len() <= 1 { + return Ok(predicates); + } + + // Group predicates by their column reference + let mut column_predicates: BTreeMap> = BTreeMap::new(); + let mut other_predicates = Vec::new(); + + for pred in predicates { + match &pred { + Expr::BinaryExpr(BinaryExpr { + left, + op: + Operator::Gt + | Operator::GtEq + | Operator::Lt + | Operator::LtEq + | Operator::Eq, + right, + }) => { + let left_col = extract_column_from_expr(left); + let right_col = extract_column_from_expr(right); + if let (Some(col), Some(_)) = (&left_col, right.as_literal()) { + column_predicates.entry(col.clone()).or_default().push(pred); + } else if let (Some(_), Some(col)) = (left.as_literal(), &right_col) { + column_predicates.entry(col.clone()).or_default().push(pred); + } else { + other_predicates.push(pred); + } + } + _ => other_predicates.push(pred), + } + } + + // Process each column's predicates to remove redundancies + let mut result = other_predicates; + for (_, preds) in column_predicates { + let simplified = simplify_column_predicates(preds)?; + result.extend(simplified); + } + + Ok(result) +} + +/// Simplifies predicates related to a single column. +/// +/// This function processes a list of predicates that all reference the same column and +/// simplifies them based on their operators. It groups predicates into greater-than (>, >=), +/// less-than (<, <=), and equality (=) categories, then selects the most restrictive condition +/// in each category to reduce redundancy. For example, among `x > 5` and `x > 6`, only `x > 6` +/// is retained as it is more restrictive. +/// +/// # Arguments +/// * `predicates` - A vector of `Expr` representing predicates for a single column. +/// +/// # Returns +/// A `Result` containing a vector of simplified `Expr` predicates for the column. +fn simplify_column_predicates(predicates: Vec) -> Result> { + if predicates.len() <= 1 { + return Ok(predicates); + } + + // Group by operator type, but combining similar operators + let mut greater_predicates = Vec::new(); // Combines > and >= + let mut less_predicates = Vec::new(); // Combines < and <= + let mut eq_predicates = Vec::new(); + + for pred in predicates { + match &pred { + Expr::BinaryExpr(BinaryExpr { left: _, op, right }) => { + match (op, right.as_literal().is_some()) { + (Operator::Gt, true) + | (Operator::Lt, false) + | (Operator::GtEq, true) + | (Operator::LtEq, false) => greater_predicates.push(pred), + (Operator::Lt, true) + | (Operator::Gt, false) + | (Operator::LtEq, true) + | (Operator::GtEq, false) => less_predicates.push(pred), + (Operator::Eq, _) => eq_predicates.push(pred), + _ => unreachable!("Unexpected operator: {}", op), + } + } + _ => unreachable!("Unexpected predicate {}", pred.to_string()), + } + } + + let mut result = Vec::new(); + + if !eq_predicates.is_empty() { + // If there are many equality predicates, we can only keep one if they are all the same + if eq_predicates.len() == 1 + || eq_predicates.iter().all(|e| e == &eq_predicates[0]) + { + result.push(eq_predicates.pop().unwrap()); + } else { + // If they are not the same, add a false predicate + result.push(Expr::Literal(ScalarValue::Boolean(Some(false)), None)); + } + } + + // Handle all greater-than-style predicates (keep the most restrictive - highest value) + if !greater_predicates.is_empty() { + if let Some(most_restrictive) = + find_most_restrictive_predicate(&greater_predicates, true)? + { + result.push(most_restrictive); + } else { + result.extend(greater_predicates); + } + } + + // Handle all less-than-style predicates (keep the most restrictive - lowest value) + if !less_predicates.is_empty() { + if let Some(most_restrictive) = + find_most_restrictive_predicate(&less_predicates, false)? + { + result.push(most_restrictive); + } else { + result.extend(less_predicates); + } + } + + Ok(result) +} + +/// Finds the most restrictive predicate from a list based on literal values. +/// +/// This function iterates through a list of predicates to identify the most restrictive one +/// by comparing their literal values. For greater-than predicates, the highest value is most +/// restrictive, while for less-than predicates, the lowest value is most restrictive. +/// +/// # Arguments +/// * `predicates` - A slice of `Expr` representing predicates to compare. +/// * `find_greater` - A boolean indicating whether to find the highest value (true for >, >=) +/// or the lowest value (false for <, <=). +/// +/// # Returns +/// A `Result` containing an `Option` with the most restrictive predicate, if any. +fn find_most_restrictive_predicate( + predicates: &[Expr], + find_greater: bool, +) -> Result> { + if predicates.is_empty() { + return Ok(None); + } + + let mut most_restrictive_idx = 0; + let mut best_value: Option<&ScalarValue> = None; + + for (idx, pred) in predicates.iter().enumerate() { + if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = pred { + // Extract the literal value based on which side has it + let scalar_value = match (right.as_literal(), left.as_literal()) { + (Some(scalar), _) => Some(scalar), + (_, Some(scalar)) => Some(scalar), + _ => None, + }; + + if let Some(scalar) = scalar_value { + if let Some(current_best) = best_value { + if let Some(comparison) = scalar.partial_cmp(current_best) { + let is_better = if find_greater { + comparison == std::cmp::Ordering::Greater + } else { + comparison == std::cmp::Ordering::Less + }; + + if is_better { + best_value = Some(scalar); + most_restrictive_idx = idx; + } + } + } else { + best_value = Some(scalar); + most_restrictive_idx = idx; + } + } + } + } + + Ok(Some(predicates[most_restrictive_idx].clone())) +} + +/// Extracts a column reference from an expression, if present. +/// +/// This function checks if the given expression is a column reference or contains one, +/// such as within a cast operation. It returns the `Column` if found. +/// +/// # Arguments +/// * `expr` - A reference to an `Expr` to inspect for a column reference. +/// +/// # Returns +/// An `Option` containing the column reference if found, otherwise `None`. +fn extract_column_from_expr(expr: &Expr) -> Option { + match expr { + Expr::Column(col) => Some(col.clone()), + // Handle cases where the column might be wrapped in a cast or other operation + Expr::Cast(Cast { expr, .. }) => extract_column_from_expr(expr), + _ => None, + } +} diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index be71a8cd19b0..7c8ff8305e84 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -69,14 +69,14 @@ use datafusion_expr::{simplify::SimplifyInfo, Cast, Expr, Operator, TryCast}; pub(super) fn unwrap_cast_in_comparison_for_binary( info: &S, - cast_expr: Box, - literal: Box, + cast_expr: Expr, + literal: Expr, op: Operator, ) -> Result> { - match (*cast_expr, *literal) { + match (cast_expr, literal) { ( Expr::TryCast(TryCast { expr, .. }) | Expr::Cast(Cast { expr, .. }), - Expr::Literal(lit_value), + Expr::Literal(lit_value, _), ) => { let Ok(expr_type) = info.get_data_type(&expr) else { return internal_err!("Can't get the data type of the expr {:?}", &expr); @@ -126,7 +126,7 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary< | Expr::Cast(Cast { expr: left_expr, .. }), - Expr::Literal(lit_val), + Expr::Literal(lit_val, _), ) => { let Ok(expr_type) = info.get_data_type(left_expr) else { return false; @@ -183,7 +183,7 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist< } match right { - Expr::Literal(lit_val) + Expr::Literal(lit_val, _) if try_cast_literal_to_type(lit_val, &expr_type).is_some() => {} _ => return false, } @@ -197,6 +197,7 @@ fn is_supported_type(data_type: &DataType) -> bool { is_supported_numeric_type(data_type) || is_supported_string_type(data_type) || is_supported_dictionary_type(data_type) + || is_supported_binary_type(data_type) } /// Returns true if unwrap_cast_in_comparison support this numeric type @@ -230,6 +231,10 @@ fn is_supported_dictionary_type(data_type: &DataType) -> bool { DataType::Dictionary(_, inner) if is_supported_type(inner)) } +fn is_supported_binary_type(data_type: &DataType) -> bool { + matches!(data_type, DataType::Binary | DataType::FixedSizeBinary(_)) +} + ///// Tries to move a cast from an expression (such as column) to the literal other side of a comparison operator./ /// /// Specifically, rewrites @@ -292,6 +297,7 @@ pub(super) fn try_cast_literal_to_type( try_cast_numeric_literal(lit_value, target_type) .or_else(|| try_cast_string_literal(lit_value, target_type)) .or_else(|| try_cast_dictionary(lit_value, target_type)) + .or_else(|| try_cast_binary(lit_value, target_type)) } /// Convert a numeric value from one numeric data type to another @@ -501,6 +507,20 @@ fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option } } +fn try_cast_binary( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + match (lit_value, target_type) { + (ScalarValue::Binary(Some(v)), DataType::FixedSizeBinary(n)) + if v.len() == *n as usize => + { + Some(ScalarValue::FixedSizeBinary(*n, Some(v.clone()))) + } + _ => None, + } +} + #[cfg(test)] mod tests { use super::*; @@ -1450,4 +1470,13 @@ mod tests { ) } } + + #[test] + fn try_cast_to_fixed_size_binary() { + expect_cast( + ScalarValue::Binary(Some(vec![1, 2, 3])), + DataType::FixedSizeBinary(3), + ExpectedCast::Value(ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3]))), + ) + } } diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index cf182175e48e..4df0e125eb18 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -139,34 +139,34 @@ pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> pub fn is_zero(s: &Expr) -> bool { match s { - Expr::Literal(ScalarValue::Int8(Some(0))) - | Expr::Literal(ScalarValue::Int16(Some(0))) - | Expr::Literal(ScalarValue::Int32(Some(0))) - | Expr::Literal(ScalarValue::Int64(Some(0))) - | Expr::Literal(ScalarValue::UInt8(Some(0))) - | Expr::Literal(ScalarValue::UInt16(Some(0))) - | Expr::Literal(ScalarValue::UInt32(Some(0))) - | Expr::Literal(ScalarValue::UInt64(Some(0))) => true, - Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 0. => true, - Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 0. => true, - Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s)) if *v == 0 => true, + Expr::Literal(ScalarValue::Int8(Some(0)), _) + | Expr::Literal(ScalarValue::Int16(Some(0)), _) + | Expr::Literal(ScalarValue::Int32(Some(0)), _) + | Expr::Literal(ScalarValue::Int64(Some(0)), _) + | Expr::Literal(ScalarValue::UInt8(Some(0)), _) + | Expr::Literal(ScalarValue::UInt16(Some(0)), _) + | Expr::Literal(ScalarValue::UInt32(Some(0)), _) + | Expr::Literal(ScalarValue::UInt64(Some(0)), _) => true, + Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 0. => true, + Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 0. => true, + Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s), _) if *v == 0 => true, _ => false, } } pub fn is_one(s: &Expr) -> bool { match s { - Expr::Literal(ScalarValue::Int8(Some(1))) - | Expr::Literal(ScalarValue::Int16(Some(1))) - | Expr::Literal(ScalarValue::Int32(Some(1))) - | Expr::Literal(ScalarValue::Int64(Some(1))) - | Expr::Literal(ScalarValue::UInt8(Some(1))) - | Expr::Literal(ScalarValue::UInt16(Some(1))) - | Expr::Literal(ScalarValue::UInt32(Some(1))) - | Expr::Literal(ScalarValue::UInt64(Some(1))) => true, - Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 1. => true, - Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 1. => true, - Expr::Literal(ScalarValue::Decimal128(Some(v), _p, s)) => { + Expr::Literal(ScalarValue::Int8(Some(1)), _) + | Expr::Literal(ScalarValue::Int16(Some(1)), _) + | Expr::Literal(ScalarValue::Int32(Some(1)), _) + | Expr::Literal(ScalarValue::Int64(Some(1)), _) + | Expr::Literal(ScalarValue::UInt8(Some(1)), _) + | Expr::Literal(ScalarValue::UInt16(Some(1)), _) + | Expr::Literal(ScalarValue::UInt32(Some(1)), _) + | Expr::Literal(ScalarValue::UInt64(Some(1)), _) => true, + Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 1. => true, + Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 1. => true, + Expr::Literal(ScalarValue::Decimal128(Some(v), _p, s), _) => { *s >= 0 && POWS_OF_TEN .get(*s as usize) @@ -179,7 +179,7 @@ pub fn is_one(s: &Expr) -> bool { pub fn is_true(expr: &Expr) -> bool { match expr { - Expr::Literal(ScalarValue::Boolean(Some(v))) => *v, + Expr::Literal(ScalarValue::Boolean(Some(v)), _) => *v, _ => false, } } @@ -187,24 +187,24 @@ pub fn is_true(expr: &Expr) -> bool { /// returns true if expr is a /// `Expr::Literal(ScalarValue::Boolean(v))` , false otherwise pub fn is_bool_lit(expr: &Expr) -> bool { - matches!(expr, Expr::Literal(ScalarValue::Boolean(_))) + matches!(expr, Expr::Literal(ScalarValue::Boolean(_), _)) } /// Return a literal NULL value of Boolean data type pub fn lit_bool_null() -> Expr { - Expr::Literal(ScalarValue::Boolean(None)) + Expr::Literal(ScalarValue::Boolean(None), None) } pub fn is_null(expr: &Expr) -> bool { match expr { - Expr::Literal(v) => v.is_null(), + Expr::Literal(v, _) => v.is_null(), _ => false, } } pub fn is_false(expr: &Expr) -> bool { match expr { - Expr::Literal(ScalarValue::Boolean(Some(v))) => !(*v), + Expr::Literal(ScalarValue::Boolean(Some(v)), _) => !(*v), _ => false, } } @@ -247,7 +247,7 @@ pub fn is_negative_of(not_expr: &Expr, expr: &Expr) -> bool { /// `Expr::Literal(ScalarValue::Boolean(v))`. pub fn as_bool_lit(expr: &Expr) -> Result> { match expr { - Expr::Literal(ScalarValue::Boolean(v)) => Ok(*v), + Expr::Literal(ScalarValue::Boolean(v), _) => Ok(*v), _ => internal_err!("Expected boolean literal, got {expr:?}"), } } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 7337d2ffce5c..50783a214342 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -206,7 +206,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation } else { index += 1; - let alias_str = format!("alias{}", index); + let alias_str = format!("alias{index}"); inner_aggr_exprs.push( Expr::AggregateFunction(AggregateFunction::new_udf( Arc::clone(&func), @@ -280,6 +280,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { #[cfg(test)] mod tests { use super::*; + use crate::assert_optimized_plan_eq_display_indent_snapshot; use crate::test::*; use datafusion_expr::expr::GroupingSet; use datafusion_expr::ExprFunctionExt; @@ -300,13 +301,18 @@ mod tests { )) } - fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq_display_indent( - Arc::new(SingleDistinctToGroupBy::new()), - plan, - expected, - ); - Ok(()) + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let rule: Arc = Arc::new(SingleDistinctToGroupBy::new()); + assert_optimized_plan_eq_display_indent_snapshot!( + rule, + $plan, + @ $expected, + ) + }}; } #[test] @@ -318,11 +324,13 @@ mod tests { .build()?; // Do nothing - let expected = - "Aggregate: groupBy=[[]], aggr=[[max(test.b)]] [max(test.b):UInt32;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[]], aggr=[[max(test.b)]] [max(test.b):UInt32;N] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -334,12 +342,15 @@ mod tests { .build()?; // Should work - let expected = "Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64] + Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64] + Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -357,10 +368,13 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -375,10 +389,13 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -394,10 +411,13 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -408,12 +428,15 @@ mod tests { .aggregate(Vec::::new(), vec![count_distinct(lit(2) * col("b"))])? .build()?; - let expected = "Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64]\ - \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\ - \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64] + Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64] + Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -425,12 +448,15 @@ mod tests { .build()?; // Should work - let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64] + Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64] + Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -445,10 +471,13 @@ mod tests { .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(DISTINCT test.c):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(DISTINCT test.c):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -461,13 +490,17 @@ mod tests { vec![count_distinct(col("b")), max_distinct(col("b"))], )? .build()?; - // Should work - let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), max(alias1)]] [a:UInt32, count(alias1):Int64, max(alias1):UInt32;N]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Should work + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N] + Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), max(alias1)]] [a:UInt32, count(alias1):Int64, max(alias1):UInt32;N] + Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -482,10 +515,13 @@ mod tests { .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(test.c):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(test.c):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -497,12 +533,15 @@ mod tests { .build()?; // Should work - let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int64, count(DISTINCT test.c):Int64]\ - \n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int64, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int64, alias1:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) + assert_optimized_plan_equal!( + plan, + @r" + Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int64, count(DISTINCT test.c):Int64] + Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int64, count(alias1):Int64] + Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int64, alias1:UInt32] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -519,13 +558,17 @@ mod tests { ], )? .build()?; - // Should work - let expected = "Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), max(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64, max(alias1):UInt32;N]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Should work + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N] + Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), max(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64, max(alias1):UInt32;N] + Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -538,13 +581,17 @@ mod tests { vec![sum(col("c")), max(col("c")), count_distinct(col("b"))], )? .build()?; - // Should work - let expected = "Projection: test.a, sum(alias2) AS sum(test.c), max(alias3) AS max(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, max(test.c):UInt32;N, count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), max(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, max(alias3):UInt32;N, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2, max(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Should work + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.a, sum(alias2) AS sum(test.c), max(alias3) AS max(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, max(test.c):UInt32;N, count(DISTINCT test.b):Int64] + Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), max(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, max(alias3):UInt32;N, count(alias1):Int64] + Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2, max(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -557,13 +604,17 @@ mod tests { vec![min(col("a")), count_distinct(col("b"))], )? .build()?; - // Should work - let expected = "Projection: test.c, min(alias2) AS min(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, min(test.a):UInt32;N, count(DISTINCT test.b):Int64]\ - \n Aggregate: groupBy=[[test.c]], aggr=[[min(alias2), count(alias1)]] [c:UInt32, min(alias2):UInt32;N, count(alias1):Int64]\ - \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[min(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Should work + assert_optimized_plan_equal!( + plan, + @r" + Projection: test.c, min(alias2) AS min(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, min(test.a):UInt32;N, count(DISTINCT test.b):Int64] + Aggregate: groupBy=[[test.c]], aggr=[[min(alias2), count(alias1)]] [c:UInt32, min(alias2):UInt32;N, count(alias1):Int64] + Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[min(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -582,11 +633,15 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? .build()?; - // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Do nothing + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -602,11 +657,15 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; - // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Do nothing + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -625,11 +684,15 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? .build()?; - // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a ASC NULLS LAST], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a ASC NULLS LAST]:UInt64;N, count(DISTINCT test.b):Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Do nothing + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a ASC NULLS LAST], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a ASC NULLS LAST]:UInt64;N, count(DISTINCT test.b):Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -645,11 +708,15 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; - // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Do nothing + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } #[test] @@ -666,10 +733,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; - // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(plan, expected) + // Do nothing + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64] + TableScan: test [a:UInt32, b:UInt32, c:UInt32] + " + ) } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 94d07a0791b3..6e0b734bb928 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -21,7 +21,7 @@ use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::{assert_contains, Result}; -use datafusion_expr::{col, logical_plan::table_scan, LogicalPlan, LogicalPlanBuilder}; +use datafusion_expr::{logical_plan::table_scan, LogicalPlan, LogicalPlanBuilder}; use std::sync::Arc; pub mod user_defined; @@ -64,15 +64,6 @@ pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { assert_eq!(actual, expected); } -pub fn test_subquery_with_name(name: &str) -> Result> { - let table_scan = test_table_scan_with_name(name)?; - Ok(Arc::new( - LogicalPlanBuilder::from(table_scan) - .project(vec![col("c")])? - .build()?, - )) -} - pub fn scan_tpch_table(table: &str) -> LogicalPlan { let schema = Arc::new(get_tpch_table_schema(table)); table_scan(Some(table), &schema, None) @@ -108,43 +99,20 @@ pub fn get_tpch_table_schema(table: &str) -> Schema { } } -pub fn assert_analyzed_plan_eq( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) -> Result<()> { - let options = ConfigOptions::default(); - assert_analyzed_plan_with_config_eq(options, rule, plan, expected)?; - - Ok(()) -} +#[macro_export] +macro_rules! assert_analyzed_plan_with_config_eq_snapshot { + ( + $options:expr, + $rule:expr, + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let analyzed_plan = $crate::Analyzer::with_rules(vec![$rule]).execute_and_check($plan, &$options, |_, _| {})?; -pub fn assert_analyzed_plan_with_config_eq( - options: ConfigOptions, - rule: Arc, - plan: LogicalPlan, - expected: &str, -) -> Result<()> { - let analyzed_plan = - Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options, |_, _| {})?; - let formatted_plan = format!("{analyzed_plan}"); - assert_eq!(formatted_plan, expected); + insta::assert_snapshot!(analyzed_plan, @ $expected); - Ok(()) -} - -pub fn assert_analyzed_plan_eq_display_indent( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) -> Result<()> { - let options = ConfigOptions::default(); - let analyzed_plan = - Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options, |_, _| {})?; - let formatted_plan = analyzed_plan.display_indent_schema().to_string(); - assert_eq!(formatted_plan, expected); - - Ok(()) + Ok::<(), datafusion_common::DataFusionError>(()) + }}; } pub fn assert_analyzer_check_err( @@ -165,27 +133,26 @@ pub fn assert_analyzer_check_err( fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} -pub fn assert_optimized_plan_eq( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) -> Result<()> { - // Apply the rule once - let opt_context = OptimizerContext::new().with_max_passes(1); +#[macro_export] +macro_rules! assert_optimized_plan_eq_snapshot { + ( + $optimizer_context:expr, + $rules:expr, + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer = $crate::Optimizer::with_rules($rules); + let optimized_plan = optimizer.optimize($plan, &$optimizer_context, |_, _| {})?; + insta::assert_snapshot!(optimized_plan, @ $expected); - let optimizer = Optimizer::with_rules(vec![Arc::clone(&rule)]); - let optimized_plan = optimizer.optimize(plan, &opt_context, observe)?; - let formatted_plan = format!("{optimized_plan}"); - assert_eq!(formatted_plan, expected); - - Ok(()) + Ok::<(), datafusion_common::DataFusionError>(()) + }}; } fn generate_optimized_plan_with_rules( rules: Vec>, plan: LogicalPlan, ) -> LogicalPlan { - fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} let config = &mut OptimizerContext::new() .with_max_passes(1) .with_skip_failing_rules(false); @@ -211,60 +178,20 @@ pub fn assert_optimized_plan_with_rules( Ok(()) } -pub fn assert_optimized_plan_eq_display_indent( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) { - let optimizer = Optimizer::with_rules(vec![rule]); - let optimized_plan = optimizer - .optimize(plan, &OptimizerContext::new(), observe) - .expect("failed to optimize plan"); - let formatted_plan = optimized_plan.display_indent_schema().to_string(); - assert_eq!(formatted_plan, expected); -} - -pub fn assert_multi_rules_optimized_plan_eq_display_indent( - rules: Vec>, - plan: LogicalPlan, - expected: &str, -) { - let optimizer = Optimizer::with_rules(rules); - let optimized_plan = optimizer - .optimize(plan, &OptimizerContext::new(), observe) - .expect("failed to optimize plan"); - let formatted_plan = optimized_plan.display_indent_schema().to_string(); - assert_eq!(formatted_plan, expected); -} - -pub fn assert_optimizer_err( - rule: Arc, - plan: LogicalPlan, - expected: &str, -) { - let optimizer = Optimizer::with_rules(vec![rule]); - let res = optimizer.optimize(plan, &OptimizerContext::new(), observe); - match res { - Ok(plan) => assert_eq!(format!("{}", plan.display_indent()), "An error"), - Err(ref e) => { - let actual = format!("{e}"); - if expected.is_empty() || !actual.contains(expected) { - assert_eq!(actual, expected) - } - } - } -} - -pub fn assert_optimization_skipped( - rule: Arc, - plan: LogicalPlan, -) -> Result<()> { - let optimizer = Optimizer::with_rules(vec![rule]); - let new_plan = optimizer.optimize(plan.clone(), &OptimizerContext::new(), observe)?; - - assert_eq!( - format!("{}", plan.display_indent()), - format!("{}", new_plan.display_indent()) - ); - Ok(()) +#[macro_export] +macro_rules! assert_optimized_plan_eq_display_indent_snapshot { + ( + $rule:expr, + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer = $crate::Optimizer::with_rules(vec![$rule]); + let optimized_plan = optimizer + .optimize($plan, &$crate::OptimizerContext::new(), |_, _| {}) + .expect("failed to optimize plan"); + let formatted_plan = optimized_plan.display_indent_schema(); + insta::assert_snapshot!(formatted_plan, @ $expected); + + Ok::<(), datafusion_common::DataFusionError>(()) + }}; } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 41c40ec06d65..0aa0bf3ea430 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -163,7 +163,11 @@ mod tests { (Expr::IsNotNull(Box::new(col("a"))), true), // a = NULL ( - binary_expr(col("a"), Operator::Eq, Expr::Literal(ScalarValue::Null)), + binary_expr( + col("a"), + Operator::Eq, + Expr::Literal(ScalarValue::Null, None), + ), true, ), // a > 8 @@ -226,12 +230,16 @@ mod tests { ), // a IN (NULL) ( - in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], false), + in_list( + col("a"), + vec![Expr::Literal(ScalarValue::Null, None)], + false, + ), true, ), // a NOT IN (NULL) ( - in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], true), + in_list(col("a"), vec![Expr::Literal(ScalarValue::Null, None)], true), true, ), ]; @@ -241,7 +249,7 @@ mod tests { let join_cols_of_predicate = std::iter::once(&column_a); let actual = is_restrict_null_predicate(predicate.clone(), join_cols_of_predicate)?; - assert_eq!(actual, expected, "{}", predicate); + assert_eq!(actual, expected, "{predicate}"); } Ok(()) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 941e5bd7b4d7..95a9db6c8abd 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -250,7 +250,7 @@ fn between_date32_plus_interval() -> Result<()> { format!("{plan}"), @r#" Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] - Projection: + Projection: Filter: test.col_date32 >= Date32("1998-03-18") AND test.col_date32 <= Date32("1998-06-16") TableScan: test projection=[col_date32] "# @@ -268,7 +268,7 @@ fn between_date64_plus_interval() -> Result<()> { format!("{plan}"), @r#" Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] - Projection: + Projection: Filter: test.col_date64 >= Date64("1998-03-18") AND test.col_date64 <= Date64("1998-06-16") TableScan: test projection=[col_date64] "# @@ -492,7 +492,32 @@ fn test_sql(sql: &str) -> Result { .with_expr_planners(vec![ Arc::new(AggregateFunctionPlanner), Arc::new(WindowFunctionPlanner), - ]); + ]) + .with_schema( + "test", + Schema::new_with_metadata( + vec![ + Field::new("col_int32", DataType::Int32, true), + Field::new("col_uint32", DataType::UInt32, true), + Field::new("col_utf8", DataType::Utf8, true), + Field::new("col_date32", DataType::Date32, true), + Field::new("col_date64", DataType::Date64, true), + // timestamp with no timezone + Field::new( + "col_ts_nano_none", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), + // timestamp with UTC timezone + Field::new( + "col_ts_nano_utc", + DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), + true, + ), + ], + HashMap::new(), + ), + ); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone())?; @@ -510,6 +535,7 @@ fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} #[derive(Default)] struct MyContextProvider { options: ConfigOptions, + tables: HashMap>, udafs: HashMap>, expr_planners: Vec>, } @@ -525,38 +551,23 @@ impl MyContextProvider { self.expr_planners = expr_planners; self } + + fn with_schema(mut self, name: impl Into, schema: Schema) -> Self { + self.tables.insert( + name.into(), + Arc::new(MyTableSource { + schema: Arc::new(schema), + }), + ); + self + } } impl ContextProvider for MyContextProvider { fn get_table_source(&self, name: TableReference) -> Result> { let table_name = name.table(); - if table_name.starts_with("test") { - let schema = Schema::new_with_metadata( - vec![ - Field::new("col_int32", DataType::Int32, true), - Field::new("col_uint32", DataType::UInt32, true), - Field::new("col_utf8", DataType::Utf8, true), - Field::new("col_date32", DataType::Date32, true), - Field::new("col_date64", DataType::Date64, true), - // timestamp with no timezone - Field::new( - "col_ts_nano_none", - DataType::Timestamp(TimeUnit::Nanosecond, None), - true, - ), - // timestamp with UTC timezone - Field::new( - "col_ts_nano_utc", - DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), - true, - ), - ], - HashMap::new(), - ); - - Ok(Arc::new(MyTableSource { - schema: Arc::new(schema), - })) + if let Some(table) = self.tables.get(table_name) { + Ok(table.clone()) } else { plan_err!("table does not exist") } diff --git a/datafusion/physical-expr-common/README.md b/datafusion/physical-expr-common/README.md index 7a1eff77d3b4..fab03fb49752 100644 --- a/datafusion/physical-expr-common/README.md +++ b/datafusion/physical-expr-common/README.md @@ -24,4 +24,9 @@ This crate is a submodule of DataFusion that provides shared APIs for implementing physical expressions such as `PhysicalExpr` and `PhysicalSortExpr`. +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 3bc41d2652d9..b4cb08715f53 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -25,9 +25,11 @@ use crate::utils::scatter; use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; @@ -71,11 +73,23 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; /// Get the data type of this expression, given the schema of the input - fn data_type(&self, input_schema: &Schema) -> Result; + fn data_type(&self, input_schema: &Schema) -> Result { + Ok(self.return_field(input_schema)?.data_type().to_owned()) + } /// Determine whether this expression is nullable, given the schema of the input - fn nullable(&self, input_schema: &Schema) -> Result; + fn nullable(&self, input_schema: &Schema) -> Result { + Ok(self.return_field(input_schema)?.is_nullable()) + } /// Evaluate an expression against a RecordBatch fn evaluate(&self, batch: &RecordBatch) -> Result; + /// The output field associated with this expression + fn return_field(&self, input_schema: &Schema) -> Result { + Ok(Arc::new(Field::new( + format!("{self}"), + self.data_type(input_schema)?, + self.nullable(input_schema)?, + ))) + } /// Evaluate an expression against a RecordBatch after first applying a /// validity array fn evaluate_selection( @@ -333,6 +347,24 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { // This is a safe default behavior. Ok(None) } + + /// Returns the generation of this `PhysicalExpr` for snapshotting purposes. + /// The generation is an arbitrary u64 that can be used to track changes + /// in the state of the `PhysicalExpr` over time without having to do an exhaustive comparison. + /// This is useful to avoid unecessary computation or serialization if there are no changes to the expression. + /// In particular, dynamic expressions that may change over time; this allows cheap checks for changes. + /// Static expressions that do not change over time should return 0, as does the default implementation. + /// You should not call this method directly as it does not handle recursion. + /// Instead use [`snapshot_generation`] to handle recursion and capture the + /// full state of the `PhysicalExpr`. + fn snapshot_generation(&self) -> u64 { + // By default, we return 0 to indicate that this PhysicalExpr does not + // have any dynamic references or state. + // Since the recursive algorithm XORs the generations of all children the overall + // generation will be 0 if no children have a non-zero generation, meaning that + // static expressions will always return 0. + 0 + } } /// [`PhysicalExpr`] can't be constrained by [`Eq`] directly because it must remain object @@ -434,10 +466,10 @@ where let mut iter = self.0.clone(); write!(f, "[")?; if let Some(expr) = iter.next() { - write!(f, "{}", expr)?; + write!(f, "{expr}")?; } for expr in iter { - write!(f, ", {}", expr)?; + write!(f, ", {expr}")?; } write!(f, "]")?; Ok(()) @@ -453,19 +485,21 @@ where /// ``` /// # // The boiler plate needed to create a `PhysicalExpr` for the example /// # use std::any::Any; +/// use std::collections::HashMap; /// # use std::fmt::Formatter; /// # use std::sync::Arc; /// # use arrow::array::RecordBatch; -/// # use arrow::datatypes::{DataType, Schema}; +/// # use arrow::datatypes::{DataType, Field, FieldRef, Schema}; /// # use datafusion_common::Result; /// # use datafusion_expr_common::columnar_value::ColumnarValue; /// # use datafusion_physical_expr_common::physical_expr::{fmt_sql, DynEq, PhysicalExpr}; /// # #[derive(Debug, Hash, PartialOrd, PartialEq)] -/// # struct MyExpr {}; +/// # struct MyExpr {} /// # impl PhysicalExpr for MyExpr {fn as_any(&self) -> &dyn Any { unimplemented!() } /// # fn data_type(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn nullable(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result { unimplemented!() } +/// # fn return_field(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn children(&self) -> Vec<&Arc>{ unimplemented!() } /// # fn with_new_children(self: Arc, children: Vec>) -> Result> { unimplemented!() } /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CASE a > b THEN 1 ELSE 0 END") } @@ -523,3 +557,31 @@ pub fn snapshot_physical_expr( }) .data() } + +/// Check the generation of this `PhysicalExpr`. +/// Dynamic `PhysicalExpr`s may have a generation that is incremented +/// every time the state of the `PhysicalExpr` changes. +/// If the generation changes that means this `PhysicalExpr` or one of its children +/// has changed since the last time it was evaluated. +/// +/// This algorithm will not produce collisions as long as the structure of the +/// `PhysicalExpr` does not change and no `PhysicalExpr` decrements its own generation. +pub fn snapshot_generation(expr: &Arc) -> u64 { + let mut generation = 0u64; + expr.apply(|e| { + // Add the current generation of the `PhysicalExpr` to our global generation. + generation = generation.wrapping_add(e.snapshot_generation()); + Ok(TreeNodeRecursion::Continue) + }) + .expect("this traversal is infallible"); + + generation +} + +/// Check if the given `PhysicalExpr` is dynamic. +/// Internally this calls [`snapshot_generation`] to check if the generation is non-zero, +/// any dynamic `PhysicalExpr` should have a non-zero generation. +pub fn is_dynamic_physical_expr(expr: &Arc) -> bool { + // If the generation is non-zero, then this `PhysicalExpr` is dynamic. + snapshot_generation(expr) != 0 +} diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 3a54b5b40399..07edfb70f4aa 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -17,33 +17,34 @@ //! Sort expressions -use crate::physical_expr::{fmt_sql, PhysicalExpr}; -use std::fmt; -use std::fmt::{Display, Formatter}; +use std::cmp::Ordering; +use std::fmt::{self, Display, Formatter}; use std::hash::{Hash, Hasher}; -use std::ops::{Deref, Index, Range, RangeFrom, RangeTo}; -use std::sync::{Arc, LazyLock}; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; use std::vec::IntoIter; +use crate::physical_expr::{fmt_sql, PhysicalExpr}; + use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; -use datafusion_common::Result; +use datafusion_common::{HashSet, Result}; use datafusion_expr_common::columnar_value::ColumnarValue; -use itertools::Itertools; /// Represents Sort operation for a column in a RecordBatch /// /// Example: /// ``` /// # use std::any::Any; +/// # use std::collections::HashMap; /// # use std::fmt::{Display, Formatter}; /// # use std::hash::Hasher; /// # use std::sync::Arc; /// # use arrow::array::RecordBatch; /// # use datafusion_common::Result; /// # use arrow::compute::SortOptions; -/// # use arrow::datatypes::{DataType, Schema}; +/// # use arrow::datatypes::{DataType, Field, FieldRef, Schema}; /// # use datafusion_expr_common::columnar_value::ColumnarValue; /// # use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// # use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; @@ -56,6 +57,7 @@ use itertools::Itertools; /// # fn data_type(&self, input_schema: &Schema) -> Result {todo!()} /// # fn nullable(&self, input_schema: &Schema) -> Result {todo!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result {todo!() } +/// # fn return_field(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn children(&self) -> Vec<&Arc> {todo!()} /// # fn with_new_children(self: Arc, children: Vec>) -> Result> {todo!()} /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { todo!() } @@ -75,7 +77,7 @@ use itertools::Itertools; /// .nulls_last(); /// assert_eq!(sort_expr.to_string(), "a DESC NULLS LAST"); /// ``` -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq)] pub struct PhysicalSortExpr { /// Physical expression representing the column to sort pub expr: Arc, @@ -94,6 +96,15 @@ impl PhysicalSortExpr { Self::new(expr, SortOptions::default()) } + /// Reverses the sort expression. For instance, `[a ASC NULLS LAST]` turns + /// into `[a DESC NULLS FIRST]`. Such reversals are useful in planning, e.g. + /// when constructing equivalent window expressions. + pub fn reverse(&self) -> Self { + let mut result = self.clone(); + result.options = !result.options; + result + } + /// Set the sort sort options to ASC pub fn asc(mut self) -> Self { self.options.descending = false; @@ -127,23 +138,58 @@ impl PhysicalSortExpr { to_str(&self.options) ) } -} -/// Access the PhysicalSortExpr as a PhysicalExpr -impl AsRef for PhysicalSortExpr { - fn as_ref(&self) -> &(dyn PhysicalExpr + 'static) { - self.expr.as_ref() + /// Evaluates the sort expression into a `SortColumn` that can be passed + /// into the arrow sort kernel. + pub fn evaluate_to_sort_column(&self, batch: &RecordBatch) -> Result { + let array_to_sort = match self.expr.evaluate(batch)? { + ColumnarValue::Array(array) => array, + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(batch.num_rows())?, + }; + Ok(SortColumn { + values: array_to_sort, + options: Some(self.options), + }) + } + + /// Checks whether this sort expression satisfies the given `requirement`. + /// If sort options are unspecified in `requirement`, only expressions are + /// compared for inequality. See [`options_compatible`] for details on + /// how sort options compare with one another. + pub fn satisfy( + &self, + requirement: &PhysicalSortRequirement, + schema: &Schema, + ) -> bool { + self.expr.eq(&requirement.expr) + && requirement.options.is_none_or(|opts| { + options_compatible( + &self.options, + &opts, + self.expr.nullable(schema).unwrap_or(true), + ) + }) + } + + /// Checks whether this sort expression satisfies the given `sort_expr`. + /// See [`options_compatible`] for details on how sort options compare with + /// one another. + pub fn satisfy_expr(&self, sort_expr: &Self, schema: &Schema) -> bool { + self.expr.eq(&sort_expr.expr) + && options_compatible( + &self.options, + &sort_expr.options, + self.expr.nullable(schema).unwrap_or(true), + ) } } impl PartialEq for PhysicalSortExpr { - fn eq(&self, other: &PhysicalSortExpr) -> bool { + fn eq(&self, other: &Self) -> bool { self.options == other.options && self.expr.eq(&other.expr) } } -impl Eq for PhysicalSortExpr {} - impl Hash for PhysicalSortExpr { fn hash(&self, state: &mut H) { self.expr.hash(state); @@ -157,38 +203,20 @@ impl Display for PhysicalSortExpr { } } -impl PhysicalSortExpr { - /// evaluate the sort expression into SortColumn that can be passed into arrow sort kernel - pub fn evaluate_to_sort_column(&self, batch: &RecordBatch) -> Result { - let value_to_sort = self.expr.evaluate(batch)?; - let array_to_sort = match value_to_sort { - ColumnarValue::Array(array) => array, - ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(batch.num_rows())?, - }; - Ok(SortColumn { - values: array_to_sort, - options: Some(self.options), - }) - } - - /// Checks whether this sort expression satisfies the given `requirement`. - /// If sort options are unspecified in `requirement`, only expressions are - /// compared for inequality. - pub fn satisfy( - &self, - requirement: &PhysicalSortRequirement, - schema: &Schema, - ) -> bool { +/// Returns whether the given two [`SortOptions`] are compatible. Here, +/// compatibility means that they are either exactly equal, or they differ only +/// in whether NULL values come in first/last, which is immaterial because the +/// column in question is not nullable (specified by the `nullable` parameter). +pub fn options_compatible( + options_lhs: &SortOptions, + options_rhs: &SortOptions, + nullable: bool, +) -> bool { + if nullable { + options_lhs == options_rhs + } else { // If the column is not nullable, NULLS FIRST/LAST is not important. - let nullable = self.expr.nullable(schema).unwrap_or(true); - self.expr.eq(&requirement.expr) - && if nullable { - requirement.options.is_none_or(|opts| self.options == opts) - } else { - requirement - .options - .is_none_or(|opts| self.options.descending == opts.descending) - } + options_lhs.descending == options_rhs.descending } } @@ -220,28 +248,8 @@ pub struct PhysicalSortRequirement { pub options: Option, } -impl From for PhysicalSortExpr { - /// If options is `None`, the default sort options `ASC, NULLS LAST` is used. - /// - /// The default is picked to be consistent with - /// PostgreSQL: - fn from(value: PhysicalSortRequirement) -> Self { - let options = value.options.unwrap_or(SortOptions { - descending: false, - nulls_first: false, - }); - PhysicalSortExpr::new(value.expr, options) - } -} - -impl From for PhysicalSortRequirement { - fn from(value: PhysicalSortExpr) -> Self { - PhysicalSortRequirement::new(value.expr, Some(value.options)) - } -} - impl PartialEq for PhysicalSortRequirement { - fn eq(&self, other: &PhysicalSortRequirement) -> bool { + fn eq(&self, other: &Self) -> bool { self.options == other.options && self.expr.eq(&other.expr) } } @@ -265,10 +273,10 @@ pub fn format_physical_sort_requirement_list( let mut iter = self.0.iter(); write!(f, "[")?; if let Some(expr) = iter.next() { - write!(f, "{}", expr)?; + write!(f, "{expr}")?; } for expr in iter { - write!(f, ", {}", expr)?; + write!(f, ", {expr}")?; } write!(f, "]")?; Ok(()) @@ -291,37 +299,16 @@ impl PhysicalSortRequirement { Self { expr, options } } - /// Replace the required expression for this requirement with the new one - pub fn with_expr(mut self, expr: Arc) -> Self { - self.expr = expr; - self - } - /// Returns whether this requirement is equal or more specific than `other`. - pub fn compatible(&self, other: &PhysicalSortRequirement) -> bool { + pub fn compatible(&self, other: &Self) -> bool { self.expr.eq(&other.expr) && other .options .is_none_or(|other_opts| self.options == Some(other_opts)) } - - #[deprecated(since = "43.0.0", note = "use LexRequirement::from_lex_ordering")] - pub fn from_sort_exprs<'a>( - ordering: impl IntoIterator, - ) -> LexRequirement { - let ordering = ordering.into_iter().cloned().collect(); - LexRequirement::from_lex_ordering(ordering) - } - #[deprecated(since = "43.0.0", note = "use LexOrdering::from_lex_requirement")] - pub fn to_sort_exprs( - requirements: impl IntoIterator, - ) -> LexOrdering { - let requirements = requirements.into_iter().collect(); - LexOrdering::from_lex_requirement(requirements) - } } -/// Returns the SQL string representation of the given [SortOptions] object. +/// Returns the SQL string representation of the given [`SortOptions`] object. #[inline] fn to_str(options: &SortOptions) -> &str { match (options.descending, options.nulls_first) { @@ -332,162 +319,135 @@ fn to_str(options: &SortOptions) -> &str { } } -///`LexOrdering` contains a `Vec`, which represents -/// a lexicographical ordering. +// Cross-conversion utilities between `PhysicalSortExpr` and `PhysicalSortRequirement` +impl From for PhysicalSortRequirement { + fn from(value: PhysicalSortExpr) -> Self { + Self::new(value.expr, Some(value.options)) + } +} + +impl From for PhysicalSortExpr { + /// The default sort options `ASC, NULLS LAST` when the requirement does + /// not specify sort options. This default is consistent with PostgreSQL. + /// + /// Reference: + fn from(value: PhysicalSortRequirement) -> Self { + let options = value + .options + .unwrap_or_else(|| SortOptions::new(false, false)); + Self::new(value.expr, options) + } +} + +/// This object represents a lexicographical ordering and contains a vector +/// of `PhysicalSortExpr` objects. /// -/// For example, `vec![a ASC, b DESC]` represents a lexicographical ordering +/// For example, a `vec![a ASC, b DESC]` represents a lexicographical ordering /// that first sorts by column `a` in ascending order, then by column `b` in /// descending order. -#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +/// +/// # Invariants +/// +/// The following always hold true for a `LexOrdering`: +/// +/// 1. It is non-degenerate, meaning it contains at least one element. +/// 2. It is duplicate-free, meaning it does not contain multiple entries for +/// the same column. +#[derive(Debug, Clone, PartialEq, Eq)] pub struct LexOrdering { - inner: Vec, -} - -impl AsRef for LexOrdering { - fn as_ref(&self) -> &LexOrdering { - self - } + /// Vector of sort expressions representing the lexicographical ordering. + exprs: Vec, + /// Set of expressions in the lexicographical ordering, used to ensure + /// that the ordering is duplicate-free. Note that the elements in this + /// set are the same underlying physical expressions as in `exprs`. + set: HashSet>, } impl LexOrdering { - /// Creates a new [`LexOrdering`] from a vector - pub fn new(inner: Vec) -> Self { - Self { inner } - } - - /// Return an empty LexOrdering (no expressions) - pub fn empty() -> &'static LexOrdering { - static EMPTY_ORDER: LazyLock = LazyLock::new(LexOrdering::default); - &EMPTY_ORDER - } - - /// Returns the number of elements that can be stored in the LexOrdering - /// without reallocating. - pub fn capacity(&self) -> usize { - self.inner.capacity() - } - - /// Clears the LexOrdering, removing all elements. - pub fn clear(&mut self) { - self.inner.clear() - } - - /// Takes ownership of the actual vector of `PhysicalSortExpr`s in the LexOrdering. - pub fn take_exprs(self) -> Vec { - self.inner - } - - /// Returns `true` if the LexOrdering contains `expr` - pub fn contains(&self, expr: &PhysicalSortExpr) -> bool { - self.inner.contains(expr) - } - - /// Add all elements from `iter` to the LexOrdering. - pub fn extend>(&mut self, iter: I) { - self.inner.extend(iter) - } - - /// Remove all elements from the LexOrdering where `f` evaluates to `false`. - pub fn retain(&mut self, f: F) - where - F: FnMut(&PhysicalSortExpr) -> bool, - { - self.inner.retain(f) - } - - /// Returns `true` if the LexOrdering contains no elements. - pub fn is_empty(&self) -> bool { - self.inner.is_empty() - } - - /// Returns an iterator over each `&PhysicalSortExpr` in the LexOrdering. - pub fn iter(&self) -> core::slice::Iter { - self.inner.iter() + /// Creates a new [`LexOrdering`] from the given vector of sort expressions. + /// If the vector is empty, returns `None`. + pub fn new(exprs: impl IntoIterator) -> Option { + let (non_empty, ordering) = Self::construct(exprs); + non_empty.then_some(ordering) } - /// Returns the number of elements in the LexOrdering. - pub fn len(&self) -> usize { - self.inner.len() - } - - /// Removes the last element from the LexOrdering and returns it, or `None` if it is empty. - pub fn pop(&mut self) -> Option { - self.inner.pop() - } - - /// Appends an element to the back of the LexOrdering. - pub fn push(&mut self, physical_sort_expr: PhysicalSortExpr) { - self.inner.push(physical_sort_expr) + /// Appends an element to the back of the `LexOrdering`. + pub fn push(&mut self, sort_expr: PhysicalSortExpr) { + if self.set.insert(Arc::clone(&sort_expr.expr)) { + self.exprs.push(sort_expr); + } } - /// Truncates the LexOrdering, keeping only the first `len` elements. - pub fn truncate(&mut self, len: usize) { - self.inner.truncate(len) + /// Add all elements from `iter` to the `LexOrdering`. + pub fn extend(&mut self, sort_exprs: impl IntoIterator) { + for sort_expr in sort_exprs { + self.push(sort_expr); + } } - /// Merge the contents of `other` into `self`, removing duplicates. - pub fn merge(mut self, other: LexOrdering) -> Self { - self.inner = self.inner.into_iter().chain(other).unique().collect(); - self + /// Returns the leading `PhysicalSortExpr` of the `LexOrdering`. Note that + /// this function does not return an `Option`, as a `LexOrdering` is always + /// non-degenerate (i.e. it contains at least one element). + pub fn first(&self) -> &PhysicalSortExpr { + // Can safely `unwrap` because `LexOrdering` is non-degenerate: + self.exprs.first().unwrap() } - /// Converts a `LexRequirement` into a `LexOrdering`. - /// - /// This function converts [`PhysicalSortRequirement`] to [`PhysicalSortExpr`] - /// for each entry in the input. - /// - /// If the required ordering is `None` for an entry in `requirement`, the - /// default ordering `ASC, NULLS LAST` is used (see - /// [`PhysicalSortExpr::from`]). - pub fn from_lex_requirement(requirement: LexRequirement) -> LexOrdering { - requirement - .into_iter() - .map(PhysicalSortExpr::from) - .collect() + /// Returns the number of elements that can be stored in the `LexOrdering` + /// without reallocating. + pub fn capacity(&self) -> usize { + self.exprs.capacity() } - /// Collapse a `LexOrdering` into a new duplicate-free `LexOrdering` based on expression. - /// - /// This function filters duplicate entries that have same physical - /// expression inside, ignoring [`SortOptions`]. For example: - /// - /// `vec![a ASC, a DESC]` collapses to `vec![a ASC]`. - pub fn collapse(self) -> Self { - let mut output = LexOrdering::default(); - for item in self { - if !output.iter().any(|req| req.expr.eq(&item.expr)) { - output.push(item); - } + /// Truncates the `LexOrdering`, keeping only the first `len` elements. + /// Returns `true` if truncation made a change, `false` otherwise. Negative + /// cases happen in two scenarios: (1) When `len` is greater than or equal + /// to the number of expressions inside this `LexOrdering`, making truncation + /// a no-op, or (2) when `len` is `0`, making truncation impossible. + pub fn truncate(&mut self, len: usize) -> bool { + if len == 0 || len >= self.exprs.len() { + return false; } - output - } - - /// Transforms each `PhysicalSortExpr` in the `LexOrdering` - /// in place using the provided closure `f`. - pub fn transform(&mut self, f: F) - where - F: FnMut(&mut PhysicalSortExpr), - { - self.inner.iter_mut().for_each(f); + for PhysicalSortExpr { expr, .. } in self.exprs[len..].iter() { + self.set.remove(expr); + } + self.exprs.truncate(len); + true } -} -impl From> for LexOrdering { - fn from(value: Vec) -> Self { - Self::new(value) + /// Constructs a new `LexOrdering` from the given sort requirements w/o + /// enforcing non-degeneracy. This function is used internally and is not + /// meant (or safe) for external use. + fn construct(exprs: impl IntoIterator) -> (bool, Self) { + let mut set = HashSet::new(); + let exprs = exprs + .into_iter() + .filter_map(|s| set.insert(Arc::clone(&s.expr)).then_some(s)) + .collect(); + (!set.is_empty(), Self { exprs, set }) } } -impl From for LexOrdering { - fn from(value: LexRequirement) -> Self { - Self::from_lex_requirement(value) +impl PartialOrd for LexOrdering { + /// There is a partial ordering among `LexOrdering` objects. For example, the + /// ordering `[a ASC]` is coarser (less) than ordering `[a ASC, b ASC]`. + /// If two orderings do not share a prefix, they are incomparable. + fn partial_cmp(&self, other: &Self) -> Option { + self.iter() + .zip(other.iter()) + .all(|(lhs, rhs)| lhs == rhs) + .then(|| self.len().cmp(&other.len())) } } -/// Convert a `LexOrdering` into a `Arc[]` for fast copies -impl From for Arc<[PhysicalSortExpr]> { - fn from(value: LexOrdering) -> Self { - value.inner.into() +impl From<[PhysicalSortExpr; N]> for LexOrdering { + fn from(value: [PhysicalSortExpr; N]) -> Self { + // TODO: Replace this assertion with a condition on the generic parameter + // when Rust supports it. + assert!(N > 0); + let (non_empty, ordering) = Self::construct(value); + debug_assert!(non_empty); + ordering } } @@ -495,181 +455,269 @@ impl Deref for LexOrdering { type Target = [PhysicalSortExpr]; fn deref(&self) -> &Self::Target { - self.inner.as_slice() + self.exprs.as_slice() } } impl Display for LexOrdering { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let mut first = true; - for sort_expr in &self.inner { + for sort_expr in &self.exprs { if first { first = false; } else { write!(f, ", ")?; } - write!(f, "{}", sort_expr)?; + write!(f, "{sort_expr}")?; } Ok(()) } } -impl FromIterator for LexOrdering { - fn from_iter>(iter: T) -> Self { - let mut lex_ordering = LexOrdering::default(); - - for i in iter { - lex_ordering.push(i); - } +impl IntoIterator for LexOrdering { + type Item = PhysicalSortExpr; + type IntoIter = IntoIter; - lex_ordering + fn into_iter(self) -> Self::IntoIter { + self.exprs.into_iter() } } -impl Index for LexOrdering { - type Output = PhysicalSortExpr; +impl<'a> IntoIterator for &'a LexOrdering { + type Item = &'a PhysicalSortExpr; + type IntoIter = std::slice::Iter<'a, PhysicalSortExpr>; - fn index(&self, index: usize) -> &Self::Output { - &self.inner[index] + fn into_iter(self) -> Self::IntoIter { + self.exprs.iter() } } -impl Index> for LexOrdering { - type Output = [PhysicalSortExpr]; - - fn index(&self, range: Range) -> &Self::Output { - &self.inner[range] +impl From for Vec { + fn from(ordering: LexOrdering) -> Self { + ordering.exprs } } -impl Index> for LexOrdering { - type Output = [PhysicalSortExpr]; +/// This object represents a lexicographical ordering requirement and contains +/// a vector of `PhysicalSortRequirement` objects. +/// +/// For example, a `vec![a Some(ASC), b None]` represents a lexicographical +/// requirement that firsts imposes an ordering by column `a` in ascending +/// order, then by column `b` in *any* (ascending or descending) order. The +/// ordering is non-degenerate, meaning it contains at least one element, and +/// it is duplicate-free, meaning it does not contain multiple entries for the +/// same column. +/// +/// Note that a `LexRequirement` need not enforce the uniqueness of its sort +/// expressions after construction like a `LexOrdering` does, because it provides +/// no mutation methods. If such methods become necessary, we will need to +/// enforce uniqueness like the latter object. +#[derive(Debug, Clone, PartialEq)] +pub struct LexRequirement { + reqs: Vec, +} - fn index(&self, range_from: RangeFrom) -> &Self::Output { - &self.inner[range_from] +impl LexRequirement { + /// Creates a new [`LexRequirement`] from the given vector of sort expressions. + /// If the vector is empty, returns `None`. + pub fn new(reqs: impl IntoIterator) -> Option { + let (non_empty, requirements) = Self::construct(reqs); + non_empty.then_some(requirements) + } + + /// Returns the leading `PhysicalSortRequirement` of the `LexRequirement`. + /// Note that this function does not return an `Option`, as a `LexRequirement` + /// is always non-degenerate (i.e. it contains at least one element). + pub fn first(&self) -> &PhysicalSortRequirement { + // Can safely `unwrap` because `LexRequirement` is non-degenerate: + self.reqs.first().unwrap() + } + + /// Constructs a new `LexRequirement` from the given sort requirements w/o + /// enforcing non-degeneracy. This function is used internally and is not + /// meant (or safe) for external use. + fn construct( + reqs: impl IntoIterator, + ) -> (bool, Self) { + let mut set = HashSet::new(); + let reqs = reqs + .into_iter() + .filter_map(|r| set.insert(Arc::clone(&r.expr)).then_some(r)) + .collect(); + (!set.is_empty(), Self { reqs }) } } -impl Index> for LexOrdering { - type Output = [PhysicalSortExpr]; - - fn index(&self, range_to: RangeTo) -> &Self::Output { - &self.inner[range_to] +impl From<[PhysicalSortRequirement; N]> for LexRequirement { + fn from(value: [PhysicalSortRequirement; N]) -> Self { + // TODO: Replace this assertion with a condition on the generic parameter + // when Rust supports it. + assert!(N > 0); + let (non_empty, requirement) = Self::construct(value); + debug_assert!(non_empty); + requirement } } -impl IntoIterator for LexOrdering { - type Item = PhysicalSortExpr; - type IntoIter = IntoIter; +impl Deref for LexRequirement { + type Target = [PhysicalSortRequirement]; - fn into_iter(self) -> Self::IntoIter { - self.inner.into_iter() + fn deref(&self) -> &Self::Target { + self.reqs.as_slice() } } -///`LexOrderingRef` is an alias for the type &`[PhysicalSortExpr]`, which represents -/// a reference to a lexicographical ordering. -#[deprecated(since = "43.0.0", note = "use &LexOrdering instead")] -pub type LexOrderingRef<'a> = &'a [PhysicalSortExpr]; +impl IntoIterator for LexRequirement { + type Item = PhysicalSortRequirement; + type IntoIter = IntoIter; -///`LexRequirement` is an struct containing a `Vec`, which -/// represents a lexicographical ordering requirement. -#[derive(Debug, Default, Clone, PartialEq)] -pub struct LexRequirement { - pub inner: Vec, + fn into_iter(self) -> Self::IntoIter { + self.reqs.into_iter() + } } -impl LexRequirement { - pub fn new(inner: Vec) -> Self { - Self { inner } - } +impl<'a> IntoIterator for &'a LexRequirement { + type Item = &'a PhysicalSortRequirement; + type IntoIter = std::slice::Iter<'a, PhysicalSortRequirement>; - pub fn is_empty(&self) -> bool { - self.inner.is_empty() + fn into_iter(self) -> Self::IntoIter { + self.reqs.iter() } +} - pub fn iter(&self) -> impl Iterator { - self.inner.iter() +impl From for Vec { + fn from(requirement: LexRequirement) -> Self { + requirement.reqs } +} - pub fn push(&mut self, physical_sort_requirement: PhysicalSortRequirement) { - self.inner.push(physical_sort_requirement) +// Cross-conversion utilities between `LexOrdering` and `LexRequirement` +impl From for LexRequirement { + fn from(value: LexOrdering) -> Self { + // Can construct directly as `value` is non-degenerate: + let (non_empty, requirements) = + Self::construct(value.into_iter().map(Into::into)); + debug_assert!(non_empty); + requirements } +} - /// Create a new [`LexRequirement`] from a [`LexOrdering`] - /// - /// Returns [`LexRequirement`] that requires the exact - /// sort of the [`PhysicalSortExpr`]s in `ordering` - pub fn from_lex_ordering(ordering: LexOrdering) -> Self { - Self::new( - ordering - .into_iter() - .map(PhysicalSortRequirement::from) - .collect(), - ) +impl From for LexOrdering { + fn from(value: LexRequirement) -> Self { + // Can construct directly as `value` is non-degenerate: + let (non_empty, ordering) = Self::construct(value.into_iter().map(Into::into)); + debug_assert!(non_empty); + ordering } +} - /// Constructs a duplicate-free `LexOrderingReq` by filtering out - /// duplicate entries that have same physical expression inside. - /// - /// For example, `vec![a Some(ASC), a Some(DESC)]` collapses to `vec![a - /// Some(ASC)]`. - pub fn collapse(self) -> Self { - let mut output = Vec::::new(); - for item in self { - if !output.iter().any(|req| req.expr.eq(&item.expr)) { - output.push(item); +/// Represents a plan's input ordering requirements. Vector elements represent +/// alternative ordering requirements in the order of preference. The list of +/// alternatives can be either hard or soft, depending on whether the operator +/// can work without an input ordering. +/// +/// # Invariants +/// +/// The following always hold true for a `OrderingRequirements`: +/// +/// 1. It is non-degenerate, meaning it contains at least one ordering. The +/// absence of an input ordering requirement is represented by a `None` value +/// in `ExecutionPlan` APIs, which return an `Option`. +#[derive(Debug, Clone, PartialEq)] +pub enum OrderingRequirements { + /// The operator is not able to work without one of these requirements. + Hard(Vec), + /// The operator can benefit from these input orderings when available, + /// but can still work in the absence of any input ordering. + Soft(Vec), +} + +impl OrderingRequirements { + /// Creates a new instance from the given alternatives. If an empty list of + /// alternatives are given, returns `None`. + pub fn new_alternatives( + alternatives: impl IntoIterator, + soft: bool, + ) -> Option { + let alternatives = alternatives.into_iter().collect::>(); + (!alternatives.is_empty()).then(|| { + if soft { + Self::Soft(alternatives) + } else { + Self::Hard(alternatives) } - } - LexRequirement::new(output) + }) } -} -impl From for LexRequirement { - fn from(value: LexOrdering) -> Self { - Self::from_lex_ordering(value) + /// Creates a new instance with a single hard requirement. + pub fn new(requirement: LexRequirement) -> Self { + Self::Hard(vec![requirement]) } -} -impl Deref for LexRequirement { - type Target = [PhysicalSortRequirement]; + /// Creates a new instance with a single soft requirement. + pub fn new_soft(requirement: LexRequirement) -> Self { + Self::Soft(vec![requirement]) + } - fn deref(&self) -> &Self::Target { - self.inner.as_slice() + /// Adds an alternative requirement to the list of alternatives. + pub fn add_alternative(&mut self, requirement: LexRequirement) { + match self { + Self::Hard(alts) | Self::Soft(alts) => alts.push(requirement), + } } -} -impl FromIterator for LexRequirement { - fn from_iter>(iter: T) -> Self { - let mut lex_requirement = LexRequirement::new(vec![]); + /// Returns the first (i.e. most preferred) `LexRequirement` among + /// alternative requirements. + pub fn into_single(self) -> LexRequirement { + match self { + Self::Hard(mut alts) | Self::Soft(mut alts) => alts.swap_remove(0), + } + } - for i in iter { - lex_requirement.inner.push(i); + /// Returns a reference to the first (i.e. most preferred) `LexRequirement` + /// among alternative requirements. + pub fn first(&self) -> &LexRequirement { + match self { + Self::Hard(alts) | Self::Soft(alts) => &alts[0], } + } - lex_requirement + /// Returns all alternatives as a vector of `LexRequirement` objects and a + /// boolean value indicating softness/hardness of the requirements. + pub fn into_alternatives(self) -> (Vec, bool) { + match self { + Self::Hard(alts) => (alts, false), + Self::Soft(alts) => (alts, true), + } } } -impl IntoIterator for LexRequirement { - type Item = PhysicalSortRequirement; - type IntoIter = IntoIter; +impl From for OrderingRequirements { + fn from(requirement: LexRequirement) -> Self { + Self::new(requirement) + } +} - fn into_iter(self) -> Self::IntoIter { - self.inner.into_iter() +impl From for OrderingRequirements { + fn from(ordering: LexOrdering) -> Self { + Self::new(ordering.into()) } } -impl<'a> IntoIterator for &'a LexOrdering { - type Item = &'a PhysicalSortExpr; - type IntoIter = std::slice::Iter<'a, PhysicalSortExpr>; +impl Deref for OrderingRequirements { + type Target = [LexRequirement]; - fn into_iter(self) -> Self::IntoIter { - self.inner.iter() + fn deref(&self) -> &Self::Target { + match &self { + Self::Hard(alts) | Self::Soft(alts) => alts.as_slice(), + } } } -///`LexRequirementRef` is an alias for the type &`[PhysicalSortRequirement]`, which -/// represents a reference to a lexicographical ordering requirement. -/// #[deprecated(since = "43.0.0", note = "use &LexRequirement instead")] -pub type LexRequirementRef<'a> = &'a [PhysicalSortRequirement]; +impl DerefMut for OrderingRequirements { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + Self::Hard(alts) | Self::Soft(alts) => alts.as_mut_slice(), + } + } +} diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index 114007bfa6af..05b216ab75eb 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -17,16 +17,14 @@ use std::sync::Arc; +use crate::physical_expr::PhysicalExpr; +use crate::tree_node::ExprContext; + use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; - use datafusion_common::Result; use datafusion_expr_common::sort_properties::ExprProperties; -use crate::physical_expr::PhysicalExpr; -use crate::sort_expr::{LexOrdering, PhysicalSortExpr}; -use crate::tree_node::ExprContext; - /// Represents a [`PhysicalExpr`] node with associated properties (order and /// range) in a context where properties are tracked. pub type ExprPropertiesNode = ExprContext; @@ -93,16 +91,6 @@ pub fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result { Ok(make_array(data)) } -/// Reverses the ORDER BY expression, which is useful during equivalent window -/// expression construction. For instance, 'ORDER BY a ASC, NULLS LAST' turns into -/// 'ORDER BY a DESC, NULLS FIRST'. -pub fn reverse_order_bys(order_bys: &LexOrdering) -> LexOrdering { - order_bys - .iter() - .map(|e| PhysicalSortExpr::new(Arc::clone(&e.expr), !e.options)) - .collect() -} - #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 47e3291e5cb4..881969ef32ad 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -51,7 +51,7 @@ indexmap = { workspace = true } itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } paste = "^1.0" -petgraph = "0.7.1" +petgraph = "0.8.2" [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } diff --git a/datafusion/physical-expr/README.md b/datafusion/physical-expr/README.md index 424256c77e7e..b99f3c4946ce 100644 --- a/datafusion/physical-expr/README.md +++ b/datafusion/physical-expr/README.md @@ -23,4 +23,9 @@ This crate is a submodule of DataFusion that provides data types and utilities for physical expressions. +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/physical-expr/benches/binary_op.rs b/datafusion/physical-expr/benches/binary_op.rs index 59a602df053c..5b0f700fdb8a 100644 --- a/datafusion/physical-expr/benches/binary_op.rs +++ b/datafusion/physical-expr/benches/binary_op.rs @@ -126,14 +126,25 @@ fn generate_boolean_cases( )); } + // Scenario 7: Test all true or all false in AND/OR + // This situation won't cause a short circuit, but it can skip the bool calculation + if TEST_ALL_FALSE { + let all_true = vec![true; len]; + cases.push(("all_true_in_and".to_string(), BooleanArray::from(all_true))); + } else { + let all_false = vec![false; len]; + cases.push(("all_false_in_or".to_string(), BooleanArray::from(all_false))); + } + cases } /// Benchmarks AND/OR operator short-circuiting by evaluating complex regex conditions. /// -/// Creates 6 test scenarios per operator: +/// Creates 7 test scenarios per operator: /// 1. All values enable short-circuit (all_true/all_false) /// 2. 2-6 Single true/false value at different positions to measure early exit +/// 3. Test all true or all false in AND/OR /// /// You can run this benchmark with: /// ```sh @@ -203,16 +214,16 @@ fn benchmark_binary_op_in_short_circuit(c: &mut Criterion) { // Each scenario when the test operator is `and` { - for (name, batch) in batches_and { - c.bench_function(&format!("short_circuit/and/{}", name), |b| { + for (name, batch) in batches_and.into_iter() { + c.bench_function(&format!("short_circuit/and/{name}"), |b| { b.iter(|| expr_and.evaluate(black_box(&batch)).unwrap()) }); } } // Each scenario when the test operator is `or` { - for (name, batch) in batches_or { - c.bench_function(&format!("short_circuit/or/{}", name), |b| { + for (name, batch) in batches_or.into_iter() { + c.bench_function(&format!("short_circuit/or/{name}"), |b| { b.iter(|| expr_or.evaluate(black_box(&batch)).unwrap()) }); } diff --git a/datafusion/physical-expr/benches/in_list.rs b/datafusion/physical-expr/benches/in_list.rs index 90bfc5efb61e..e91e8d1f137c 100644 --- a/datafusion/physical-expr/benches/in_list.rs +++ b/datafusion/physical-expr/benches/in_list.rs @@ -21,7 +21,7 @@ use arrow::record_batch::RecordBatch; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; use datafusion_physical_expr::expressions::{col, in_list, lit}; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use rand::prelude::*; use std::sync::Arc; @@ -51,7 +51,7 @@ fn do_benches( for string_length in [5, 10, 20] { let values: StringArray = (0..array_length) .map(|_| { - rng.gen_bool(null_percent) + rng.random_bool(null_percent) .then(|| random_string(&mut rng, string_length)) }) .collect(); @@ -71,11 +71,11 @@ fn do_benches( } let values: Float32Array = (0..array_length) - .map(|_| rng.gen_bool(null_percent).then(|| rng.gen())) + .map(|_| rng.random_bool(null_percent).then(|| rng.random())) .collect(); let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Float32(Some(rng.gen()))) + .map(|_| ScalarValue::Float32(Some(rng.random()))) .collect(); do_bench( @@ -86,11 +86,11 @@ fn do_benches( ); let values: Int32Array = (0..array_length) - .map(|_| rng.gen_bool(null_percent).then(|| rng.gen())) + .map(|_| rng.random_bool(null_percent).then(|| rng.random())) .collect(); let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Int32(Some(rng.gen()))) + .map(|_| ScalarValue::Int32(Some(rng.random()))) .collect(); do_bench( diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index 49912954ac81..9175c01274cb 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -41,7 +41,7 @@ use std::sync::Arc; use crate::expressions::Column; use arrow::compute::SortOptions; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef}; use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; use datafusion_expr::{AggregateUDF, ReversedUDAF, SetMonotonicity}; use datafusion_expr_common::accumulator::Accumulator; @@ -52,8 +52,7 @@ use datafusion_functions_aggregate_common::accumulator::{ }; use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_expr_common::utils::reverse_order_bys; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; /// Builder for physical [`AggregateFunctionExpr`] /// @@ -70,7 +69,7 @@ pub struct AggregateExprBuilder { /// Arrow Schema for the aggregate function schema: SchemaRef, /// The physical order by expressions - ordering_req: LexOrdering, + order_bys: Vec, /// Whether to ignore null values ignore_nulls: bool, /// Whether is distinct aggregate function @@ -87,7 +86,7 @@ impl AggregateExprBuilder { alias: None, human_display: String::default(), schema: Arc::new(Schema::empty()), - ordering_req: LexOrdering::default(), + order_bys: vec![], ignore_nulls: false, is_distinct: false, is_reversed: false, @@ -106,7 +105,7 @@ impl AggregateExprBuilder { /// ``` /// # use std::any::Any; /// # use std::sync::Arc; - /// # use arrow::datatypes::DataType; + /// # use arrow::datatypes::{DataType, FieldRef}; /// # use datafusion_common::{Result, ScalarValue}; /// # use datafusion_expr::{col, ColumnarValue, Documentation, Signature, Volatility, Expr}; /// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}}; @@ -128,25 +127,28 @@ impl AggregateExprBuilder { /// # impl AggregateUDFImpl for FirstValueUdf { /// # fn as_any(&self) -> &dyn Any { /// # unimplemented!() - /// # } + /// # } + /// # /// # fn name(&self) -> &str { /// # unimplemented!() - /// } + /// # } + /// # /// # fn signature(&self) -> &Signature { /// # unimplemented!() - /// # } + /// # } + /// # /// # fn return_type(&self, args: &[DataType]) -> Result { /// # unimplemented!() /// # } - /// # + /// # /// # fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { /// # unimplemented!() /// # } /// # - /// # fn state_fields(&self, args: StateFieldsArgs) -> Result> { + /// # fn state_fields(&self, args: StateFieldsArgs) -> Result> { /// # unimplemented!() /// # } - /// # + /// # /// # fn documentation(&self) -> Option<&Documentation> { /// # unimplemented!() /// # } @@ -169,16 +171,16 @@ impl AggregateExprBuilder { /// }]; /// /// let first_value = AggregateUDF::from(FirstValueUdf::new()); - /// + /// /// let aggregate_expr = AggregateExprBuilder::new( /// Arc::new(first_value), /// args /// ) - /// .order_by(order_by.into()) + /// .order_by(order_by) /// .alias("first_a_by_x") /// .ignore_nulls() /// .build()?; - /// + /// /// Ok(()) /// } /// ``` @@ -192,7 +194,7 @@ impl AggregateExprBuilder { alias, human_display, schema, - ordering_req, + order_bys, ignore_nulls, is_distinct, is_reversed, @@ -201,30 +203,25 @@ impl AggregateExprBuilder { return internal_err!("args should not be empty"); } - let mut ordering_fields = vec![]; - - if !ordering_req.is_empty() { - let ordering_types = ordering_req - .iter() - .map(|e| e.expr.data_type(&schema)) - .collect::>>()?; + let ordering_types = order_bys + .iter() + .map(|e| e.expr.data_type(&schema)) + .collect::>>()?; - ordering_fields = - utils::ordering_fields(ordering_req.as_ref(), &ordering_types); - } + let ordering_fields = utils::ordering_fields(&order_bys, &ordering_types); - let input_exprs_types = args + let input_exprs_fields = args .iter() - .map(|arg| arg.data_type(&schema)) + .map(|arg| arg.return_field(&schema)) .collect::>>()?; check_arg_count( fun.name(), - &input_exprs_types, + &input_exprs_fields, &fun.signature().type_signature, )?; - let data_type = fun.return_type(&input_exprs_types)?; + let return_field = fun.return_field(&input_exprs_fields)?; let is_nullable = fun.is_nullable(); let name = match alias { None => { @@ -238,15 +235,15 @@ impl AggregateExprBuilder { Ok(AggregateFunctionExpr { fun: Arc::unwrap_or_clone(fun), args, - data_type, + return_field, name, human_display, schema: Arc::unwrap_or_clone(schema), - ordering_req, + order_bys, ignore_nulls, ordering_fields, is_distinct, - input_types: input_exprs_types, + input_fields: input_exprs_fields, is_reversed, is_nullable, }) @@ -267,8 +264,8 @@ impl AggregateExprBuilder { self } - pub fn order_by(mut self, order_by: LexOrdering) -> Self { - self.ordering_req = order_by; + pub fn order_by(mut self, order_bys: Vec) -> Self { + self.order_bys = order_bys; self } @@ -310,22 +307,22 @@ impl AggregateExprBuilder { pub struct AggregateFunctionExpr { fun: AggregateUDF, args: Vec>, - /// Output / return type of this aggregate - data_type: DataType, + /// Output / return field of this aggregate + return_field: FieldRef, /// Output column name that this expression creates name: String, /// Simplified name for `tree` explain. human_display: String, schema: Schema, // The physical order by expressions - ordering_req: LexOrdering, + order_bys: Vec, // Whether to ignore null values ignore_nulls: bool, // fields used for order sensitive aggregation functions - ordering_fields: Vec, + ordering_fields: Vec, is_distinct: bool, is_reversed: bool, - input_types: Vec, + input_fields: Vec, is_nullable: bool, } @@ -372,8 +369,12 @@ impl AggregateFunctionExpr { } /// the field of the final result of this aggregation. - pub fn field(&self) -> Field { - Field::new(&self.name, self.data_type.clone(), self.is_nullable) + pub fn field(&self) -> FieldRef { + self.return_field + .as_ref() + .clone() + .with_name(&self.name) + .into() } /// the accumulator used to accumulate values from the expressions. @@ -381,10 +382,10 @@ impl AggregateFunctionExpr { /// return states with the same description as `state_fields` pub fn create_accumulator(&self) -> Result> { let acc_args = AccumulatorArgs { - return_type: &self.data_type, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, - ordering_req: self.ordering_req.as_ref(), + order_bys: self.order_bys.as_ref(), is_distinct: self.is_distinct, name: &self.name, is_reversed: self.is_reversed, @@ -395,11 +396,11 @@ impl AggregateFunctionExpr { } /// the field of the final result of this aggregation. - pub fn state_fields(&self) -> Result> { + pub fn state_fields(&self) -> Result> { let args = StateFieldsArgs { name: &self.name, - input_types: &self.input_types, - return_type: &self.data_type, + input_fields: &self.input_fields, + return_field: Arc::clone(&self.return_field), ordering_fields: &self.ordering_fields, is_distinct: self.is_distinct, }; @@ -407,31 +408,24 @@ impl AggregateFunctionExpr { self.fun.state_fields(args) } - /// Order by requirements for the aggregate function - /// By default it is `None` (there is no requirement) - /// Order-sensitive aggregators, such as `FIRST_VALUE(x ORDER BY y)` should implement this - pub fn order_bys(&self) -> Option<&LexOrdering> { - if self.ordering_req.is_empty() { - return None; - } - - if !self.order_sensitivity().is_insensitive() { - return Some(self.ordering_req.as_ref()); + /// Returns the ORDER BY expressions for the aggregate function. + pub fn order_bys(&self) -> &[PhysicalSortExpr] { + if self.order_sensitivity().is_insensitive() { + &[] + } else { + &self.order_bys } - - None } /// Indicates whether aggregator can produce the correct result with any /// arbitrary input ordering. By default, we assume that aggregate expressions /// are order insensitive. pub fn order_sensitivity(&self) -> AggregateOrderSensitivity { - if !self.ordering_req.is_empty() { - // If there is requirement, use the sensitivity of the implementation - self.fun.order_sensitivity() - } else { - // If no requirement, aggregator is order insensitive + if self.order_bys.is_empty() { AggregateOrderSensitivity::Insensitive + } else { + // If there is an ORDER BY clause, use the sensitivity of the implementation: + self.fun.order_sensitivity() } } @@ -459,7 +453,7 @@ impl AggregateFunctionExpr { }; AggregateExprBuilder::new(Arc::new(updated_fn), self.args.to_vec()) - .order_by(self.ordering_req.clone()) + .order_by(self.order_bys.clone()) .schema(Arc::new(self.schema.clone())) .alias(self.name().to_string()) .with_ignore_nulls(self.ignore_nulls) @@ -472,10 +466,10 @@ impl AggregateFunctionExpr { /// Creates accumulator implementation that supports retract pub fn create_sliding_accumulator(&self) -> Result> { let args = AccumulatorArgs { - return_type: &self.data_type, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, - ordering_req: self.ordering_req.as_ref(), + order_bys: self.order_bys.as_ref(), is_distinct: self.is_distinct, name: &self.name, is_reversed: self.is_reversed, @@ -541,10 +535,10 @@ impl AggregateFunctionExpr { /// `[Self::create_groups_accumulator`] will be called. pub fn groups_accumulator_supported(&self) -> bool { let args = AccumulatorArgs { - return_type: &self.data_type, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, - ordering_req: self.ordering_req.as_ref(), + order_bys: self.order_bys.as_ref(), is_distinct: self.is_distinct, name: &self.name, is_reversed: self.is_reversed, @@ -560,10 +554,10 @@ impl AggregateFunctionExpr { /// implemented in addition to [`Accumulator`]. pub fn create_groups_accumulator(&self) -> Result> { let args = AccumulatorArgs { - return_type: &self.data_type, + return_field: Arc::clone(&self.return_field), schema: &self.schema, ignore_nulls: self.ignore_nulls, - ordering_req: self.ordering_req.as_ref(), + order_bys: self.order_bys.as_ref(), is_distinct: self.is_distinct, name: &self.name, is_reversed: self.is_reversed, @@ -581,18 +575,16 @@ impl AggregateFunctionExpr { ReversedUDAF::NotSupported => None, ReversedUDAF::Identical => Some(self.clone()), ReversedUDAF::Reversed(reverse_udf) => { - let reverse_ordering_req = reverse_order_bys(self.ordering_req.as_ref()); let mut name = self.name().to_string(); // If the function is changed, we need to reverse order_by clause as well // i.e. First(a order by b asc null first) -> Last(a order by b desc null last) - if self.fun().name() == reverse_udf.name() { - } else { + if self.fun().name() != reverse_udf.name() { replace_order_by_clause(&mut name); } replace_fn_name_clause(&mut name, self.fun.name(), reverse_udf.name()); AggregateExprBuilder::new(reverse_udf, self.args.to_vec()) - .order_by(reverse_ordering_req) + .order_by(self.order_bys.iter().map(|e| e.reverse()).collect()) .schema(Arc::new(self.schema.clone())) .alias(name) .with_ignore_nulls(self.ignore_nulls) @@ -608,14 +600,11 @@ impl AggregateFunctionExpr { /// These expressions are (1)function arguments, (2) order by expressions. pub fn all_expressions(&self) -> AggregatePhysicalExpressions { let args = self.expressions(); - let order_bys = self + let order_by_exprs = self .order_bys() - .cloned() - .unwrap_or_else(LexOrdering::default); - let order_by_exprs = order_bys .iter() .map(|sort_expr| Arc::clone(&sort_expr.expr)) - .collect::>(); + .collect(); AggregatePhysicalExpressions { args, order_by_exprs, @@ -640,7 +629,7 @@ impl AggregateFunctionExpr { /// output_field is the name of the column produced by this aggregate /// /// Note: this is used to use special aggregate implementations in certain conditions - pub fn get_minmax_desc(&self) -> Option<(Field, bool)> { + pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> { self.fun.is_descending().map(|flag| (self.field(), flag)) } @@ -685,7 +674,7 @@ pub struct AggregatePhysicalExpressions { impl PartialEq for AggregateFunctionExpr { fn eq(&self, other: &Self) -> bool { self.name == other.name - && self.data_type == other.data_type + && self.return_field == other.return_field && self.fun == other.fun && self.args.len() == other.args.len() && self diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index 5abd50f6d1b4..1d59dab8fd6d 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -100,7 +100,7 @@ impl ExprBoundaries { ) -> Result { let field = schema.fields().get(col_index).ok_or_else(|| { internal_datafusion_err!( - "Could not create `ExprBoundaries`: in `try_from_column` `col_index` + "Could not create `ExprBoundaries`: in `try_from_column` `col_index` has gone out of bounds with a value of {col_index}, the schema has {} columns.", schema.fields.len() ) @@ -112,7 +112,7 @@ impl ExprBoundaries { .min_value .get_value() .cloned() - .unwrap_or(empty_field.clone()), + .unwrap_or_else(|| empty_field.clone()), col_stats .max_value .get_value() @@ -425,7 +425,7 @@ mod tests { fn test_analyze_invalid_boundary_exprs() { let schema = Arc::new(Schema::new(vec![make_field("a", DataType::Int32)])); let expr = col("a").lt(lit(10)).or(col("a").gt(lit(20))); - let expected_error = "Interval arithmetic does not support the operator OR"; + let expected_error = "OR operator cannot yet propagate true intervals"; let boundaries = ExprBoundaries::try_new_unbounded(&schema).unwrap(); let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap(); let physical_expr = diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs new file mode 100644 index 000000000000..547b9c13da62 --- /dev/null +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -0,0 +1,250 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::ScalarFunctionExpr; +use arrow::array::{make_array, MutableArrayData, RecordBatch}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::Result; +use datafusion_common::{internal_err, not_impl_err}; +use datafusion_expr::async_udf::AsyncScalarUDF; +use datafusion_expr::ScalarFunctionArgs; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::any::Any; +use std::fmt::Display; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +/// Wrapper around a scalar function that can be evaluated asynchronously +#[derive(Debug, Clone, Eq)] +pub struct AsyncFuncExpr { + /// The name of the output column this function will generate + pub name: String, + /// The actual function (always `ScalarFunctionExpr`) + pub func: Arc, + /// The field that this function will return + return_field: FieldRef, +} + +impl Display for AsyncFuncExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "async_expr(name={}, expr={})", self.name, self.func) + } +} + +impl PartialEq for AsyncFuncExpr { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.func == Arc::clone(&other.func) + } +} + +impl Hash for AsyncFuncExpr { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.func.as_ref().hash(state); + } +} + +impl AsyncFuncExpr { + /// create a new AsyncFuncExpr + pub fn try_new( + name: impl Into, + func: Arc, + schema: &Schema, + ) -> Result { + let Some(_) = func.as_any().downcast_ref::() else { + return internal_err!( + "unexpected function type, expected ScalarFunctionExpr, got: {:?}", + func + ); + }; + + let return_field = func.return_field(schema)?; + Ok(Self { + name: name.into(), + func, + return_field, + }) + } + + /// return the name of the output column + pub fn name(&self) -> &str { + &self.name + } + + /// Return the output field generated by evaluating this function + pub fn field(&self, input_schema: &Schema) -> Result { + Ok(Field::new( + &self.name, + self.func.data_type(input_schema)?, + self.func.nullable(input_schema)?, + )) + } + + /// Return the ideal batch size for this function + pub fn ideal_batch_size(&self) -> Result> { + if let Some(expr) = self.func.as_any().downcast_ref::() { + if let Some(udf) = + expr.fun().inner().as_any().downcast_ref::() + { + return Ok(udf.ideal_batch_size()); + } + } + not_impl_err!("Can't get ideal_batch_size from {:?}", self.func) + } + + /// This (async) function is called for each record batch to evaluate the LLM expressions + /// + /// The output is the output of evaluating the async expression and the input record batch + pub async fn invoke_with_args( + &self, + batch: &RecordBatch, + option: &ConfigOptions, + ) -> Result { + let Some(scalar_function_expr) = + self.func.as_any().downcast_ref::() + else { + return internal_err!( + "unexpected function type, expected ScalarFunctionExpr, got: {:?}", + self.func + ); + }; + + let Some(async_udf) = scalar_function_expr + .fun() + .inner() + .as_any() + .downcast_ref::() + else { + return not_impl_err!( + "Don't know how to evaluate async function: {:?}", + scalar_function_expr + ); + }; + + let arg_fields = scalar_function_expr + .args() + .iter() + .map(|e| e.return_field(batch.schema_ref())) + .collect::>>()?; + + let mut result_batches = vec![]; + if let Some(ideal_batch_size) = self.ideal_batch_size()? { + let mut remainder = batch.clone(); + while remainder.num_rows() > 0 { + let size = if ideal_batch_size > remainder.num_rows() { + remainder.num_rows() + } else { + ideal_batch_size + }; + + let current_batch = remainder.slice(0, size); // get next 10 rows + remainder = remainder.slice(size, remainder.num_rows() - size); + let args = scalar_function_expr + .args() + .iter() + .map(|e| e.evaluate(¤t_batch)) + .collect::>>()?; + result_batches.push( + async_udf + .invoke_async_with_args( + ScalarFunctionArgs { + args, + arg_fields: arg_fields.clone(), + number_rows: current_batch.num_rows(), + return_field: Arc::clone(&self.return_field), + }, + option, + ) + .await?, + ); + } + } else { + let args = scalar_function_expr + .args() + .iter() + .map(|e| e.evaluate(batch)) + .collect::>>()?; + + result_batches.push( + async_udf + .invoke_async_with_args( + ScalarFunctionArgs { + args: args.to_vec(), + arg_fields, + number_rows: batch.num_rows(), + return_field: Arc::clone(&self.return_field), + }, + option, + ) + .await?, + ); + } + + let datas = result_batches + .iter() + .map(|b| b.to_data()) + .collect::>(); + let total_len = datas.iter().map(|d| d.len()).sum(); + let mut mutable = MutableArrayData::new(datas.iter().collect(), false, total_len); + datas.iter().enumerate().for_each(|(i, data)| { + mutable.extend(i, 0, data.len()); + }); + let array_ref = make_array(mutable.freeze()); + Ok(ColumnarValue::Array(array_ref)) + } +} + +impl PhysicalExpr for AsyncFuncExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> Result { + self.func.data_type(input_schema) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.func.nullable(input_schema) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + // TODO: implement this for scalar value input + not_impl_err!("AsyncFuncExpr.evaluate") + } + + fn children(&self) -> Vec<&Arc> { + self.func.children() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + let new_func = Arc::clone(&self.func).with_new_children(children)?; + Ok(Arc::new(AsyncFuncExpr { + name: self.name.clone(), + func: new_func, + return_field: Arc::clone(&self.return_field), + })) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.func) + } +} diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 13a3c79a47a2..8af6f3be0389 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -15,30 +15,61 @@ // specific language governing permissions and limitations // under the License. -use super::{add_offset_to_expr, ProjectionMapping}; -use crate::{ - expressions::Column, LexOrdering, LexRequirement, PhysicalExpr, PhysicalExprRef, - PhysicalSortExpr, PhysicalSortRequirement, -}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{JoinType, ScalarValue}; -use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; use std::fmt::Display; +use std::ops::Deref; use std::sync::Arc; use std::vec::IntoIter; +use super::projection::ProjectionTargets; +use super::ProjectionMapping; +use crate::expressions::Literal; +use crate::physical_expr::add_offset_to_expr; +use crate::{PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement}; + +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{HashMap, JoinType, Result, ScalarValue}; +use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; + use indexmap::{IndexMap, IndexSet}; -/// A structure representing a expression known to be constant in a physical execution plan. +/// Represents whether a constant expression's value is uniform or varies across +/// partitions. Has two variants: +/// - `Heterogeneous`: The constant expression may have different values for +/// different partitions. +/// - `Uniform(Option)`: The constant expression has the same value +/// across all partitions, or is `None` if the value is unknown. +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub enum AcrossPartitions { + #[default] + Heterogeneous, + Uniform(Option), +} + +impl Display for AcrossPartitions { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AcrossPartitions::Heterogeneous => write!(f, "(heterogeneous)"), + AcrossPartitions::Uniform(value) => { + if let Some(val) = value { + write!(f, "(uniform: {val})") + } else { + write!(f, "(uniform: unknown)") + } + } + } + } +} + +/// A structure representing a expression known to be constant in a physical +/// execution plan. /// -/// The `ConstExpr` struct encapsulates an expression that is constant during the execution -/// of a query. For example if a predicate like `A = 5` applied earlier in the plan `A` would -/// be known constant +/// The `ConstExpr` struct encapsulates an expression that is constant during +/// the execution of a query. For example if a filter like `A = 5` appears +/// earlier in the plan, `A` would become a constant in subsequent operations. /// /// # Fields /// /// - `expr`: Constant expression for a node in the physical plan. -/// /// - `across_partitions`: A boolean flag indicating whether the constant /// expression is the same across partitions. If set to `true`, the constant /// expression has same value for all partitions. If set to `false`, the @@ -50,108 +81,37 @@ use indexmap::{IndexMap, IndexSet}; /// # use datafusion_physical_expr::ConstExpr; /// # use datafusion_physical_expr::expressions::lit; /// let col = lit(5); -/// // Create a constant expression from a physical expression ref -/// let const_expr = ConstExpr::from(&col); -/// // create a constant expression from a physical expression +/// // Create a constant expression from a physical expression: /// let const_expr = ConstExpr::from(col); /// ``` -// TODO: Consider refactoring the `across_partitions` and `value` fields into an enum: -// -// ``` -// enum PartitionValues { -// Uniform(Option), // Same value across all partitions -// Heterogeneous(Vec>) // Different values per partition -// } -// ``` -// -// This would provide more flexible representation of partition values. -// Note: This is a breaking change for the equivalence API and should be -// addressed in a separate issue/PR. -#[derive(Debug, Clone)] +#[derive(Clone, Debug)] pub struct ConstExpr { - /// The expression that is known to be constant (e.g. a `Column`) - expr: Arc, - /// Does the constant have the same value across all partitions? See - /// struct docs for more details - across_partitions: AcrossPartitions, -} - -#[derive(PartialEq, Clone, Debug)] -/// Represents whether a constant expression's value is uniform or varies across partitions. -/// -/// The `AcrossPartitions` enum is used to describe the nature of a constant expression -/// in a physical execution plan: -/// -/// - `Heterogeneous`: The constant expression may have different values for different partitions. -/// - `Uniform(Option)`: The constant expression has the same value across all partitions, -/// or is `None` if the value is not specified. -pub enum AcrossPartitions { - Heterogeneous, - Uniform(Option), -} - -impl Default for AcrossPartitions { - fn default() -> Self { - Self::Heterogeneous - } -} - -impl PartialEq for ConstExpr { - fn eq(&self, other: &Self) -> bool { - self.across_partitions == other.across_partitions && self.expr.eq(&other.expr) - } + /// The expression that is known to be constant (e.g. a `Column`). + pub expr: Arc, + /// Indicates whether the constant have the same value across all partitions. + pub across_partitions: AcrossPartitions, } +// TODO: The `ConstExpr` definition above can be in an inconsistent state where +// `expr` is a literal but `across_partitions` is not `Uniform`. Consider +// a refactor to ensure that `ConstExpr` is always in a consistent state +// (either by changing type definition, or by API constraints). impl ConstExpr { - /// Create a new constant expression from a physical expression. + /// Create a new constant expression from a physical expression, specifying + /// whether the constant expression is the same across partitions. /// - /// Note you can also use `ConstExpr::from` to create a constant expression - /// from a reference as well - pub fn new(expr: Arc) -> Self { - Self { - expr, - // By default, assume constant expressions are not same across partitions. - across_partitions: Default::default(), + /// Note that you can also use `ConstExpr::from` to create a constant + /// expression from just a physical expression, with the *safe* assumption + /// of heterogenous values across partitions unless the expression is a + /// literal. + pub fn new(expr: Arc, across_partitions: AcrossPartitions) -> Self { + let mut result = ConstExpr::from(expr); + // Override the across partitions specification if the expression is not + // a literal. + if result.across_partitions == AcrossPartitions::Heterogeneous { + result.across_partitions = across_partitions; } - } - - /// Set the `across_partitions` flag - /// - /// See struct docs for more details - pub fn with_across_partitions(mut self, across_partitions: AcrossPartitions) -> Self { - self.across_partitions = across_partitions; - self - } - - /// Is the expression the same across all partitions? - /// - /// See struct docs for more details - pub fn across_partitions(&self) -> AcrossPartitions { - self.across_partitions.clone() - } - - pub fn expr(&self) -> &Arc { - &self.expr - } - - pub fn owned_expr(self) -> Arc { - self.expr - } - - pub fn map(&self, f: F) -> Option - where - F: Fn(&Arc) -> Option>, - { - let maybe_expr = f(&self.expr); - maybe_expr.map(|expr| Self { - expr, - across_partitions: self.across_partitions.clone(), - }) - } - - /// Returns true if this constant expression is equal to the given expression - pub fn eq_expr(&self, other: impl AsRef) -> bool { - self.expr.as_ref() == other.as_ref() + result } /// Returns a [`Display`]able list of `ConstExpr`. @@ -166,7 +126,7 @@ impl ConstExpr { } else { write!(f, ",")?; } - write!(f, "{}", const_expr)?; + write!(f, "{const_expr}")?; } Ok(()) } @@ -175,47 +135,36 @@ impl ConstExpr { } } +impl PartialEq for ConstExpr { + fn eq(&self, other: &Self) -> bool { + self.across_partitions == other.across_partitions && self.expr.eq(&other.expr) + } +} + impl Display for ConstExpr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.expr)?; - match &self.across_partitions { - AcrossPartitions::Heterogeneous => { - write!(f, "(heterogeneous)")?; - } - AcrossPartitions::Uniform(value) => { - if let Some(val) = value { - write!(f, "(uniform: {})", val)?; - } else { - write!(f, "(uniform: unknown)")?; - } - } - } - Ok(()) + write!(f, "{}", self.across_partitions) } } impl From> for ConstExpr { fn from(expr: Arc) -> Self { - Self::new(expr) - } -} - -impl From<&Arc> for ConstExpr { - fn from(expr: &Arc) -> Self { - Self::new(Arc::clone(expr)) + // By default, assume constant expressions are not same across partitions. + // However, if we have a literal, it will have a single value that is the + // same across all partitions. + let across = if let Some(lit) = expr.as_any().downcast_ref::() { + AcrossPartitions::Uniform(Some(lit.value().clone())) + } else { + AcrossPartitions::Heterogeneous + }; + Self { + expr, + across_partitions: across, + } } } -/// Checks whether `expr` is among in the `const_exprs`. -pub fn const_exprs_contains( - const_exprs: &[ConstExpr], - expr: &Arc, -) -> bool { - const_exprs - .iter() - .any(|const_expr| const_expr.expr.eq(expr)) -} - /// An `EquivalenceClass` is a set of [`Arc`]s that are known /// to have the same value for all tuples in a relation. These are generated by /// equality predicates (e.g. `a = b`), typically equi-join conditions and @@ -223,259 +172,361 @@ pub fn const_exprs_contains( /// /// Two `EquivalenceClass`es are equal if they contains the same expressions in /// without any ordering. -#[derive(Debug, Clone)] +#[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct EquivalenceClass { - /// The expressions in this equivalence class. The order doesn't - /// matter for equivalence purposes - /// - exprs: IndexSet>, -} - -impl PartialEq for EquivalenceClass { - /// Returns true if other is equal in the sense - /// of bags (multi-sets), disregarding their orderings. - fn eq(&self, other: &Self) -> bool { - self.exprs.eq(&other.exprs) - } + /// The expressions in this equivalence class. The order doesn't matter for + /// equivalence purposes. + pub(crate) exprs: IndexSet>, + /// Indicates whether the expressions in this equivalence class have a + /// constant value. A `Some` value indicates constant-ness. + pub(crate) constant: Option, } impl EquivalenceClass { - /// Create a new empty equivalence class - pub fn new_empty() -> Self { - Self { - exprs: IndexSet::new(), - } - } - - // Create a new equivalence class from a pre-existing `Vec` - pub fn new(exprs: Vec>) -> Self { - Self { - exprs: exprs.into_iter().collect(), + // Create a new equivalence class from a pre-existing collection. + pub fn new(exprs: impl IntoIterator>) -> Self { + let mut class = Self::default(); + for expr in exprs { + class.push(expr); } - } - - /// Return the inner vector of expressions - pub fn into_vec(self) -> Vec> { - self.exprs.into_iter().collect() + class } /// Return the "canonical" expression for this class (the first element) - /// if any - fn canonical_expr(&self) -> Option> { - self.exprs.iter().next().cloned() + /// if non-empty. + pub fn canonical_expr(&self) -> Option<&Arc> { + self.exprs.iter().next() } /// Insert the expression into this class, meaning it is known to be equal to - /// all other expressions in this class + /// all other expressions in this class. pub fn push(&mut self, expr: Arc) { + if let Some(lit) = expr.as_any().downcast_ref::() { + let expr_across = AcrossPartitions::Uniform(Some(lit.value().clone())); + if let Some(across) = self.constant.as_mut() { + // TODO: Return an error if constant values do not agree. + if *across == AcrossPartitions::Heterogeneous { + *across = expr_across; + } + } else { + self.constant = Some(expr_across); + } + } self.exprs.insert(expr); } - /// Inserts all the expressions from other into this class + /// Inserts all the expressions from other into this class. pub fn extend(&mut self, other: Self) { - for expr in other.exprs { - // use push so entries are deduplicated - self.push(expr); + self.exprs.extend(other.exprs); + match (&self.constant, &other.constant) { + (Some(across), Some(_)) => { + // TODO: Return an error if constant values do not agree. + if across == &AcrossPartitions::Heterogeneous { + self.constant = other.constant; + } + } + (None, Some(_)) => self.constant = other.constant, + (_, None) => {} } } - /// Returns true if this equivalence class contains t expression - pub fn contains(&self, expr: &Arc) -> bool { - self.exprs.contains(expr) - } - - /// Returns true if this equivalence class has any entries in common with `other` + /// Returns whether this equivalence class has any entries in common with + /// `other`. pub fn contains_any(&self, other: &Self) -> bool { - self.exprs.iter().any(|e| other.contains(e)) - } - - /// return the number of items in this class - pub fn len(&self) -> usize { - self.exprs.len() - } - - /// return true if this class is empty - pub fn is_empty(&self) -> bool { - self.exprs.is_empty() + self.exprs.intersection(&other.exprs).next().is_some() } - /// Iterate over all elements in this class, in some arbitrary order - pub fn iter(&self) -> impl Iterator> { - self.exprs.iter() + /// Returns whether this equivalence class is trivial, meaning that it is + /// either empty, or contains a single expression that is not a constant. + /// Such classes are not useful, and can be removed from equivalence groups. + pub fn is_trivial(&self) -> bool { + self.exprs.is_empty() || (self.exprs.len() == 1 && self.constant.is_none()) } - /// Return a new equivalence class that have the specified offset added to - /// each expression (used when schemas are appended such as in joins) - pub fn with_offset(&self, offset: usize) -> Self { - let new_exprs = self + /// Adds the given offset to all columns in the expressions inside this + /// class. This is used when schemas are appended, e.g. in joins. + pub fn try_with_offset(&self, offset: isize) -> Result { + let mut cls = Self::default(); + for expr_result in self .exprs .iter() .cloned() .map(|e| add_offset_to_expr(e, offset)) - .collect(); - Self::new(new_exprs) + { + cls.push(expr_result?); + } + Ok(cls) + } +} + +impl Deref for EquivalenceClass { + type Target = IndexSet>; + + fn deref(&self) -> &Self::Target { + &self.exprs + } +} + +impl IntoIterator for EquivalenceClass { + type Item = Arc; + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.exprs.into_iter() } } impl Display for EquivalenceClass { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "[{}]", format_physical_expr_list(&self.exprs)) + write!(f, "{{")?; + write!(f, "members: {}", format_physical_expr_list(&self.exprs))?; + if let Some(across) = &self.constant { + write!(f, ", constant: {across}")?; + } + write!(f, "}}") + } +} + +impl From for Vec> { + fn from(cls: EquivalenceClass) -> Self { + cls.exprs.into_iter().collect() } } -/// A collection of distinct `EquivalenceClass`es -#[derive(Debug, Clone)] +type AugmentedMapping<'a> = IndexMap< + &'a Arc, + (&'a ProjectionTargets, Option<&'a EquivalenceClass>), +>; + +/// A collection of distinct `EquivalenceClass`es. This object supports fast +/// lookups of expressions and their equivalence classes. +#[derive(Clone, Debug, Default)] pub struct EquivalenceGroup { + /// A mapping from expressions to their equivalence class key. + map: HashMap, usize>, + /// The equivalence classes in this group. classes: Vec, } impl EquivalenceGroup { - /// Creates an empty equivalence group. - pub fn empty() -> Self { - Self { classes: vec![] } - } - /// Creates an equivalence group from the given equivalence classes. - pub fn new(classes: Vec) -> Self { - let mut result = Self { classes }; - result.remove_redundant_entries(); - result - } - - /// Returns how many equivalence classes there are in this group. - pub fn len(&self) -> usize { - self.classes.len() + pub fn new(classes: impl IntoIterator) -> Self { + classes.into_iter().collect::>().into() } - /// Checks whether this equivalence group is empty. - pub fn is_empty(&self) -> bool { - self.len() == 0 + /// Adds `expr` as a constant expression to this equivalence group. + pub fn add_constant(&mut self, const_expr: ConstExpr) { + // If the expression is already in an equivalence class, we should + // adjust the constant-ness of the class if necessary: + if let Some(idx) = self.map.get(&const_expr.expr) { + let cls = &mut self.classes[*idx]; + if let Some(across) = cls.constant.as_mut() { + // TODO: Return an error if constant values do not agree. + if *across == AcrossPartitions::Heterogeneous { + *across = const_expr.across_partitions; + } + } else { + cls.constant = Some(const_expr.across_partitions); + } + return; + } + // If the expression is not in any equivalence class, but has the same + // constant value with some class, add it to that class: + if let AcrossPartitions::Uniform(_) = &const_expr.across_partitions { + for (idx, cls) in self.classes.iter_mut().enumerate() { + if cls + .constant + .as_ref() + .is_some_and(|across| const_expr.across_partitions.eq(across)) + { + self.map.insert(Arc::clone(&const_expr.expr), idx); + cls.push(const_expr.expr); + return; + } + } + } + // Otherwise, create a new class with the expression as the only member: + let mut new_class = EquivalenceClass::new(std::iter::once(const_expr.expr)); + if new_class.constant.is_none() { + new_class.constant = Some(const_expr.across_partitions); + } + Self::update_lookup_table(&mut self.map, &new_class, self.classes.len()); + self.classes.push(new_class); } - /// Returns an iterator over the equivalence classes in this group. - pub fn iter(&self) -> impl Iterator { - self.classes.iter() + /// Removes constant expressions that may change across partitions. + /// This method should be used when merging data from different partitions. + /// Returns whether any change was made to the equivalence group. + pub fn clear_per_partition_constants(&mut self) -> bool { + let (mut idx, mut change) = (0, false); + while idx < self.classes.len() { + let cls = &mut self.classes[idx]; + if let Some(AcrossPartitions::Heterogeneous) = cls.constant { + change = true; + if cls.len() == 1 { + // If this class becomes trivial, remove it entirely: + self.remove_class_at_idx(idx); + continue; + } else { + cls.constant = None; + } + } + idx += 1; + } + change } - /// Adds the equality `left` = `right` to this equivalence group. - /// New equality conditions often arise after steps like `Filter(a = b)`, - /// `Alias(a, a as b)` etc. + /// Adds the equality `left` = `right` to this equivalence group. New + /// equality conditions often arise after steps like `Filter(a = b)`, + /// `Alias(a, a as b)` etc. Returns whether the given equality defines + /// a new equivalence class. pub fn add_equal_conditions( &mut self, - left: &Arc, - right: &Arc, - ) { - let mut first_class = None; - let mut second_class = None; - for (idx, cls) in self.classes.iter().enumerate() { - if cls.contains(left) { - first_class = Some(idx); - } - if cls.contains(right) { - second_class = Some(idx); - } - } + left: Arc, + right: Arc, + ) -> bool { + let first_class = self.map.get(&left).copied(); + let second_class = self.map.get(&right).copied(); match (first_class, second_class) { (Some(mut first_idx), Some(mut second_idx)) => { // If the given left and right sides belong to different classes, // we should unify/bridge these classes. - if first_idx != second_idx { - // By convention, make sure `second_idx` is larger than `first_idx`. - if first_idx > second_idx { - (first_idx, second_idx) = (second_idx, first_idx); + match first_idx.cmp(&second_idx) { + // The equality is already known, return and signal this: + std::cmp::Ordering::Equal => return false, + // Swap indices to ensure `first_idx` is the lesser index. + std::cmp::Ordering::Greater => { + std::mem::swap(&mut first_idx, &mut second_idx); } - // Remove the class at `second_idx` and merge its values with - // the class at `first_idx`. The convention above makes sure - // that `first_idx` is still valid after removing `second_idx`. - let other_class = self.classes.swap_remove(second_idx); - self.classes[first_idx].extend(other_class); + _ => {} } + // Remove the class at `second_idx` and merge its values with + // the class at `first_idx`. The convention above makes sure + // that `first_idx` is still valid after removing `second_idx`. + let other_class = self.remove_class_at_idx(second_idx); + // Update the lookup table for the second class: + Self::update_lookup_table(&mut self.map, &other_class, first_idx); + self.classes[first_idx].extend(other_class); } (Some(group_idx), None) => { // Right side is new, extend left side's class: - self.classes[group_idx].push(Arc::clone(right)); + self.map.insert(Arc::clone(&right), group_idx); + self.classes[group_idx].push(right); } (None, Some(group_idx)) => { // Left side is new, extend right side's class: - self.classes[group_idx].push(Arc::clone(left)); + self.map.insert(Arc::clone(&left), group_idx); + self.classes[group_idx].push(left); } (None, None) => { // None of the expressions is among existing classes. // Create a new equivalence class and extend the group. - self.classes.push(EquivalenceClass::new(vec![ - Arc::clone(left), - Arc::clone(right), - ])); + let class = EquivalenceClass::new([left, right]); + Self::update_lookup_table(&mut self.map, &class, self.classes.len()); + self.classes.push(class); + return true; } } + false } - /// Removes redundant entries from this group. - fn remove_redundant_entries(&mut self) { - // Remove duplicate entries from each equivalence class: - self.classes.retain_mut(|cls| { - // Keep groups that have at least two entries as singleton class is - // meaningless (i.e. it contains no non-trivial information): - cls.len() > 1 - }); - // Unify/bridge groups that have common expressions: - self.bridge_classes() + /// Removes the equivalence class at the given index from this group. + fn remove_class_at_idx(&mut self, idx: usize) -> EquivalenceClass { + // Remove the class at the given index: + let cls = self.classes.swap_remove(idx); + // Remove its entries from the lookup table: + for expr in cls.iter() { + self.map.remove(expr); + } + // Update the lookup table for the moved class: + if idx < self.classes.len() { + Self::update_lookup_table(&mut self.map, &self.classes[idx], idx); + } + cls + } + + /// Updates the entry in lookup table for the given equivalence class with + /// the given index. + fn update_lookup_table( + map: &mut HashMap, usize>, + cls: &EquivalenceClass, + idx: usize, + ) { + for expr in cls.iter() { + map.insert(Arc::clone(expr), idx); + } + } + + /// Removes redundant entries from this group. Returns whether any change + /// was made to the equivalence group. + fn remove_redundant_entries(&mut self) -> bool { + // First, remove trivial equivalence classes: + let mut change = false; + for idx in (0..self.classes.len()).rev() { + if self.classes[idx].is_trivial() { + self.remove_class_at_idx(idx); + change = true; + } + } + // Then, unify/bridge groups that have common expressions: + self.bridge_classes() || change } /// This utility function unifies/bridges classes that have common expressions. /// For example, assume that we have [`EquivalenceClass`]es `[a, b]` and `[b, c]`. /// Since both classes contain `b`, columns `a`, `b` and `c` are actually all /// equal and belong to one class. This utility converts merges such classes. - fn bridge_classes(&mut self) { - let mut idx = 0; - while idx < self.classes.len() { - let mut next_idx = idx + 1; - let start_size = self.classes[idx].len(); - while next_idx < self.classes.len() { - if self.classes[idx].contains_any(&self.classes[next_idx]) { - let extension = self.classes.swap_remove(next_idx); + /// Returns whether any change was made to the equivalence group. + fn bridge_classes(&mut self) -> bool { + let (mut idx, mut change) = (0, false); + 'scan: while idx < self.classes.len() { + for other_idx in (idx + 1..self.classes.len()).rev() { + if self.classes[idx].contains_any(&self.classes[other_idx]) { + let extension = self.remove_class_at_idx(other_idx); + Self::update_lookup_table(&mut self.map, &extension, idx); self.classes[idx].extend(extension); - } else { - next_idx += 1; + change = true; + continue 'scan; } } - if self.classes[idx].len() > start_size { - continue; - } idx += 1; } + change } /// Extends this equivalence group with the `other` equivalence group. - pub fn extend(&mut self, other: Self) { + /// Returns whether any equivalence classes were unified/bridged as a + /// result of the extension process. + pub fn extend(&mut self, other: Self) -> bool { + for (idx, cls) in other.classes.iter().enumerate() { + // Update the lookup table for the new class: + Self::update_lookup_table(&mut self.map, cls, idx); + } self.classes.extend(other.classes); - self.remove_redundant_entries(); + self.bridge_classes() } - /// Normalizes the given physical expression according to this group. - /// The expression is replaced with the first expression in the equivalence - /// class it matches with (if any). + /// Normalizes the given physical expression according to this group. The + /// expression is replaced with the first (canonical) expression in the + /// equivalence class it matches with (if any). pub fn normalize_expr(&self, expr: Arc) -> Arc { expr.transform(|expr| { - for cls in self.iter() { - if cls.contains(&expr) { - // The unwrap below is safe because the guard above ensures - // that the class is not empty. - return Ok(Transformed::yes(cls.canonical_expr().unwrap())); - } - } - Ok(Transformed::no(expr)) + let cls = self.get_equivalence_class(&expr); + let Some(canonical) = cls.and_then(|cls| cls.canonical_expr()) else { + return Ok(Transformed::no(expr)); + }; + Ok(Transformed::yes(Arc::clone(canonical))) }) .data() .unwrap() // The unwrap above is safe because the closure always returns `Ok`. } - /// Normalizes the given sort expression according to this group. - /// The underlying physical expression is replaced with the first expression - /// in the equivalence class it matches with (if any). If the underlying - /// expression does not belong to any equivalence class in this group, returns - /// the sort expression as is. + /// Normalizes the given sort expression according to this group. The + /// underlying physical expression is replaced with the first expression in + /// the equivalence class it matches with (if any). If the underlying + /// expression does not belong to any equivalence class in this group, + /// returns the sort expression as is. pub fn normalize_sort_expr( &self, mut sort_expr: PhysicalSortExpr, @@ -484,11 +535,29 @@ impl EquivalenceGroup { sort_expr } - /// Normalizes the given sort requirement according to this group. - /// The underlying physical expression is replaced with the first expression - /// in the equivalence class it matches with (if any). If the underlying - /// expression does not belong to any equivalence class in this group, returns - /// the given sort requirement as is. + /// Normalizes the given sort expressions (i.e. `sort_exprs`) by: + /// - Replacing sections that belong to some equivalence class in the + /// with the first entry in the matching equivalence class. + /// - Removing expressions that have a constant value. + /// + /// If columns `a` and `b` are known to be equal, `d` is known to be a + /// constant, and `sort_exprs` is `[b ASC, d DESC, c ASC, a ASC]`, this + /// function would return `[a ASC, c ASC, a ASC]`. + pub fn normalize_sort_exprs<'a>( + &'a self, + sort_exprs: impl IntoIterator + 'a, + ) -> impl Iterator + 'a { + sort_exprs + .into_iter() + .map(|sort_expr| self.normalize_sort_expr(sort_expr)) + .filter(|sort_expr| self.is_expr_constant(&sort_expr.expr).is_none()) + } + + /// Normalizes the given sort requirement according to this group. The + /// underlying physical expression is replaced with the first expression in + /// the equivalence class it matches with (if any). If the underlying + /// expression does not belong to any equivalence class in this group, + /// returns the given sort requirement as is. pub fn normalize_sort_requirement( &self, mut sort_requirement: PhysicalSortRequirement, @@ -497,44 +566,76 @@ impl EquivalenceGroup { sort_requirement } - /// This function applies the `normalize_expr` function for all expressions - /// in `exprs` and returns the corresponding normalized physical expressions. - pub fn normalize_exprs( - &self, - exprs: impl IntoIterator>, - ) -> Vec> { - exprs + /// Normalizes the given sort requirements (i.e. `sort_reqs`) by: + /// - Replacing sections that belong to some equivalence class in the + /// with the first entry in the matching equivalence class. + /// - Removing expressions that have a constant value. + /// + /// If columns `a` and `b` are known to be equal, `d` is known to be a + /// constant, and `sort_reqs` is `[b ASC, d DESC, c ASC, a ASC]`, this + /// function would return `[a ASC, c ASC, a ASC]`. + pub fn normalize_sort_requirements<'a>( + &'a self, + sort_reqs: impl IntoIterator + 'a, + ) -> impl Iterator + 'a { + sort_reqs .into_iter() - .map(|expr| self.normalize_expr(expr)) - .collect() + .map(|req| self.normalize_sort_requirement(req)) + .filter(|req| self.is_expr_constant(&req.expr).is_none()) } - /// This function applies the `normalize_sort_expr` function for all sort - /// expressions in `sort_exprs` and returns the corresponding normalized - /// sort expressions. - pub fn normalize_sort_exprs(&self, sort_exprs: &LexOrdering) -> LexOrdering { - // Convert sort expressions to sort requirements: - let sort_reqs = LexRequirement::from(sort_exprs.clone()); - // Normalize the requirements: - let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); - // Convert sort requirements back to sort expressions: - LexOrdering::from(normalized_sort_reqs) + /// Perform an indirect projection of `expr` by consulting the equivalence + /// classes. + fn project_expr_indirect( + aug_mapping: &AugmentedMapping, + expr: &Arc, + ) -> Option> { + // The given expression is not inside the mapping, so we try to project + // indirectly using equivalence classes. + for (targets, eq_class) in aug_mapping.values() { + // If we match an equivalent expression to a source expression in + // the mapping, then we can project. For example, if we have the + // mapping `(a as a1, a + c)` and the equivalence `a == b`, + // expression `b` projects to `a1`. + if eq_class.as_ref().is_some_and(|cls| cls.contains(expr)) { + let (target, _) = targets.first(); + return Some(Arc::clone(target)); + } + } + // Project a non-leaf expression by projecting its children. + let children = expr.children(); + if children.is_empty() { + // A leaf expression should be inside the mapping. + return None; + } + children + .into_iter() + .map(|child| { + // First, we try to project children with an exact match. If + // we are unable to do this, we consult equivalence classes. + if let Some((targets, _)) = aug_mapping.get(child) { + // If we match the source, we can project directly: + let (target, _) = targets.first(); + Some(Arc::clone(target)) + } else { + Self::project_expr_indirect(aug_mapping, child) + } + }) + .collect::>>() + .map(|children| Arc::clone(expr).with_new_children(children).unwrap()) } - /// This function applies the `normalize_sort_requirement` function for all - /// requirements in `sort_reqs` and returns the corresponding normalized - /// sort requirements. - pub fn normalize_sort_requirements( - &self, - sort_reqs: &LexRequirement, - ) -> LexRequirement { - LexRequirement::new( - sort_reqs - .iter() - .map(|sort_req| self.normalize_sort_requirement(sort_req.clone())) - .collect(), - ) - .collapse() + fn augment_projection_mapping<'a>( + &'a self, + mapping: &'a ProjectionMapping, + ) -> AugmentedMapping<'a> { + mapping + .iter() + .map(|(k, v)| { + let eq_class = self.get_equivalence_class(k); + (k, (v, eq_class)) + }) + .collect() } /// Projects `expr` according to the given projection mapping. @@ -544,81 +645,118 @@ impl EquivalenceGroup { mapping: &ProjectionMapping, expr: &Arc, ) -> Option> { - // First, we try to project expressions with an exact match. If we are - // unable to do this, we consult equivalence classes. - if let Some(target) = mapping.target_expr(expr) { + if let Some(targets) = mapping.get(expr) { // If we match the source, we can project directly: - return Some(target); + let (target, _) = targets.first(); + Some(Arc::clone(target)) } else { - // If the given expression is not inside the mapping, try to project - // expressions considering the equivalence classes. - for (source, target) in mapping.iter() { - // If we match an equivalent expression to `source`, then we can - // project. For example, if we have the mapping `(a as a1, a + c)` - // and the equivalence class `(a, b)`, expression `b` projects to `a1`. - if self - .get_equivalence_class(source) - .is_some_and(|group| group.contains(expr)) - { - return Some(Arc::clone(target)); - } - } + let aug_mapping = self.augment_projection_mapping(mapping); + Self::project_expr_indirect(&aug_mapping, expr) } - // Project a non-leaf expression by projecting its children. - let children = expr.children(); - if children.is_empty() { - // Leaf expression should be inside mapping. - return None; - } - children - .into_iter() - .map(|child| self.project_expr(mapping, child)) - .collect::>>() - .map(|children| Arc::clone(expr).with_new_children(children).unwrap()) + } + + /// Projects `expressions` according to the given projection mapping. + /// This function is similar to [`Self::project_expr`], but projects multiple + /// expressions at once more efficiently than calling `project_expr` for each + /// expression. + pub fn project_expressions<'a>( + &'a self, + mapping: &'a ProjectionMapping, + expressions: impl IntoIterator> + 'a, + ) -> impl Iterator>> + 'a { + let mut aug_mapping = None; + expressions.into_iter().map(move |expr| { + if let Some(targets) = mapping.get(expr) { + // If we match the source, we can project directly: + let (target, _) = targets.first(); + Some(Arc::clone(target)) + } else { + let aug_mapping = aug_mapping + .get_or_insert_with(|| self.augment_projection_mapping(mapping)); + Self::project_expr_indirect(aug_mapping, expr) + } + }) } /// Projects this equivalence group according to the given projection mapping. pub fn project(&self, mapping: &ProjectionMapping) -> Self { - let projected_classes = self.iter().filter_map(|cls| { - let new_class = cls - .iter() - .filter_map(|expr| self.project_expr(mapping, expr)) - .collect::>(); - (new_class.len() > 1).then_some(EquivalenceClass::new(new_class)) + let projected_classes = self.iter().map(|cls| { + let new_exprs = self.project_expressions(mapping, cls.iter()); + EquivalenceClass::new(new_exprs.flatten()) }); // The key is the source expression, and the value is the equivalence // class that contains the corresponding target expression. - let mut new_classes: IndexMap<_, _> = IndexMap::new(); - for (source, target) in mapping.iter() { + let mut new_constants = vec![]; + let mut new_classes = IndexMap::<_, EquivalenceClass>::new(); + for (source, targets) in mapping.iter() { // We need to find equivalent projected expressions. For example, // consider a table with columns `[a, b, c]` with `a` == `b`, and // projection `[a + c, b + c]`. To conclude that `a + c == b + c`, // we first normalize all source expressions in the mapping, then // merge all equivalent expressions into the classes. let normalized_expr = self.normalize_expr(Arc::clone(source)); - new_classes - .entry(normalized_expr) - .or_insert_with(EquivalenceClass::new_empty) - .push(Arc::clone(target)); + let cls = new_classes.entry(normalized_expr).or_default(); + for (target, _) in targets.iter() { + cls.push(Arc::clone(target)); + } + // Save new constants arising from the projection: + if let Some(across) = self.is_expr_constant(source) { + for (target, _) in targets.iter() { + let const_expr = ConstExpr::new(Arc::clone(target), across.clone()); + new_constants.push(const_expr); + } + } } - // Only add equivalence classes with at least two members as singleton - // equivalence classes are meaningless. - let new_classes = new_classes - .into_iter() - .filter_map(|(_, cls)| (cls.len() > 1).then_some(cls)); - let classes = projected_classes.chain(new_classes).collect(); - Self::new(classes) + // Union projected classes with new classes to make up the result: + let classes = projected_classes + .chain(new_classes.into_values()) + .filter(|cls| !cls.is_trivial()); + let mut result = Self::new(classes); + // Add new constants arising from the projection to the equivalence group: + for constant in new_constants { + result.add_constant(constant); + } + result + } + + /// Returns a `Some` value if the expression is constant according to + /// equivalence group, and `None` otherwise. The `Some` variant contains + /// an `AcrossPartitions` value indicating whether the expression is + /// constant across partitions, and its actual value (if available). + pub fn is_expr_constant( + &self, + expr: &Arc, + ) -> Option { + if let Some(lit) = expr.as_any().downcast_ref::() { + return Some(AcrossPartitions::Uniform(Some(lit.value().clone()))); + } + if let Some(cls) = self.get_equivalence_class(expr) { + if cls.constant.is_some() { + return cls.constant.clone(); + } + } + // TODO: This function should be able to return values of non-literal + // complex constants as well; e.g. it should return `8` for the + // expression `3 + 5`, not an unknown `heterogenous` value. + let children = expr.children(); + if children.is_empty() { + return None; + } + for child in children { + self.is_expr_constant(child)?; + } + Some(AcrossPartitions::Heterogeneous) } /// Returns the equivalence class containing `expr`. If no equivalence class /// contains `expr`, returns `None`. - fn get_equivalence_class( + pub fn get_equivalence_class( &self, expr: &Arc, ) -> Option<&EquivalenceClass> { - self.iter().find(|cls| cls.contains(expr)) + self.map.get(expr).map(|idx| &self.classes[*idx]) } /// Combine equivalence groups of the given join children. @@ -628,18 +766,16 @@ impl EquivalenceGroup { join_type: &JoinType, left_size: usize, on: &[(PhysicalExprRef, PhysicalExprRef)], - ) -> Self { - match join_type { + ) -> Result { + let group = match join_type { JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { let mut result = Self::new( - self.iter() - .cloned() - .chain( - right_equivalences - .iter() - .map(|cls| cls.with_offset(left_size)), - ) - .collect(), + self.iter().cloned().chain( + right_equivalences + .iter() + .map(|cls| cls.try_with_offset(left_size as _)) + .collect::>>()?, + ), ); // In we have an inner join, expressions in the "on" condition // are equal in the resulting table. @@ -647,36 +783,25 @@ impl EquivalenceGroup { for (lhs, rhs) in on.iter() { let new_lhs = Arc::clone(lhs); // Rewrite rhs to point to the right side of the join: - let new_rhs = Arc::clone(rhs) - .transform(|expr| { - if let Some(column) = - expr.as_any().downcast_ref::() - { - let new_column = Arc::new(Column::new( - column.name(), - column.index() + left_size, - )) - as _; - return Ok(Transformed::yes(new_column)); - } - - Ok(Transformed::no(expr)) - }) - .data() - .unwrap(); - result.add_equal_conditions(&new_lhs, &new_rhs); + let new_rhs = + add_offset_to_expr(Arc::clone(rhs), left_size as _)?; + result.add_equal_conditions(new_lhs, new_rhs); } } result } JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(), - JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), - } + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + right_equivalences.clone() + } + }; + Ok(group) } - /// Checks if two expressions are equal either directly or through equivalence classes. - /// For complex expressions (e.g. a + b), checks that the expression trees are structurally - /// identical and their leaf nodes are equivalent either directly or through equivalence classes. + /// Checks if two expressions are equal directly or through equivalence + /// classes. For complex expressions (e.g. `a + b`), checks that the + /// expression trees are structurally identical and their leaf nodes are + /// equivalent either directly or through equivalence classes. pub fn exprs_equal( &self, left: &Arc, @@ -726,16 +851,19 @@ impl EquivalenceGroup { .zip(right_children) .all(|(left_child, right_child)| self.exprs_equal(left_child, right_child)) } +} + +impl Deref for EquivalenceGroup { + type Target = [EquivalenceClass]; - /// Return the inner classes of this equivalence group. - pub fn into_inner(self) -> Vec { - self.classes + fn deref(&self) -> &Self::Target { + &self.classes } } impl IntoIterator for EquivalenceGroup { type Item = EquivalenceClass; - type IntoIter = IntoIter; + type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { self.classes.into_iter() @@ -747,20 +875,37 @@ impl Display for EquivalenceGroup { write!(f, "[")?; let mut iter = self.iter(); if let Some(cls) = iter.next() { - write!(f, "{}", cls)?; + write!(f, "{cls}")?; } for cls in iter { - write!(f, ", {}", cls)?; + write!(f, ", {cls}")?; } write!(f, "]") } } +impl From> for EquivalenceGroup { + fn from(classes: Vec) -> Self { + let mut result = Self { + map: classes + .iter() + .enumerate() + .flat_map(|(idx, cls)| { + cls.iter().map(move |expr| (Arc::clone(expr), idx)) + }) + .collect(), + classes, + }; + result.remove_redundant_entries(); + result + } +} + #[cfg(test)] mod tests { use super::*; use crate::equivalence::tests::create_test_params; - use crate::expressions::{binary, col, lit, BinaryExpr, Literal}; + use crate::expressions::{binary, col, lit, BinaryExpr, Column, Literal}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{Result, ScalarValue}; @@ -786,24 +931,32 @@ mod tests { for (entries, expected) in test_cases { let entries = entries .into_iter() - .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(|entry| { + entry.into_iter().map(|idx| { + let c = Column::new(format!("col_{idx}").as_str(), idx); + Arc::new(c) as _ + }) + }) .map(EquivalenceClass::new) .collect::>(); let expected = expected .into_iter() - .map(|entry| entry.into_iter().map(lit).collect::>()) + .map(|entry| { + entry.into_iter().map(|idx| { + let c = Column::new(format!("col_{idx}").as_str(), idx); + Arc::new(c) as _ + }) + }) .map(EquivalenceClass::new) .collect::>(); - let mut eq_groups = EquivalenceGroup::new(entries.clone()); - eq_groups.bridge_classes(); + let eq_groups: EquivalenceGroup = entries.clone().into(); let eq_groups = eq_groups.classes; let err_msg = format!( - "error in test entries: {:?}, expected: {:?}, actual:{:?}", - entries, expected, eq_groups + "error in test entries: {entries:?}, expected: {expected:?}, actual:{eq_groups:?}" ); - assert_eq!(eq_groups.len(), expected.len(), "{}", err_msg); + assert_eq!(eq_groups.len(), expected.len(), "{err_msg}"); for idx in 0..eq_groups.len() { - assert_eq!(&eq_groups[idx], &expected[idx], "{}", err_msg); + assert_eq!(&eq_groups[idx], &expected[idx], "{err_msg}"); } } Ok(()) @@ -811,58 +964,45 @@ mod tests { #[test] fn test_remove_redundant_entries_eq_group() -> Result<()> { + let c = |idx| Arc::new(Column::new(format!("col_{idx}").as_str(), idx)) as _; let entries = [ - EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]), - // This group is meaningless should be removed - EquivalenceClass::new(vec![lit(3), lit(3)]), - EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), + EquivalenceClass::new([c(1), c(1), lit(20)]), + EquivalenceClass::new([lit(30), lit(30)]), + EquivalenceClass::new([c(2), c(3), c(4)]), ]; // Given equivalences classes are not in succinct form. // Expected form is the most plain representation that is functionally same. let expected = [ - EquivalenceClass::new(vec![lit(1), lit(2)]), - EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), + EquivalenceClass::new([c(1), lit(20)]), + EquivalenceClass::new([lit(30)]), + EquivalenceClass::new([c(2), c(3), c(4)]), ]; - let mut eq_groups = EquivalenceGroup::new(entries.to_vec()); - eq_groups.remove_redundant_entries(); - - let eq_groups = eq_groups.classes; - assert_eq!(eq_groups.len(), expected.len()); - assert_eq!(eq_groups.len(), 2); - - assert_eq!(eq_groups[0], expected[0]); - assert_eq!(eq_groups[1], expected[1]); + let eq_groups = EquivalenceGroup::new(entries); + assert_eq!(eq_groups.classes, expected); Ok(()) } #[test] fn test_schema_normalize_expr_with_equivalence() -> Result<()> { - let col_a = &Column::new("a", 0); - let col_b = &Column::new("b", 1); - let col_c = &Column::new("c", 2); + let col_a = Arc::new(Column::new("a", 0)) as Arc; + let col_b = Arc::new(Column::new("b", 1)) as _; + let col_c = Arc::new(Column::new("c", 2)) as _; // Assume that column a and c are aliases. - let (_test_schema, eq_properties) = create_test_params()?; - - let col_a_expr = Arc::new(col_a.clone()) as Arc; - let col_b_expr = Arc::new(col_b.clone()) as Arc; - let col_c_expr = Arc::new(col_c.clone()) as Arc; - // Test cases for equivalence normalization, - // First entry in the tuple is argument, second entry is expected result after normalization. + let (_, eq_properties) = create_test_params()?; + // Test cases for equivalence normalization. First entry in the tuple is + // the argument, second entry is expected result after normalization. let expressions = vec![ // Normalized version of the column a and c should go to a // (by convention all the expressions inside equivalence class are mapped to the first entry // in this case a is the first entry in the equivalence class.) - (&col_a_expr, &col_a_expr), - (&col_c_expr, &col_a_expr), + (Arc::clone(&col_a), Arc::clone(&col_a)), + (col_c, col_a), // Cannot normalize column b - (&col_b_expr, &col_b_expr), + (Arc::clone(&col_b), Arc::clone(&col_b)), ]; let eq_group = eq_properties.eq_group(); for (expr, expected_eq) in expressions { - assert!( - expected_eq.eq(&eq_group.normalize_expr(Arc::clone(expr))), - "error in test: expr: {expr:?}" - ); + assert!(expected_eq.eq(&eq_group.normalize_expr(expr))); } Ok(()) @@ -870,21 +1010,15 @@ mod tests { #[test] fn test_contains_any() { - let lit_true = Arc::new(Literal::new(ScalarValue::Boolean(Some(true)))) - as Arc; - let lit_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))) - as Arc; - let lit2 = - Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; - let lit1 = - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; - let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - - let cls1 = - EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]); - let cls2 = - EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&col_b_expr)]); - let cls3 = EquivalenceClass::new(vec![Arc::clone(&lit2), Arc::clone(&lit1)]); + let lit_true = Arc::new(Literal::new(ScalarValue::from(true))) as _; + let lit_false = Arc::new(Literal::new(ScalarValue::from(false))) as _; + let col_a_expr = Arc::new(Column::new("a", 0)) as _; + let col_b_expr = Arc::new(Column::new("b", 1)) as _; + let col_c_expr = Arc::new(Column::new("c", 2)) as _; + + let cls1 = EquivalenceClass::new([Arc::clone(&lit_true), col_a_expr]); + let cls2 = EquivalenceClass::new([lit_true, col_b_expr]); + let cls3 = EquivalenceClass::new([col_c_expr, lit_false]); // lit_true is common assert!(cls1.contains_any(&cls2)); @@ -903,21 +1037,19 @@ mod tests { } // Create test columns - let col_a = Arc::new(Column::new("a", 0)) as Arc; - let col_b = Arc::new(Column::new("b", 1)) as Arc; - let col_x = Arc::new(Column::new("x", 2)) as Arc; - let col_y = Arc::new(Column::new("y", 3)) as Arc; + let col_a = Arc::new(Column::new("a", 0)) as _; + let col_b = Arc::new(Column::new("b", 1)) as _; + let col_x = Arc::new(Column::new("x", 2)) as _; + let col_y = Arc::new(Column::new("y", 3)) as _; // Create test literals - let lit_1 = - Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; - let lit_2 = - Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + let lit_1 = Arc::new(Literal::new(ScalarValue::from(1))) as _; + let lit_2 = Arc::new(Literal::new(ScalarValue::from(2))) as _; // Create equivalence group with classes (a = x) and (b = y) - let eq_group = EquivalenceGroup::new(vec![ - EquivalenceClass::new(vec![Arc::clone(&col_a), Arc::clone(&col_x)]), - EquivalenceClass::new(vec![Arc::clone(&col_b), Arc::clone(&col_y)]), + let eq_group = EquivalenceGroup::new([ + EquivalenceClass::new([Arc::clone(&col_a), Arc::clone(&col_x)]), + EquivalenceClass::new([Arc::clone(&col_b), Arc::clone(&col_y)]), ]); let test_cases = vec![ @@ -967,12 +1099,12 @@ mod tests { Arc::clone(&col_a), Operator::Plus, Arc::clone(&col_b), - )) as Arc, + )) as _, right: Arc::new(BinaryExpr::new( Arc::clone(&col_x), Operator::Plus, Arc::clone(&col_y), - )) as Arc, + )) as _, expected: true, description: "Binary expressions with equivalent operands should be equal", @@ -982,12 +1114,12 @@ mod tests { Arc::clone(&col_a), Operator::Plus, Arc::clone(&col_b), - )) as Arc, + )) as _, right: Arc::new(BinaryExpr::new( Arc::clone(&col_x), Operator::Plus, Arc::clone(&col_a), - )) as Arc, + )) as _, expected: false, description: "Binary expressions with non-equivalent operands should not be equal", @@ -997,12 +1129,12 @@ mod tests { Arc::clone(&col_a), Operator::Plus, Arc::clone(&lit_1), - )) as Arc, + )) as _, right: Arc::new(BinaryExpr::new( Arc::clone(&col_x), Operator::Plus, Arc::clone(&lit_1), - )) as Arc, + )) as _, expected: true, description: "Binary expressions with equivalent column and same literal should be equal", }, @@ -1015,7 +1147,7 @@ mod tests { )), Operator::Multiply, Arc::clone(&lit_1), - )) as Arc, + )) as _, right: Arc::new(BinaryExpr::new( Arc::new(BinaryExpr::new( Arc::clone(&col_x), @@ -1024,7 +1156,7 @@ mod tests { )), Operator::Multiply, Arc::clone(&lit_1), - )) as Arc, + )) as _, expected: true, description: "Nested binary expressions with equivalent operands should be equal", }, @@ -1040,8 +1172,7 @@ mod tests { let actual = eq_group.exprs_equal(&left, &right); assert_eq!( actual, expected, - "{}: Failed comparing {:?} and {:?}, expected {}, got {}", - description, left, right, expected, actual + "{description}: Failed comparing {left:?} and {right:?}, expected {expected}, got {actual}" ); } @@ -1059,36 +1190,36 @@ mod tests { Field::new("b", DataType::Int32, false), Field::new("c", DataType::Int32, false), ])); - let mut group = EquivalenceGroup::empty(); - group.add_equal_conditions(&col("a", &schema)?, &col("b", &schema)?); + let mut group = EquivalenceGroup::default(); + group.add_equal_conditions(col("a", &schema)?, col("b", &schema)?); let projected_schema = Arc::new(Schema::new(vec![ Field::new("a+c", DataType::Int32, false), Field::new("b+c", DataType::Int32, false), ])); - let mapping = ProjectionMapping { - map: vec![ - ( - binary( - col("a", &schema)?, - Operator::Plus, - col("c", &schema)?, - &schema, - )?, - col("a+c", &projected_schema)?, - ), - ( - binary( - col("b", &schema)?, - Operator::Plus, - col("c", &schema)?, - &schema, - )?, - col("b+c", &projected_schema)?, - ), - ], - }; + let mapping = [ + ( + binary( + col("a", &schema)?, + Operator::Plus, + col("c", &schema)?, + &schema, + )?, + vec![(col("a+c", &projected_schema)?, 0)].into(), + ), + ( + binary( + col("b", &schema)?, + Operator::Plus, + col("c", &schema)?, + &schema, + )?, + vec![(col("b+c", &projected_schema)?, 1)].into(), + ), + ] + .into_iter() + .collect::(); let projected = group.project(&mapping); diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index e94d2bad5712..ecb73be256d4 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. +use std::borrow::Borrow; use std::sync::Arc; -use crate::expressions::Column; -use crate::{LexRequirement, PhysicalExpr}; +use crate::PhysicalExpr; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use arrow::compute::SortOptions; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; mod class; mod ordering; @@ -34,50 +35,34 @@ pub use properties::{ calculate_union, join_equivalence_properties, EquivalenceProperties, }; -/// This function constructs a duplicate-free `LexOrderingReq` by filtering out -/// duplicate entries that have same physical expression inside. For example, -/// `vec![a Some(ASC), a Some(DESC)]` collapses to `vec![a Some(ASC)]`. -/// -/// It will also filter out entries that are ordered if the next entry is; -/// for instance, `vec![floor(a) Some(ASC), a Some(ASC)]` will be collapsed to -/// `vec![a Some(ASC)]`. -#[deprecated(since = "45.0.0", note = "Use LexRequirement::collapse")] -pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement { - input.collapse() +// Convert each tuple to a `PhysicalSortExpr` and construct a vector. +pub fn convert_to_sort_exprs>>( + args: &[(T, SortOptions)], +) -> Vec { + args.iter() + .map(|(expr, options)| PhysicalSortExpr::new(Arc::clone(expr.borrow()), *options)) + .collect() } -/// Adds the `offset` value to `Column` indices inside `expr`. This function is -/// generally used during the update of the right table schema in join operations. -pub fn add_offset_to_expr( - expr: Arc, - offset: usize, -) -> Arc { - expr.transform_down(|e| match e.as_any().downcast_ref::() { - Some(col) => Ok(Transformed::yes(Arc::new(Column::new( - col.name(), - offset + col.index(), - )))), - None => Ok(Transformed::no(e)), - }) - .data() - .unwrap() - // Note that we can safely unwrap here since our transform always returns - // an `Ok` value. +// Convert each vector of tuples to a `LexOrdering`. +pub fn convert_to_orderings>>( + args: &[Vec<(T, SortOptions)>], +) -> Vec { + args.iter() + .filter_map(|sort_exprs| LexOrdering::new(convert_to_sort_exprs(sort_exprs))) + .collect() } #[cfg(test)] mod tests { - use super::*; - use crate::expressions::col; - use crate::PhysicalSortExpr; + use crate::expressions::{col, Column}; + use crate::{LexRequirement, PhysicalSortExpr}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::{plan_datafusion_err, Result}; - use datafusion_physical_expr_common::sort_expr::{ - LexOrdering, PhysicalSortRequirement, - }; + use datafusion_common::{plan_err, Result}; + use datafusion_physical_expr_common::sort_expr::PhysicalSortRequirement; /// Converts a string to a physical sort expression /// @@ -97,8 +82,7 @@ mod tests { "ASC" => sort_expr.asc(), "DESC" => sort_expr.desc(), _ => panic!( - "unknown sort options. Expected 'ASC' or 'DESC', got {}", - options + "unknown sort options. Expected 'ASC' or 'DESC', got {options}" ), } } @@ -115,27 +99,21 @@ mod tests { mapping: &ProjectionMapping, input_schema: &Arc, ) -> Result { - // Calculate output schema - let fields: Result> = mapping - .iter() - .map(|(source, target)| { - let name = target - .as_any() - .downcast_ref::() - .ok_or_else(|| plan_datafusion_err!("Expects to have column"))? - .name(); - let field = Field::new( - name, - source.data_type(input_schema)?, - source.nullable(input_schema)?, - ); - - Ok(field) - }) - .collect(); + // Calculate output schema: + let mut fields = vec![]; + for (source, targets) in mapping.iter() { + let data_type = source.data_type(input_schema)?; + let nullable = source.nullable(input_schema)?; + for (target, _) in targets.iter() { + let Some(column) = target.as_any().downcast_ref::() else { + return plan_err!("Expects to have column"); + }; + fields.push(Field::new(column.name(), data_type.clone(), nullable)); + } + } let output_schema = Arc::new(Schema::new_with_metadata( - fields?, + fields, input_schema.metadata().clone(), )); @@ -164,15 +142,15 @@ mod tests { /// Column [a=c] (e.g they are aliases). pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { let test_schema = create_test_schema()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_g = &col("g", &test_schema)?; + let col_a = col("a", &test_schema)?; + let col_b = col("b", &test_schema)?; + let col_c = col("c", &test_schema)?; + let col_d = col("d", &test_schema)?; + let col_e = col("e", &test_schema)?; + let col_f = col("f", &test_schema)?; + let col_g = col("g", &test_schema)?; let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); - eq_properties.add_equal_conditions(col_a, col_c)?; + eq_properties.add_equal_conditions(Arc::clone(&col_a), Arc::clone(&col_c))?; let option_asc = SortOptions { descending: false, @@ -195,68 +173,19 @@ mod tests { ], ]; let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); Ok((test_schema, eq_properties)) } - // Convert each tuple to PhysicalSortRequirement + // Convert each tuple to a `PhysicalSortRequirement` and construct a + // a `LexRequirement` from them. pub fn convert_to_sort_reqs( - in_data: &[(&Arc, Option)], + args: &[(&Arc, Option)], ) -> LexRequirement { - in_data - .iter() - .map(|(expr, options)| { - PhysicalSortRequirement::new(Arc::clone(*expr), *options) - }) - .collect() - } - - // Convert each tuple to PhysicalSortExpr - pub fn convert_to_sort_exprs( - in_data: &[(&Arc, SortOptions)], - ) -> LexOrdering { - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(*expr), - options: *options, - }) - .collect() - } - - // Convert each inner tuple to PhysicalSortExpr - pub fn convert_to_orderings( - orderings: &[Vec<(&Arc, SortOptions)>], - ) -> Vec { - orderings - .iter() - .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) - .collect() - } - - // Convert each tuple to PhysicalSortExpr - pub fn convert_to_sort_exprs_owned( - in_data: &[(Arc, SortOptions)], - ) -> LexOrdering { - LexOrdering::new( - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(expr), - options: *options, - }) - .collect(), - ) - } - - // Convert each inner tuple to PhysicalSortExpr - pub fn convert_to_orderings_owned( - orderings: &[Vec<(Arc, SortOptions)>], - ) -> Vec { - orderings - .iter() - .map(|sort_exprs| convert_to_sort_exprs_owned(sort_exprs)) - .collect() + let exprs = args.iter().map(|(expr, options)| { + PhysicalSortRequirement::new(Arc::clone(*expr), *options) + }); + LexRequirement::new(exprs).unwrap() } #[test] @@ -270,49 +199,49 @@ mod tests { ])); let mut eq_properties = EquivalenceProperties::new(schema); - let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; - let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; - let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; - let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; + let col_a = Arc::new(Column::new("a", 0)) as _; + let col_b = Arc::new(Column::new("b", 1)) as _; + let col_c = Arc::new(Column::new("c", 2)) as _; + let col_x = Arc::new(Column::new("x", 3)) as _; + let col_y = Arc::new(Column::new("y", 4)) as _; // a and b are aliases - eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_a), Arc::clone(&col_b))?; assert_eq!(eq_properties.eq_group().len(), 1); // This new entry is redundant, size shouldn't increase - eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_a))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 2); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); // b and c are aliases. Existing equivalence class should expand, // however there shouldn't be any new equivalence class - eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_c))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 3); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - assert!(eq_groups.contains(&col_c_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); + assert!(eq_groups.contains(&col_c)); // This is a new set of equality. Hence equivalent class count should be 2. - eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_x), Arc::clone(&col_y))?; assert_eq!(eq_properties.eq_group().len(), 2); // This equality bridges distinct equality sets. // Hence equivalent class count should decrease from 2 to 1. - eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?; + eq_properties.add_equal_conditions(Arc::clone(&col_x), Arc::clone(&col_a))?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = eq_properties.eq_group().iter().next().unwrap(); assert_eq!(eq_groups.len(), 5); - assert!(eq_groups.contains(&col_a_expr)); - assert!(eq_groups.contains(&col_b_expr)); - assert!(eq_groups.contains(&col_c_expr)); - assert!(eq_groups.contains(&col_x_expr)); - assert!(eq_groups.contains(&col_y_expr)); + assert!(eq_groups.contains(&col_a)); + assert!(eq_groups.contains(&col_b)); + assert!(eq_groups.contains(&col_c)); + assert!(eq_groups.contains(&col_x)); + assert!(eq_groups.contains(&col_y)); Ok(()) } diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index 0efd46ad912e..875c2a76e5eb 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -16,115 +16,83 @@ // under the License. use std::fmt::Display; -use std::hash::Hash; +use std::ops::Deref; use std::sync::Arc; use std::vec::IntoIter; -use crate::equivalence::add_offset_to_expr; -use crate::{LexOrdering, PhysicalExpr}; +use crate::expressions::with_new_schema; +use crate::{add_offset_to_physical_sort_exprs, LexOrdering, PhysicalExpr}; use arrow::compute::SortOptions; -use datafusion_common::HashSet; +use arrow::datatypes::SchemaRef; +use datafusion_common::{HashSet, Result}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; -/// An `OrderingEquivalenceClass` object keeps track of different alternative -/// orderings than can describe a schema. For example, consider the following table: +/// An `OrderingEquivalenceClass` keeps track of distinct alternative orderings +/// than can describe a table. For example, consider the following table: /// /// ```text -/// |a|b|c|d| -/// |1|4|3|1| -/// |2|3|3|2| -/// |3|1|2|2| -/// |3|2|1|3| +/// ┌───┬───┬───┬───┐ +/// │ a │ b │ c │ d │ +/// ├───┼───┼───┼───┤ +/// │ 1 │ 4 │ 3 │ 1 │ +/// │ 2 │ 3 │ 3 │ 2 │ +/// │ 3 │ 1 │ 2 │ 2 │ +/// │ 3 │ 2 │ 1 │ 3 │ +/// └───┴───┴───┴───┘ /// ``` /// -/// Here, both `vec![a ASC, b ASC]` and `vec![c DESC, d ASC]` describe the table +/// Here, both `[a ASC, b ASC]` and `[c DESC, d ASC]` describe the table /// ordering. In this case, we say that these orderings are equivalent. -#[derive(Debug, Clone, Eq, PartialEq, Hash, Default)] +/// +/// An `OrderingEquivalenceClass` is a set of such equivalent orderings, which +/// is represented by a vector of `LexOrdering`s. The set does not store any +/// redundant information by enforcing the invariant that no suffix of an +/// ordering in the equivalence class is a prefix of another ordering in the +/// equivalence class. The set can be empty, which means that there are no +/// orderings that describe the table. +#[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct OrderingEquivalenceClass { orderings: Vec, } impl OrderingEquivalenceClass { - /// Creates new empty ordering equivalence class. - pub fn empty() -> Self { - Default::default() - } - /// Clears (empties) this ordering equivalence class. pub fn clear(&mut self) { self.orderings.clear(); } - /// Creates new ordering equivalence class from the given orderings - /// - /// Any redundant entries are removed - pub fn new(orderings: Vec) -> Self { - let mut result = Self { orderings }; + /// Creates a new ordering equivalence class from the given orderings + /// and removes any redundant entries (if given). + pub fn new( + orderings: impl IntoIterator>, + ) -> Self { + let mut result = Self { + orderings: orderings.into_iter().filter_map(LexOrdering::new).collect(), + }; result.remove_redundant_entries(); result } - /// Converts this OrderingEquivalenceClass to a vector of orderings. - pub fn into_inner(self) -> Vec { - self.orderings - } - - /// Checks whether `ordering` is a member of this equivalence class. - pub fn contains(&self, ordering: &LexOrdering) -> bool { - self.orderings.contains(ordering) - } - - /// Adds `ordering` to this equivalence class. - #[allow(dead_code)] - #[deprecated( - since = "45.0.0", - note = "use OrderingEquivalenceClass::add_new_ordering instead" - )] - fn push(&mut self, ordering: LexOrdering) { - self.add_new_ordering(ordering) - } - - /// Checks whether this ordering equivalence class is empty. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns an iterator over the equivalent orderings in this class. - /// - /// Note this class also implements [`IntoIterator`] to return an iterator - /// over owned [`LexOrdering`]s. - pub fn iter(&self) -> impl Iterator { - self.orderings.iter() - } - - /// Returns how many equivalent orderings there are in this class. - pub fn len(&self) -> usize { - self.orderings.len() - } - - /// Extend this ordering equivalence class with the `other` class. - pub fn extend(&mut self, other: Self) { - self.orderings.extend(other.orderings); + /// Extend this ordering equivalence class with the given orderings. + pub fn extend(&mut self, orderings: impl IntoIterator) { + self.orderings.extend(orderings); // Make sure that there are no redundant orderings: self.remove_redundant_entries(); } - /// Adds new orderings into this ordering equivalence class - pub fn add_new_orderings( + /// Adds new orderings into this ordering equivalence class. + pub fn add_orderings( &mut self, - orderings: impl IntoIterator, + sort_exprs: impl IntoIterator>, ) { - self.orderings.extend(orderings); + self.orderings + .extend(sort_exprs.into_iter().filter_map(LexOrdering::new)); // Make sure that there are no redundant orderings: self.remove_redundant_entries(); } - /// Adds a single ordering to the existing ordering equivalence class. - pub fn add_new_ordering(&mut self, ordering: LexOrdering) { - self.add_new_orderings([ordering]); - } - - /// Removes redundant orderings from this equivalence class. + /// Removes redundant orderings from this ordering equivalence class. /// /// For instance, if we already have the ordering `[a ASC, b ASC, c DESC]`, /// then there is no need to keep ordering `[a ASC, b ASC]` in the state. @@ -133,82 +101,72 @@ impl OrderingEquivalenceClass { while work { work = false; let mut idx = 0; - while idx < self.orderings.len() { + 'outer: while idx < self.orderings.len() { let mut ordering_idx = idx + 1; - let mut removal = self.orderings[idx].is_empty(); while ordering_idx < self.orderings.len() { - work |= self.resolve_overlap(idx, ordering_idx); - if self.orderings[idx].is_empty() { - removal = true; - break; + if let Some(remove) = self.resolve_overlap(idx, ordering_idx) { + work = true; + if remove { + self.orderings.swap_remove(idx); + continue 'outer; + } } - work |= self.resolve_overlap(ordering_idx, idx); - if self.orderings[ordering_idx].is_empty() { - self.orderings.swap_remove(ordering_idx); - } else { - ordering_idx += 1; + if let Some(remove) = self.resolve_overlap(ordering_idx, idx) { + work = true; + if remove { + self.orderings.swap_remove(ordering_idx); + continue; + } } + ordering_idx += 1; } - if removal { - self.orderings.swap_remove(idx); - } else { - idx += 1; - } + idx += 1; } } } /// Trims `orderings[idx]` if some suffix of it overlaps with a prefix of - /// `orderings[pre_idx]`. Returns `true` if there is any overlap, `false` otherwise. + /// `orderings[pre_idx]`. If there is any overlap, returns a `Some(true)` + /// if any trimming took place, and `Some(false)` otherwise. If there is + /// no overlap, returns `None`. /// /// For example, if `orderings[idx]` is `[a ASC, b ASC, c DESC]` and /// `orderings[pre_idx]` is `[b ASC, c DESC]`, then the function will trim /// `orderings[idx]` to `[a ASC]`. - fn resolve_overlap(&mut self, idx: usize, pre_idx: usize) -> bool { + fn resolve_overlap(&mut self, idx: usize, pre_idx: usize) -> Option { let length = self.orderings[idx].len(); let other_length = self.orderings[pre_idx].len(); for overlap in 1..=length.min(other_length) { if self.orderings[idx][length - overlap..] == self.orderings[pre_idx][..overlap] { - self.orderings[idx].truncate(length - overlap); - return true; + return Some(!self.orderings[idx].truncate(length - overlap)); } } - false + None } /// Returns the concatenation of all the orderings. This enables merge /// operations to preserve all equivalent orderings simultaneously. pub fn output_ordering(&self) -> Option { - let output_ordering = self - .orderings - .iter() - .flatten() - .cloned() - .collect::() - .collapse(); - (!output_ordering.is_empty()).then_some(output_ordering) + self.orderings.iter().cloned().reduce(|mut cat, o| { + cat.extend(o); + cat + }) } - // Append orderings in `other` to all existing orderings in this equivalence - // class. + // Append orderings in `other` to all existing orderings in this ordering + // equivalence class. pub fn join_suffix(mut self, other: &Self) -> Self { let n_ordering = self.orderings.len(); - // Replicate entries before cross product + // Replicate entries before cross product: let n_cross = std::cmp::max(n_ordering, other.len() * n_ordering); - self.orderings = self - .orderings - .iter() - .cloned() - .cycle() - .take(n_cross) - .collect(); - // Suffix orderings of other to the current orderings. + self.orderings = self.orderings.into_iter().cycle().take(n_cross).collect(); + // Append sort expressions of `other` to the current orderings: for (outer_idx, ordering) in other.iter().enumerate() { - for idx in 0..n_ordering { - // Calculate cross product index - let idx = outer_idx * n_ordering + idx; + let base = outer_idx * n_ordering; + // Use the cross product index: + for idx in base..(base + n_ordering) { self.orderings[idx].extend(ordering.iter().cloned()); } } @@ -217,12 +175,40 @@ impl OrderingEquivalenceClass { /// Adds `offset` value to the index of each expression inside this /// ordering equivalence class. - pub fn add_offset(&mut self, offset: usize) { - for ordering in self.orderings.iter_mut() { - ordering.transform(|sort_expr| { - sort_expr.expr = add_offset_to_expr(Arc::clone(&sort_expr.expr), offset); - }) + pub fn add_offset(&mut self, offset: isize) -> Result<()> { + let orderings = std::mem::take(&mut self.orderings); + for ordering_result in orderings + .into_iter() + .map(|o| add_offset_to_physical_sort_exprs(o, offset)) + { + self.orderings.extend(LexOrdering::new(ordering_result?)); } + Ok(()) + } + + /// Transforms this `OrderingEquivalenceClass` by mapping columns in the + /// original schema to columns in the new schema by index. The new schema + /// and the original schema needs to be aligned; i.e. they should have the + /// same number of columns, and fields at the same index have the same type + /// in both schemas. + pub fn with_new_schema(mut self, schema: &SchemaRef) -> Result { + self.orderings = self + .orderings + .into_iter() + .map(|ordering| { + ordering + .into_iter() + .map(|mut sort_expr| { + sort_expr.expr = with_new_schema(sort_expr.expr, schema)?; + Ok(sort_expr) + }) + .collect::>>() + // The following `unwrap` is safe because the vector will always + // be non-empty. + .map(|v| LexOrdering::new(v).unwrap()) + }) + .collect::>()?; + Ok(self) } /// Gets sort options associated with this expression if it is a leading @@ -257,31 +243,6 @@ impl OrderingEquivalenceClass { /// added as a constant during `ordering_satisfy_requirement()` iterations /// after the corresponding prefix requirement is satisfied. /// - /// ### Example Scenarios - /// - /// In these scenarios, we assume that all expressions share the same sort - /// properties. - /// - /// #### Case 1: Sort Requirement `[a, c]` - /// - /// **Existing Orderings:** `[[a, b, c], [a, d]]`, **Constants:** `[]` - /// 1. `ordering_satisfy_single()` returns `true` because the requirement - /// `a` is satisfied by `[a, b, c].first()`. - /// 2. `a` is added as a constant for the next iteration. - /// 3. The normalized orderings become `[[b, c], [d]]`. - /// 4. `ordering_satisfy_single()` returns `false` for `c`, as neither - /// `[b, c]` nor `[d]` satisfies `c`. - /// - /// #### Case 2: Sort Requirement `[a, d]` - /// - /// **Existing Orderings:** `[[a, b, c], [a, d]]`, **Constants:** `[]` - /// 1. `ordering_satisfy_single()` returns `true` because the requirement - /// `a` is satisfied by `[a, b, c].first()`. - /// 2. `a` is added as a constant for the next iteration. - /// 3. The normalized orderings become `[[b, c], [d]]`. - /// 4. `ordering_satisfy_single()` returns `true` for `d`, as `[d]` satisfies - /// `d`. - /// /// ### Future Improvements /// /// This function may become unnecessary if any of the following improvements @@ -296,15 +257,14 @@ impl OrderingEquivalenceClass { ]; for ordering in self.iter() { - if let Some(leading_ordering) = ordering.first() { - if leading_ordering.expr.eq(expr) { - let opt = ( - leading_ordering.options.descending, - leading_ordering.options.nulls_first, - ); - constantness_defining_pairs[0].remove(&opt); - constantness_defining_pairs[1].remove(&opt); - } + let leading_ordering = ordering.first(); + if leading_ordering.expr.eq(expr) { + let opt = ( + leading_ordering.options.descending, + leading_ordering.options.nulls_first, + ); + constantness_defining_pairs[0].remove(&opt); + constantness_defining_pairs[1].remove(&opt); } } @@ -314,10 +274,26 @@ impl OrderingEquivalenceClass { } } -/// Convert the `OrderingEquivalenceClass` into an iterator of LexOrderings +impl Deref for OrderingEquivalenceClass { + type Target = [LexOrdering]; + + fn deref(&self) -> &Self::Target { + self.orderings.as_slice() + } +} + +impl From> for OrderingEquivalenceClass { + fn from(orderings: Vec) -> Self { + let mut result = Self { orderings }; + result.remove_redundant_entries(); + result + } +} + +/// Convert the `OrderingEquivalenceClass` into an iterator of `LexOrdering`s. impl IntoIterator for OrderingEquivalenceClass { type Item = LexOrdering; - type IntoIter = IntoIter; + type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { self.orderings.into_iter() @@ -329,13 +305,18 @@ impl Display for OrderingEquivalenceClass { write!(f, "[")?; let mut iter = self.orderings.iter(); if let Some(ordering) = iter.next() { - write!(f, "[{}]", ordering)?; + write!(f, "[{ordering}]")?; } for ordering in iter { - write!(f, ", [{}]", ordering)?; + write!(f, ", [{ordering}]")?; } - write!(f, "]")?; - Ok(()) + write!(f, "]") + } +} + +impl From for Vec { + fn from(oeq_class: OrderingEquivalenceClass) -> Self { + oeq_class.orderings } } @@ -343,12 +324,10 @@ impl Display for OrderingEquivalenceClass { mod tests { use std::sync::Arc; - use crate::equivalence::tests::{ - convert_to_orderings, convert_to_sort_exprs, create_test_schema, - }; + use crate::equivalence::tests::create_test_schema; use crate::equivalence::{ - EquivalenceClass, EquivalenceGroup, EquivalenceProperties, - OrderingEquivalenceClass, + convert_to_orderings, convert_to_sort_exprs, EquivalenceClass, EquivalenceGroup, + EquivalenceProperties, OrderingEquivalenceClass, }; use crate::expressions::{col, BinaryExpr, Column}; use crate::utils::tests::TestScalarUDF; @@ -361,7 +340,6 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; - use datafusion_physical_expr_common::sort_expr::LexOrdering; #[test] fn test_ordering_satisfy() -> Result<()> { @@ -369,11 +347,11 @@ mod tests { Field::new("a", DataType::Int64, true), Field::new("b", DataType::Int64, true), ])); - let crude = LexOrdering::new(vec![PhysicalSortExpr { + let crude = vec![PhysicalSortExpr { expr: Arc::new(Column::new("a", 0)), options: SortOptions::default(), - }]); - let finer = LexOrdering::new(vec![ + }]; + let finer = vec![ PhysicalSortExpr { expr: Arc::new(Column::new("a", 0)), options: SortOptions::default(), @@ -382,20 +360,18 @@ mod tests { expr: Arc::new(Column::new("b", 1)), options: SortOptions::default(), }, - ]); + ]; // finer ordering satisfies, crude ordering should return true let eq_properties_finer = EquivalenceProperties::new_with_orderings( Arc::clone(&input_schema), - &[finer.clone()], + [finer.clone()], ); - assert!(eq_properties_finer.ordering_satisfy(crude.as_ref())); + assert!(eq_properties_finer.ordering_satisfy(crude.clone())?); // Crude ordering doesn't satisfy finer ordering. should return false - let eq_properties_crude = EquivalenceProperties::new_with_orderings( - Arc::clone(&input_schema), - &[crude.clone()], - ); - assert!(!eq_properties_crude.ordering_satisfy(finer.as_ref())); + let eq_properties_crude = + EquivalenceProperties::new_with_orderings(Arc::clone(&input_schema), [crude]); + assert!(!eq_properties_crude.ordering_satisfy(finer)?); Ok(()) } @@ -663,30 +639,20 @@ mod tests { format!("error in test orderings: {orderings:?}, eq_group: {eq_group:?}, constants: {constants:?}, reqs: {reqs:?}, expected: {expected:?}"); let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); - let eq_group = eq_group + eq_properties.add_orderings(orderings); + let classes = eq_group .into_iter() - .map(|eq_class| { - let eq_classes = eq_class.into_iter().cloned().collect::>(); - EquivalenceClass::new(eq_classes) - }) - .collect::>(); - let eq_group = EquivalenceGroup::new(eq_group); - eq_properties.add_equivalence_group(eq_group); + .map(|eq_class| EquivalenceClass::new(eq_class.into_iter().cloned())); + let eq_group = EquivalenceGroup::new(classes); + eq_properties.add_equivalence_group(eq_group)?; let constants = constants.into_iter().map(|expr| { - ConstExpr::from(expr) - .with_across_partitions(AcrossPartitions::Uniform(None)) + ConstExpr::new(Arc::clone(expr), AcrossPartitions::Uniform(None)) }); - eq_properties = eq_properties.with_constants(constants); + eq_properties.add_constants(constants)?; let reqs = convert_to_sort_exprs(&reqs); - assert_eq!( - eq_properties.ordering_satisfy(reqs.as_ref()), - expected, - "{}", - err_msg - ); + assert_eq!(eq_properties.ordering_satisfy(reqs)?, expected, "{err_msg}"); } Ok(()) @@ -707,7 +673,7 @@ mod tests { }; // a=c (e.g they are aliases). let mut eq_properties = EquivalenceProperties::new(test_schema); - eq_properties.add_equal_conditions(col_a, col_c)?; + eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_c))?; let orderings = vec![ vec![(col_a, options)], @@ -717,7 +683,7 @@ mod tests { let orderings = convert_to_orderings(&orderings); // Column [a ASC], [e ASC], [d ASC, f ASC] are all valid orderings for the schema. - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); // First entry in the tuple is required ordering, second entry is the expected flag // that indicates whether this required ordering is satisfied. @@ -739,14 +705,9 @@ mod tests { for (reqs, expected) in test_cases { let err_msg = - format!("error in test reqs: {:?}, expected: {:?}", reqs, expected,); + format!("error in test reqs: {reqs:?}, expected: {expected:?}",); let reqs = convert_to_sort_exprs(&reqs); - assert_eq!( - eq_properties.ordering_satisfy(reqs.as_ref()), - expected, - "{}", - err_msg - ); + assert_eq!(eq_properties.ordering_satisfy(reqs)?, expected, "{err_msg}"); } Ok(()) @@ -856,7 +817,7 @@ mod tests { // ------- TEST CASE 5 --------- // Empty ordering ( - vec![vec![]], + vec![], // No ordering in the state (empty ordering is ignored). vec![], ), @@ -975,13 +936,11 @@ mod tests { for (orderings, expected) in test_cases { let orderings = convert_to_orderings(&orderings); let expected = convert_to_orderings(&expected); - let actual = OrderingEquivalenceClass::new(orderings.clone()); - let actual = actual.orderings; + let actual = OrderingEquivalenceClass::from(orderings.clone()); let err_msg = format!( - "orderings: {:?}, expected: {:?}, actual :{:?}", - orderings, expected, actual + "orderings: {orderings:?}, expected: {expected:?}, actual :{actual:?}" ); - assert_eq!(actual.len(), expected.len(), "{}", err_msg); + assert_eq!(actual.len(), expected.len(), "{err_msg}"); for elem in actual { assert!(expected.contains(&elem), "{}", err_msg); } diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index a33339091c85..38bb1fef8074 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::ops::Deref; use std::sync::Arc; use crate::expressions::Column; @@ -24,13 +25,52 @@ use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_err, Result}; +use indexmap::IndexMap; + +/// Stores target expressions, along with their indices, that associate with a +/// source expression in a projection mapping. +#[derive(Clone, Debug, Default)] +pub struct ProjectionTargets { + /// A non-empty vector of pairs of target expressions and their indices. + /// Consider using a special non-empty collection type in the future (e.g. + /// if Rust provides one in the standard library). + exprs_indices: Vec<(Arc, usize)>, +} + +impl ProjectionTargets { + /// Returns the first target expression and its index. + pub fn first(&self) -> &(Arc, usize) { + // Since the vector is non-empty, we can safely unwrap: + self.exprs_indices.first().unwrap() + } + + /// Adds a target expression and its index to the list of targets. + pub fn push(&mut self, target: (Arc, usize)) { + self.exprs_indices.push(target); + } +} + +impl Deref for ProjectionTargets { + type Target = [(Arc, usize)]; + + fn deref(&self) -> &Self::Target { + &self.exprs_indices + } +} + +impl From, usize)>> for ProjectionTargets { + fn from(exprs_indices: Vec<(Arc, usize)>) -> Self { + Self { exprs_indices } + } +} + /// Stores the mapping between source expressions and target expressions for a /// projection. -#[derive(Debug, Clone)] +#[derive(Clone, Debug)] pub struct ProjectionMapping { /// Mapping between source expressions and target expressions. /// Vector indices correspond to the indices after projection. - pub map: Vec<(Arc, Arc)>, + map: IndexMap, ProjectionTargets>, } impl ProjectionMapping { @@ -42,44 +82,46 @@ impl ProjectionMapping { /// projection mapping would be: /// /// ```text - /// [0]: (c + d, col("c + d")) - /// [1]: (a + b, col("a + b")) + /// [0]: (c + d, [(col("c + d"), 0)]) + /// [1]: (a + b, [(col("a + b"), 1)]) /// ``` /// /// where `col("c + d")` means the column named `"c + d"`. pub fn try_new( - expr: &[(Arc, String)], + expr: impl IntoIterator, String)>, input_schema: &SchemaRef, ) -> Result { // Construct a map from the input expressions to the output expression of the projection: - expr.iter() - .enumerate() - .map(|(expr_idx, (expression, name))| { - let target_expr = Arc::new(Column::new(name, expr_idx)) as _; - Arc::clone(expression) - .transform_down(|e| match e.as_any().downcast_ref::() { - Some(col) => { - // Sometimes, an expression and its name in the input_schema - // doesn't match. This can cause problems, so we make sure - // that the expression name matches with the name in `input_schema`. - // Conceptually, `source_expr` and `expression` should be the same. - let idx = col.index(); - let matching_input_field = input_schema.field(idx); - if col.name() != matching_input_field.name() { - return internal_err!("Input field name {} does not match with the projection expression {}", - matching_input_field.name(),col.name()) - } - let matching_input_column = - Column::new(matching_input_field.name(), idx); - Ok(Transformed::yes(Arc::new(matching_input_column))) - } - None => Ok(Transformed::no(e)), - }) - .data() - .map(|source_expr| (source_expr, target_expr)) + let mut map = IndexMap::<_, ProjectionTargets>::new(); + for (expr_idx, (expr, name)) in expr.into_iter().enumerate() { + let target_expr = Arc::new(Column::new(&name, expr_idx)) as _; + let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::() { + Some(col) => { + // Sometimes, an expression and its name in the input_schema + // doesn't match. This can cause problems, so we make sure + // that the expression name matches with the name in `input_schema`. + // Conceptually, `source_expr` and `expression` should be the same. + let idx = col.index(); + let matching_field = input_schema.field(idx); + let matching_name = matching_field.name(); + if col.name() != matching_name { + return internal_err!( + "Input field name {} does not match with the projection expression {}", + matching_name, + col.name() + ); + } + let matching_column = Column::new(matching_name, idx); + Ok(Transformed::yes(Arc::new(matching_column))) + } + None => Ok(Transformed::no(e)), }) - .collect::>>() - .map(|map| Self { map }) + .data()?; + map.entry(source_expr) + .or_default() + .push((target_expr, expr_idx)); + } + Ok(Self { map }) } /// Constructs a subset mapping using the provided indices. @@ -87,61 +129,38 @@ impl ProjectionMapping { /// This is used when the output is a subset of the input without any /// other transformations. The indices are for columns in the schema. pub fn from_indices(indices: &[usize], schema: &SchemaRef) -> Result { - let projection_exprs = project_index_to_exprs(indices, schema); - ProjectionMapping::try_new(&projection_exprs, schema) + let projection_exprs = indices.iter().map(|index| { + let field = schema.field(*index); + let column = Arc::new(Column::new(field.name(), *index)); + (column as _, field.name().clone()) + }); + ProjectionMapping::try_new(projection_exprs, schema) } +} - /// Iterate over pairs of (source, target) expressions - pub fn iter( - &self, - ) -> impl Iterator, Arc)> + '_ { - self.map.iter() - } +impl Deref for ProjectionMapping { + type Target = IndexMap, ProjectionTargets>; - /// This function returns the target expression for a given source expression. - /// - /// # Arguments - /// - /// * `expr` - Source physical expression. - /// - /// # Returns - /// - /// An `Option` containing the target for the given source expression, - /// where a `None` value means that `expr` is not inside the mapping. - pub fn target_expr( - &self, - expr: &Arc, - ) -> Option> { - self.map - .iter() - .find(|(source, _)| source.eq(expr)) - .map(|(_, target)| Arc::clone(target)) + fn deref(&self) -> &Self::Target { + &self.map } } -fn project_index_to_exprs( - projection_index: &[usize], - schema: &SchemaRef, -) -> Vec<(Arc, String)> { - projection_index - .iter() - .map(|index| { - let field = schema.field(*index); - ( - Arc::new(Column::new(field.name(), *index)) as Arc, - field.name().to_owned(), - ) - }) - .collect::>() +impl FromIterator<(Arc, ProjectionTargets)> for ProjectionMapping { + fn from_iter, ProjectionTargets)>>( + iter: T, + ) -> Self { + Self { + map: IndexMap::from_iter(iter), + } + } } #[cfg(test)] mod tests { use super::*; - use crate::equivalence::tests::{ - convert_to_orderings, convert_to_orderings_owned, output_schema, - }; - use crate::equivalence::EquivalenceProperties; + use crate::equivalence::tests::output_schema; + use crate::equivalence::{convert_to_orderings, EquivalenceProperties}; use crate::expressions::{col, BinaryExpr}; use crate::utils::tests::TestScalarUDF; use crate::{PhysicalExprRef, ScalarFunctionExpr}; @@ -608,13 +627,12 @@ mod tests { let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name)) - .collect::>(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + .map(|(expr, name)| (Arc::clone(expr), name)); + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?; let output_schema = output_schema(&projection_mapping, &schema)?; let expected = expected @@ -628,17 +646,16 @@ mod tests { .collect::>() }) .collect::>(); - let expected = convert_to_orderings_owned(&expected); + let expected = convert_to_orderings(&expected); let projected_eq = eq_properties.project(&projection_mapping, output_schema); let orderings = projected_eq.oeq_class(); let err_msg = format!( - "test_idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", - idx, orderings, expected, projection_mapping + "test_idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}" ); - assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + assert_eq!(orderings.len(), expected.len(), "{err_msg}"); for expected_ordering in &expected { assert!(orderings.contains(expected_ordering), "{}", err_msg) } @@ -687,9 +704,8 @@ mod tests { ]; let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name)) - .collect::>(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + .map(|(expr, name)| (Arc::clone(expr), name)); + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?; let output_schema = output_schema(&projection_mapping, &schema)?; let col_a_new = &col("a_new", &output_schema)?; @@ -813,7 +829,7 @@ mod tests { let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); let orderings = convert_to_orderings(orderings); - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); let expected = convert_to_orderings(expected); @@ -822,11 +838,10 @@ mod tests { let orderings = projected_eq.oeq_class(); let err_msg = format!( - "test idx: {:?}, actual: {:?}, expected: {:?}, projection_mapping: {:?}", - idx, orderings, expected, projection_mapping + "test idx: {idx:?}, actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}" ); - assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + assert_eq!(orderings.len(), expected.len(), "{err_msg}"); for expected_ordering in &expected { assert!(orderings.contains(expected_ordering), "{}", err_msg) } @@ -868,9 +883,8 @@ mod tests { ]; let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name)) - .collect::>(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; + .map(|(expr, name)| (Arc::clone(expr), name)); + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &schema)?; let output_schema = output_schema(&projection_mapping, &schema)?; let col_a_plus_b_new = &col("a+b", &output_schema)?; @@ -955,11 +969,11 @@ mod tests { for (orderings, equal_columns, expected) in test_cases { let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); for (lhs, rhs) in equal_columns { - eq_properties.add_equal_conditions(lhs, rhs)?; + eq_properties.add_equal_conditions(Arc::clone(lhs), Arc::clone(rhs))?; } let orderings = convert_to_orderings(&orderings); - eq_properties.add_new_orderings(orderings); + eq_properties.add_orderings(orderings); let expected = convert_to_orderings(&expected); @@ -968,11 +982,10 @@ mod tests { let orderings = projected_eq.oeq_class(); let err_msg = format!( - "actual: {:?}, expected: {:?}, projection_mapping: {:?}", - orderings, expected, projection_mapping + "actual: {orderings:?}, expected: {expected:?}, projection_mapping: {projection_mapping:?}" ); - assert_eq!(orderings.len(), expected.len(), "{}", err_msg); + assert_eq!(orderings.len(), expected.len(), "{err_msg}"); for expected_ordering in &expected { assert!(orderings.contains(expected_ordering), "{}", err_msg) } diff --git a/datafusion/physical-expr/src/equivalence/properties/dependency.rs b/datafusion/physical-expr/src/equivalence/properties/dependency.rs index 9eba295e562e..4554e36f766d 100644 --- a/datafusion/physical-expr/src/equivalence/properties/dependency.rs +++ b/datafusion/physical-expr/src/equivalence/properties/dependency.rs @@ -16,71 +16,67 @@ // under the License. use std::fmt::{self, Display}; +use std::ops::{Deref, DerefMut}; use std::sync::Arc; +use super::expr_refers; use crate::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use indexmap::IndexSet; -use indexmap::IndexMap; +use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; -use super::{expr_refers, ExprWrapper}; - // A list of sort expressions that can be calculated from a known set of /// dependencies. #[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct Dependencies { - inner: IndexSet, + sort_exprs: IndexSet, } impl Display for Dependencies { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "[")?; - let mut iter = self.inner.iter(); + let mut iter = self.sort_exprs.iter(); if let Some(dep) = iter.next() { - write!(f, "{}", dep)?; + write!(f, "{dep}")?; } for dep in iter { - write!(f, ", {}", dep)?; + write!(f, ", {dep}")?; } write!(f, "]") } } impl Dependencies { - /// Create a new empty `Dependencies` instance. - fn new() -> Self { + // Creates a new `Dependencies` instance from the given sort expressions. + pub fn new(sort_exprs: impl IntoIterator) -> Self { Self { - inner: IndexSet::new(), + sort_exprs: sort_exprs.into_iter().collect(), } } +} - /// Create a new `Dependencies` from an iterator of `PhysicalSortExpr`. - pub fn new_from_iter(iter: impl IntoIterator) -> Self { - Self { - inner: iter.into_iter().collect(), - } - } +impl Deref for Dependencies { + type Target = IndexSet; - /// Insert a new dependency into the set. - pub fn insert(&mut self, sort_expr: PhysicalSortExpr) { - self.inner.insert(sort_expr); + fn deref(&self) -> &Self::Target { + &self.sort_exprs } +} - /// Iterator over dependencies in the set - pub fn iter(&self) -> impl Iterator + Clone { - self.inner.iter() +impl DerefMut for Dependencies { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.sort_exprs } +} - /// Return the inner set of dependencies - pub fn into_inner(self) -> IndexSet { - self.inner - } +impl IntoIterator for Dependencies { + type Item = PhysicalSortExpr; + type IntoIter = as IntoIterator>::IntoIter; - /// Returns true if there are no dependencies - fn is_empty(&self) -> bool { - self.inner.is_empty() + fn into_iter(self) -> Self::IntoIter { + self.sort_exprs.into_iter() } } @@ -133,26 +129,25 @@ impl<'a> DependencyEnumerator<'a> { let node = dependency_map .get(referred_sort_expr) .expect("`referred_sort_expr` should be inside `dependency_map`"); - // Since we work on intermediate nodes, we are sure `val.target_sort_expr` - // exists. - let target_sort_expr = node.target_sort_expr.as_ref().unwrap(); + // Since we work on intermediate nodes, we are sure `node.target` exists. + let target = node.target.as_ref().unwrap(); // An empty dependency means the referred_sort_expr represents a global ordering. // Return its projected version, which is the target_expression. if node.dependencies.is_empty() { - return vec![LexOrdering::new(vec![target_sort_expr.clone()])]; + return vec![[target.clone()].into()]; }; node.dependencies .iter() .flat_map(|dep| { - let mut orderings = if self.insert(target_sort_expr, dep) { + let mut orderings = if self.insert(target, dep) { self.construct_orderings(dep, dependency_map) } else { vec![] }; for ordering in orderings.iter_mut() { - ordering.push(target_sort_expr.clone()) + ordering.push(target.clone()); } orderings }) @@ -178,70 +173,55 @@ impl<'a> DependencyEnumerator<'a> { /// # Note on IndexMap Rationale /// /// Using `IndexMap` (which preserves insert order) to ensure consistent results -/// across different executions for the same query. We could have used -/// `HashSet`, `HashMap` in place of them without any loss of functionality. +/// across different executions for the same query. We could have used `HashSet` +/// and `HashMap` instead without any loss of functionality. /// /// As an example, if existing orderings are /// 1. `[a ASC, b ASC]` -/// 2. `[c ASC]` for +/// 2. `[c ASC]` /// /// Then both the following output orderings are valid /// 1. `[a ASC, b ASC, c ASC]` /// 2. `[c ASC, a ASC, b ASC]` /// -/// (this are both valid as they are concatenated versions of the alternative -/// orderings). When using `HashSet`, `HashMap` it is not guaranteed to generate -/// consistent result, among the possible 2 results in the example above. -#[derive(Debug)] +/// These are both valid as they are concatenated versions of the alternative +/// orderings. Had we used `HashSet`/`HashMap`, we couldn't guarantee to generate +/// the same result among the possible two results in the example above. +#[derive(Debug, Default)] pub struct DependencyMap { - inner: IndexMap, + map: IndexMap, } impl DependencyMap { - pub fn new() -> Self { - Self { - inner: IndexMap::new(), - } - } - - /// Insert a new dependency `sort_expr` --> `dependency` into the map. - /// - /// If `target_sort_expr` is none, a new entry is created with empty dependencies. + /// Insert a new dependency of `sort_expr` (i.e. `dependency`) into the map + /// along with its target sort expression. pub fn insert( &mut self, - sort_expr: &PhysicalSortExpr, - target_sort_expr: Option<&PhysicalSortExpr>, - dependency: Option<&PhysicalSortExpr>, + sort_expr: PhysicalSortExpr, + target_sort_expr: Option, + dependency: Option, ) { - self.inner - .entry(sort_expr.clone()) - .or_insert_with(|| DependencyNode { - target_sort_expr: target_sort_expr.cloned(), - dependencies: Dependencies::new(), - }) - .insert_dependency(dependency) - } - - /// Iterator over (sort_expr, DependencyNode) pairs - pub fn iter(&self) -> impl Iterator { - self.inner.iter() + let entry = self.map.entry(sort_expr); + let node = entry.or_insert_with(|| DependencyNode { + target: target_sort_expr, + dependencies: Dependencies::default(), + }); + node.dependencies.extend(dependency); } +} - /// iterator over all sort exprs - pub fn sort_exprs(&self) -> impl Iterator { - self.inner.keys() - } +impl Deref for DependencyMap { + type Target = IndexMap; - /// Return the dependency node for the given sort expression, if any - pub fn get(&self, sort_expr: &PhysicalSortExpr) -> Option<&DependencyNode> { - self.inner.get(sort_expr) + fn deref(&self) -> &Self::Target { + &self.map } } impl Display for DependencyMap { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "DependencyMap: {{")?; - for (sort_expr, node) in self.inner.iter() { + for (sort_expr, node) in self.map.iter() { writeln!(f, " {sort_expr} --> {node}")?; } writeln!(f, "}}") @@ -256,30 +236,21 @@ impl Display for DependencyMap { /// /// # Fields /// -/// - `target_sort_expr`: An optional `PhysicalSortExpr` representing the target -/// sort expression associated with the node. It is `None` if the sort expression +/// - `target`: An optional `PhysicalSortExpr` representing the target sort +/// expression associated with the node. It is `None` if the sort expression /// cannot be projected. /// - `dependencies`: A [`Dependencies`] containing dependencies on other sort /// expressions that are referred to by the target sort expression. #[derive(Debug, Clone, PartialEq, Eq)] pub struct DependencyNode { - pub target_sort_expr: Option, - pub dependencies: Dependencies, -} - -impl DependencyNode { - /// Insert dependency to the state (if exists). - fn insert_dependency(&mut self, dependency: Option<&PhysicalSortExpr>) { - if let Some(dep) = dependency { - self.dependencies.insert(dep.clone()); - } - } + pub(crate) target: Option, + pub(crate) dependencies: Dependencies, } impl Display for DependencyNode { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if let Some(target) = &self.target_sort_expr { - write!(f, "(target: {}, ", target)?; + if let Some(target) = &self.target { + write!(f, "(target: {target}, ")?; } else { write!(f, "(")?; } @@ -307,12 +278,12 @@ pub fn referred_dependencies( source: &Arc, ) -> Vec { // Associate `PhysicalExpr`s with `PhysicalSortExpr`s that contain them: - let mut expr_to_sort_exprs = IndexMap::::new(); + let mut expr_to_sort_exprs = IndexMap::<_, Dependencies>::new(); for sort_expr in dependency_map - .sort_exprs() + .keys() .filter(|sort_expr| expr_refers(source, &sort_expr.expr)) { - let key = ExprWrapper(Arc::clone(&sort_expr.expr)); + let key = Arc::clone(&sort_expr.expr); expr_to_sort_exprs .entry(key) .or_default() @@ -322,16 +293,10 @@ pub fn referred_dependencies( // Generate all valid dependencies for the source. For example, if the source // is `a + b` and the map is `[a -> (a ASC, a DESC), b -> (b ASC)]`, we get // `vec![HashSet(a ASC, b ASC), HashSet(a DESC, b ASC)]`. - let dependencies = expr_to_sort_exprs + expr_to_sort_exprs .into_values() - .map(Dependencies::into_inner) - .collect::>(); - dependencies - .iter() .multi_cartesian_product() - .map(|referred_deps| { - Dependencies::new_from_iter(referred_deps.into_iter().cloned()) - }) + .map(Dependencies::new) .collect() } @@ -378,46 +343,39 @@ pub fn construct_prefix_orderings( /// # Parameters /// /// * `dependencies` - Set of relevant expressions. -/// * `dependency_map` - Map of dependencies for expressions that may appear in `dependencies` +/// * `dependency_map` - Map of dependencies for expressions that may appear in +/// `dependencies`. /// /// # Returns /// -/// A vector of lexical orderings (`Vec`) representing all valid orderings -/// based on the given dependencies. +/// A vector of lexical orderings (`Vec`) representing all valid +/// orderings based on the given dependencies. pub fn generate_dependency_orderings( dependencies: &Dependencies, dependency_map: &DependencyMap, ) -> Vec { // Construct all the valid prefix orderings for each expression appearing - // in the projection: - let relevant_prefixes = dependencies + // in the projection. Note that if relevant prefixes are empty, there is no + // dependency, meaning that dependent is a leading ordering. + dependencies .iter() - .flat_map(|dep| { + .filter_map(|dep| { let prefixes = construct_prefix_orderings(dep, dependency_map); (!prefixes.is_empty()).then_some(prefixes) }) - .collect::>(); - - // No dependency, dependent is a leading ordering. - if relevant_prefixes.is_empty() { - // Return an empty ordering: - return vec![LexOrdering::default()]; - } - - relevant_prefixes - .into_iter() + // Generate all possible valid orderings: .multi_cartesian_product() .flat_map(|prefix_orderings| { + let length = prefix_orderings.len(); prefix_orderings - .iter() - .permutations(prefix_orderings.len()) - .map(|prefixes| { - prefixes - .into_iter() - .flat_map(|ordering| ordering.clone()) - .collect() + .into_iter() + .permutations(length) + .filter_map(|prefixes| { + prefixes.into_iter().reduce(|mut acc, ordering| { + acc.extend(ordering); + acc + }) }) - .collect::>() }) .collect() } @@ -429,10 +387,10 @@ mod tests { use super::*; use crate::equivalence::tests::{ - convert_to_sort_exprs, convert_to_sort_reqs, create_test_params, - create_test_schema, output_schema, parse_sort_expr, + convert_to_sort_reqs, create_test_params, create_test_schema, output_schema, + parse_sort_expr, }; - use crate::equivalence::ProjectionMapping; + use crate::equivalence::{convert_to_sort_exprs, ProjectionMapping}; use crate::expressions::{col, BinaryExpr, CastExpr, Column}; use crate::{ConstExpr, EquivalenceProperties, ScalarFunctionExpr}; @@ -441,9 +399,11 @@ mod tests { use datafusion_common::{Constraint, Constraints, Result}; use datafusion_expr::sort_properties::SortProperties; use datafusion_expr::Operator; - use datafusion_functions::string::concat; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use datafusion_physical_expr_common::sort_expr::{ + LexRequirement, PhysicalSortRequirement, + }; #[test] fn project_equivalence_properties_test() -> Result<()> { @@ -463,7 +423,7 @@ mod tests { (Arc::clone(&col_a), "a3".to_string()), (Arc::clone(&col_a), "a4".to_string()), ]; - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &input_schema)?; let out_schema = output_schema(&projection_mapping, &input_schema)?; // a as a1, a as a2, a as a3, a as a3 @@ -473,7 +433,7 @@ mod tests { (Arc::clone(&col_a), "a3".to_string()), (Arc::clone(&col_a), "a4".to_string()), ]; - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &input_schema)?; // a as a1, a as a2, a as a3, a as a3 let col_a1 = &col("a1", &out_schema)?; @@ -506,20 +466,20 @@ mod tests { let mut input_properties = EquivalenceProperties::new(Arc::clone(&input_schema)); // add equivalent ordering [a, b, c, d] - input_properties.add_new_ordering(LexOrdering::new(vec![ + input_properties.add_ordering([ parse_sort_expr("a", &input_schema), parse_sort_expr("b", &input_schema), parse_sort_expr("c", &input_schema), parse_sort_expr("d", &input_schema), - ])); + ]); // add equivalent ordering [a, c, b, d] - input_properties.add_new_ordering(LexOrdering::new(vec![ + input_properties.add_ordering([ parse_sort_expr("a", &input_schema), parse_sort_expr("c", &input_schema), parse_sort_expr("b", &input_schema), // NB b and c are swapped parse_sort_expr("d", &input_schema), - ])); + ]); // simply project all the columns in order let proj_exprs = vec![ @@ -528,7 +488,7 @@ mod tests { (col("c", &input_schema)?, "c".to_string()), (col("d", &input_schema)?, "d".to_string()), ]; - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &input_schema)?; let out_properties = input_properties.project(&projection_mapping, input_schema); assert_eq!( @@ -541,8 +501,6 @@ mod tests { #[test] fn test_normalize_ordering_equivalence_classes() -> Result<()> { - let sort_options = SortOptions::default(); - let schema = Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Int32, true), @@ -553,35 +511,19 @@ mod tests { let col_c_expr = col("c", &schema)?; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr)?; - let others = vec![ - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_b_expr), - options: sort_options, - }]), - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_c_expr), - options: sort_options, - }]), - ]; - eq_properties.add_new_orderings(others); + eq_properties.add_equal_conditions(col_a_expr, Arc::clone(&col_c_expr))?; + eq_properties.add_orderings([ + vec![PhysicalSortExpr::new_default(Arc::clone(&col_b_expr))], + vec![PhysicalSortExpr::new_default(Arc::clone(&col_c_expr))], + ]); let mut expected_eqs = EquivalenceProperties::new(Arc::new(schema)); - expected_eqs.add_new_orderings([ - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_b_expr), - options: sort_options, - }]), - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_c_expr), - options: sort_options, - }]), + expected_eqs.add_orderings([ + vec![PhysicalSortExpr::new_default(col_b_expr)], + vec![PhysicalSortExpr::new_default(col_c_expr)], ]); - let oeq_class = eq_properties.oeq_class().clone(); - let expected = expected_eqs.oeq_class(); - assert!(oeq_class.eq(expected)); - + assert!(eq_properties.oeq_class().eq(expected_eqs.oeq_class())); Ok(()) } @@ -594,34 +536,22 @@ mod tests { Field::new("a", DataType::Int32, true), Field::new("b", DataType::Int32, true), ]); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let required_columns = [Arc::clone(col_b), Arc::clone(col_a)]; + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let required_columns = [Arc::clone(&col_b), Arc::clone(&col_a)]; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); - eq_properties.add_new_orderings([LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ])]); - let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); + eq_properties.add_ordering([ + PhysicalSortExpr::new(Arc::new(Column::new("b", 1)), sort_options_not), + PhysicalSortExpr::new(Arc::new(Column::new("a", 0)), sort_options), + ]); + let (result, idxs) = eq_properties.find_longest_permutation(&required_columns)?; assert_eq!(idxs, vec![0, 1]); assert_eq!( result, - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(col_b), - options: sort_options_not - }, - PhysicalSortExpr { - expr: Arc::clone(col_a), - options: sort_options - } - ]) + vec![ + PhysicalSortExpr::new(col_b, sort_options_not), + PhysicalSortExpr::new(col_a, sort_options), + ] ); let schema = Schema::new(vec![ @@ -629,40 +559,28 @@ mod tests { Field::new("b", DataType::Int32, true), Field::new("c", DataType::Int32, true), ]); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let required_columns = [Arc::clone(col_b), Arc::clone(col_a)]; + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let required_columns = [Arc::clone(&col_b), Arc::clone(&col_a)]; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); - eq_properties.add_new_orderings([ - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: sort_options, - }]), - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ]), + eq_properties.add_orderings([ + vec![PhysicalSortExpr::new( + Arc::new(Column::new("c", 2)), + sort_options, + )], + vec![ + PhysicalSortExpr::new(Arc::new(Column::new("b", 1)), sort_options_not), + PhysicalSortExpr::new(Arc::new(Column::new("a", 0)), sort_options), + ], ]); - let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); + let (result, idxs) = eq_properties.find_longest_permutation(&required_columns)?; assert_eq!(idxs, vec![0, 1]); assert_eq!( result, - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(col_b), - options: sort_options_not - }, - PhysicalSortExpr { - expr: Arc::clone(col_a), - options: sort_options - } - ]) + vec![ + PhysicalSortExpr::new(col_b, sort_options_not), + PhysicalSortExpr::new(col_a, sort_options), + ] ); let required_columns = [ @@ -677,21 +595,12 @@ mod tests { let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); // not satisfied orders - eq_properties.add_new_orderings([LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: sort_options_not, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: sort_options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: sort_options, - }, - ])]); - let (_, idxs) = eq_properties.find_longest_permutation(&required_columns); + eq_properties.add_ordering([ + PhysicalSortExpr::new(Arc::new(Column::new("b", 1)), sort_options_not), + PhysicalSortExpr::new(Arc::new(Column::new("c", 2)), sort_options), + PhysicalSortExpr::new(Arc::new(Column::new("a", 0)), sort_options), + ]); + let (_, idxs) = eq_properties.find_longest_permutation(&required_columns)?; assert_eq!(idxs, vec![0]); Ok(()) @@ -707,49 +616,35 @@ mod tests { ]); let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let col_d = &col("d", &schema)?; + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_c = col("c", &schema)?; + let col_d = col("d", &schema)?; let option_asc = SortOptions { descending: false, nulls_first: false, }; // b=a (e.g they are aliases) - eq_properties.add_equal_conditions(col_b, col_a)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_a))?; // [b ASC], [d ASC] - eq_properties.add_new_orderings(vec![ - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(col_b), - options: option_asc, - }]), - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(col_d), - options: option_asc, - }]), + eq_properties.add_orderings([ + vec![PhysicalSortExpr::new(Arc::clone(&col_b), option_asc)], + vec![PhysicalSortExpr::new(Arc::clone(&col_d), option_asc)], ]); let test_cases = vec![ // d + b ( - Arc::new(BinaryExpr::new( - Arc::clone(col_d), - Operator::Plus, - Arc::clone(col_b), - )) as Arc, + Arc::new(BinaryExpr::new(col_d, Operator::Plus, Arc::clone(&col_b))) as _, SortProperties::Ordered(option_asc), ), // b - (Arc::clone(col_b), SortProperties::Ordered(option_asc)), + (col_b, SortProperties::Ordered(option_asc)), // a - (Arc::clone(col_a), SortProperties::Ordered(option_asc)), + (Arc::clone(&col_a), SortProperties::Ordered(option_asc)), // a + c ( - Arc::new(BinaryExpr::new( - Arc::clone(col_a), - Operator::Plus, - Arc::clone(col_c), - )), + Arc::new(BinaryExpr::new(col_a, Operator::Plus, col_c)), SortProperties::Unordered, ), ]; @@ -757,14 +652,14 @@ mod tests { let leading_orderings = eq_properties .oeq_class() .iter() - .flat_map(|ordering| ordering.first().cloned()) + .map(|ordering| ordering.first().clone()) .collect::>(); let expr_props = eq_properties.get_expr_properties(Arc::clone(&expr)); let err_msg = format!( "expr:{:?}, expected: {:?}, actual: {:?}, leading_orderings: {leading_orderings:?}", expr, expected, expr_props.sort_properties ); - assert_eq!(expr_props.sort_properties, expected, "{}", err_msg); + assert_eq!(expr_props.sort_properties, expected, "{err_msg}"); } Ok(()) @@ -790,7 +685,7 @@ mod tests { Arc::clone(col_a), Operator::Plus, Arc::clone(col_d), - )) as Arc; + )) as _; let option_asc = SortOptions { descending: false, @@ -801,16 +696,10 @@ mod tests { nulls_first: true, }; // [d ASC, h DESC] also satisfies schema. - eq_properties.add_new_orderings([LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(col_d), - options: option_asc, - }, - PhysicalSortExpr { - expr: Arc::clone(col_h), - options: option_desc, - }, - ])]); + eq_properties.add_ordering([ + PhysicalSortExpr::new(Arc::clone(col_d), option_asc), + PhysicalSortExpr::new(Arc::clone(col_h), option_desc), + ]); let test_cases = vec![ // TEST CASE 1 (vec![col_a], vec![(col_a, option_asc)]), @@ -878,7 +767,7 @@ mod tests { for (exprs, expected) in test_cases { let exprs = exprs.into_iter().cloned().collect::>(); let expected = convert_to_sort_exprs(&expected); - let (actual, _) = eq_properties.find_longest_permutation(&exprs); + let (actual, _) = eq_properties.find_longest_permutation(&exprs)?; assert_eq!(actual, expected); } @@ -896,7 +785,7 @@ mod tests { let col_h = &col("h", &test_schema)?; // Add column h as constant - eq_properties = eq_properties.with_constants(vec![ConstExpr::from(col_h)]); + eq_properties.add_constants(vec![ConstExpr::from(Arc::clone(col_h))])?; let test_cases = vec![ // TEST CASE 1 @@ -907,72 +796,13 @@ mod tests { for (exprs, expected) in test_cases { let exprs = exprs.into_iter().cloned().collect::>(); let expected = convert_to_sort_exprs(&expected); - let (actual, _) = eq_properties.find_longest_permutation(&exprs); + let (actual, _) = eq_properties.find_longest_permutation(&exprs)?; assert_eq!(actual, expected); } Ok(()) } - #[test] - fn test_get_finer() -> Result<()> { - let schema = create_test_schema()?; - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let col_c = &col("c", &schema)?; - let eq_properties = EquivalenceProperties::new(schema); - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - // First entry, and second entry are the physical sort requirement that are argument for get_finer_requirement. - // Third entry is the expected result. - let tests_cases = vec![ - // Get finer requirement between [a Some(ASC)] and [a None, b Some(ASC)] - // result should be [a Some(ASC), b Some(ASC)] - ( - vec![(col_a, Some(option_asc))], - vec![(col_a, None), (col_b, Some(option_asc))], - Some(vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))]), - ), - // Get finer requirement between [a Some(ASC), b Some(ASC), c Some(ASC)] and [a Some(ASC), b Some(ASC)] - // result should be [a Some(ASC), b Some(ASC), c Some(ASC)] - ( - vec![ - (col_a, Some(option_asc)), - (col_b, Some(option_asc)), - (col_c, Some(option_asc)), - ], - vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], - Some(vec![ - (col_a, Some(option_asc)), - (col_b, Some(option_asc)), - (col_c, Some(option_asc)), - ]), - ), - // Get finer requirement between [a Some(ASC), b Some(ASC)] and [a Some(ASC), b Some(DESC)] - // result should be None - ( - vec![(col_a, Some(option_asc)), (col_b, Some(option_asc))], - vec![(col_a, Some(option_asc)), (col_b, Some(option_desc))], - None, - ), - ]; - for (lhs, rhs, expected) in tests_cases { - let lhs = convert_to_sort_reqs(&lhs); - let rhs = convert_to_sort_reqs(&rhs); - let expected = expected.map(|expected| convert_to_sort_reqs(&expected)); - let finer = eq_properties.get_finer_requirement(&lhs, &rhs); - assert_eq!(finer, expected) - } - - Ok(()) - } - #[test] fn test_normalize_sort_reqs() -> Result<()> { // Schema satisfies following properties @@ -1040,7 +870,7 @@ mod tests { let expected_normalized = convert_to_sort_reqs(&expected_normalized); assert_eq!( - eq_properties.normalize_sort_requirements(&req), + eq_properties.normalize_sort_requirements(req).unwrap(), expected_normalized ); } @@ -1073,8 +903,9 @@ mod tests { for (reqs, expected) in test_cases.into_iter() { let reqs = convert_to_sort_reqs(&reqs); let expected = convert_to_sort_reqs(&expected); - - let normalized = eq_properties.normalize_sort_requirements(&reqs); + let normalized = eq_properties + .normalize_sort_requirements(reqs.clone()) + .unwrap(); assert!( expected.eq(&normalized), "error in test: reqs: {reqs:?}, expected: {expected:?}, normalized: {normalized:?}" @@ -1091,21 +922,12 @@ mod tests { Field::new("b", DataType::Utf8, true), Field::new("c", DataType::Timestamp(TimeUnit::Nanosecond, None), true), ])); - let base_properties = EquivalenceProperties::new(Arc::clone(&schema)) - .with_reorder(LexOrdering::new( - ["a", "b", "c"] - .into_iter() - .map(|c| { - col(c, schema.as_ref()).map(|expr| PhysicalSortExpr { - expr, - options: SortOptions { - descending: false, - nulls_first: true, - }, - }) - }) - .collect::>>()?, - )); + let mut base_properties = EquivalenceProperties::new(Arc::clone(&schema)); + base_properties.reorder( + ["a", "b", "c"] + .into_iter() + .map(|c| PhysicalSortExpr::new_default(col(c, schema.as_ref()).unwrap())), + )?; struct TestCase { name: &'static str, @@ -1118,17 +940,14 @@ mod tests { let col_a = col("a", schema.as_ref())?; let col_b = col("b", schema.as_ref())?; let col_c = col("c", schema.as_ref())?; - let cast_c = Arc::new(CastExpr::new(col_c, DataType::Date32, None)); + let cast_c = Arc::new(CastExpr::new(col_c, DataType::Date32, None)) as _; let cases = vec![ TestCase { name: "(a, b, c) -> (c)", // b is constant, so it should be removed from the sort order constants: vec![Arc::clone(&col_b)], - equal_conditions: vec![[ - Arc::clone(&cast_c) as Arc, - Arc::clone(&col_a), - ]], + equal_conditions: vec![[Arc::clone(&cast_c), Arc::clone(&col_a)]], sort_columns: &["c"], should_satisfy_ordering: true, }, @@ -1138,10 +957,7 @@ mod tests { name: "(a, b, c) -> (c)", // b is constant, so it should be removed from the sort order constants: vec![col_b], - equal_conditions: vec![[ - Arc::clone(&col_a), - Arc::clone(&cast_c) as Arc, - ]], + equal_conditions: vec![[Arc::clone(&col_a), Arc::clone(&cast_c)]], sort_columns: &["c"], should_satisfy_ordering: true, }, @@ -1150,10 +966,7 @@ mod tests { // b is not constant anymore constants: vec![], // a and c are still compatible, but this is irrelevant since the original ordering is (a, b, c) - equal_conditions: vec![[ - Arc::clone(&cast_c) as Arc, - Arc::clone(&col_a), - ]], + equal_conditions: vec![[Arc::clone(&cast_c), Arc::clone(&col_a)]], sort_columns: &["c"], should_satisfy_ordering: false, }, @@ -1167,19 +980,21 @@ mod tests { // Equal conditions before constants { let mut properties = base_properties.clone(); - for [left, right] in &case.equal_conditions { + for [left, right] in case.equal_conditions.clone() { properties.add_equal_conditions(left, right)? } - properties.with_constants( + properties.add_constants( case.constants.iter().cloned().map(ConstExpr::from), - ) + )?; + properties }, // Constants before equal conditions { - let mut properties = base_properties.clone().with_constants( + let mut properties = base_properties.clone(); + properties.add_constants( case.constants.iter().cloned().map(ConstExpr::from), - ); - for [left, right] in &case.equal_conditions { + )?; + for [left, right] in case.equal_conditions { properties.add_equal_conditions(left, right)? } properties @@ -1188,16 +1003,11 @@ mod tests { let sort = case .sort_columns .iter() - .map(|&name| { - col(name, &schema).map(|col| PhysicalSortExpr { - expr: col, - options: SortOptions::default(), - }) - }) - .collect::>()?; + .map(|&name| col(name, &schema).map(PhysicalSortExpr::new_default)) + .collect::>>()?; assert_eq!( - properties.ordering_satisfy(sort.as_ref()), + properties.ordering_satisfy(sort)?, case.should_satisfy_ordering, "failed test '{}'", case.name @@ -1224,31 +1034,29 @@ mod tests { "concat", concat(), vec![Arc::clone(&col_a), Arc::clone(&col_b)], - DataType::Utf8, + Field::new("f", DataType::Utf8, true).into(), )); // Assume existing ordering is [c ASC, a ASC, b ASC] let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); - eq_properties.add_new_ordering(LexOrdering::from(vec![ + eq_properties.add_ordering([ PhysicalSortExpr::new_default(Arc::clone(&col_c)).asc(), PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), - ])); + ]); // Add equality condition c = concat(a, b) - eq_properties.add_equal_conditions(&col_c, &a_concat_b)?; + eq_properties.add_equal_conditions(Arc::clone(&col_c), a_concat_b)?; let orderings = eq_properties.oeq_class(); - let expected_ordering1 = - LexOrdering::from(vec![ - PhysicalSortExpr::new_default(Arc::clone(&col_c)).asc() - ]); - let expected_ordering2 = LexOrdering::from(vec![ - PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), - PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), - ]); + let expected_ordering1 = [PhysicalSortExpr::new_default(col_c).asc()].into(); + let expected_ordering2 = [ + PhysicalSortExpr::new_default(col_a).asc(), + PhysicalSortExpr::new_default(col_b).asc(), + ] + .into(); // The ordering should be [c ASC] and [a ASC, b ASC] assert_eq!(orderings.len(), 2); @@ -1270,25 +1078,26 @@ mod tests { let col_b = col("b", &schema)?; let col_c = col("c", &schema)?; - let a_times_b: Arc = Arc::new(BinaryExpr::new( + let a_times_b = Arc::new(BinaryExpr::new( Arc::clone(&col_a), Operator::Multiply, Arc::clone(&col_b), - )); + )) as _; // Assume existing ordering is [c ASC, a ASC, b ASC] let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); - let initial_ordering = LexOrdering::from(vec![ + let initial_ordering: LexOrdering = [ PhysicalSortExpr::new_default(Arc::clone(&col_c)).asc(), - PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), - PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), - ]); + PhysicalSortExpr::new_default(col_a).asc(), + PhysicalSortExpr::new_default(col_b).asc(), + ] + .into(); - eq_properties.add_new_ordering(initial_ordering.clone()); + eq_properties.add_ordering(initial_ordering.clone()); // Add equality condition c = a * b - eq_properties.add_equal_conditions(&col_c, &a_times_b)?; + eq_properties.add_equal_conditions(col_c, a_times_b)?; let orderings = eq_properties.oeq_class(); @@ -1311,37 +1120,35 @@ mod tests { let col_b = col("b", &schema)?; let col_c = col("c", &schema)?; - let a_concat_b: Arc = Arc::new(ScalarFunctionExpr::new( + let a_concat_b = Arc::new(ScalarFunctionExpr::new( "concat", concat(), vec![Arc::clone(&col_a), Arc::clone(&col_b)], - DataType::Utf8, - )); + Field::new("f", DataType::Utf8, true).into(), + )) as _; // Assume existing ordering is [concat(a, b) ASC, a ASC, b ASC] let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); - eq_properties.add_new_ordering(LexOrdering::from(vec![ + eq_properties.add_ordering([ PhysicalSortExpr::new_default(Arc::clone(&a_concat_b)).asc(), PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), - ])); + ]); // Add equality condition c = concat(a, b) - eq_properties.add_equal_conditions(&col_c, &a_concat_b)?; + eq_properties.add_equal_conditions(col_c, Arc::clone(&a_concat_b))?; let orderings = eq_properties.oeq_class(); - let expected_ordering1 = LexOrdering::from(vec![PhysicalSortExpr::new_default( - Arc::clone(&a_concat_b), - ) - .asc()]); - let expected_ordering2 = LexOrdering::from(vec![ - PhysicalSortExpr::new_default(Arc::clone(&col_a)).asc(), - PhysicalSortExpr::new_default(Arc::clone(&col_b)).asc(), - ]); + let expected_ordering1 = [PhysicalSortExpr::new_default(a_concat_b).asc()].into(); + let expected_ordering2 = [ + PhysicalSortExpr::new_default(col_a).asc(), + PhysicalSortExpr::new_default(col_b).asc(), + ] + .into(); - // The ordering should be [concat(a, b) ASC] and [a ASC, b ASC] + // The ordering should be [c ASC] and [a ASC, b ASC] assert_eq!(orderings.len(), 2); assert!(orderings.contains(&expected_ordering1)); assert!(orderings.contains(&expected_ordering2)); @@ -1349,6 +1156,35 @@ mod tests { Ok(()) } + #[test] + fn test_requirements_compatible() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ])); + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let col_c = col("c", &schema)?; + + let eq_properties = EquivalenceProperties::new(schema); + let lex_a: LexRequirement = + [PhysicalSortRequirement::new(Arc::clone(&col_a), None)].into(); + let lex_a_b: LexRequirement = [ + PhysicalSortRequirement::new(col_a, None), + PhysicalSortRequirement::new(col_b, None), + ] + .into(); + let lex_c = [PhysicalSortRequirement::new(col_c, None)].into(); + + assert!(eq_properties.requirements_compatible(lex_a.clone(), lex_a.clone())); + assert!(!eq_properties.requirements_compatible(lex_a.clone(), lex_a_b.clone())); + assert!(eq_properties.requirements_compatible(lex_a_b, lex_a.clone())); + assert!(!eq_properties.requirements_compatible(lex_c, lex_a)); + + Ok(()) + } + #[test] fn test_with_reorder_constant_filtering() -> Result<()> { let schema = create_test_schema()?; @@ -1357,26 +1193,21 @@ mod tests { // Setup constant columns let col_a = col("a", &schema)?; let col_b = col("b", &schema)?; - eq_properties = eq_properties.with_constants([ConstExpr::from(&col_a)]); + eq_properties.add_constants([ConstExpr::from(Arc::clone(&col_a))])?; - let sort_exprs = LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::clone(&col_b), - options: SortOptions::default(), - }, - ]); + let sort_exprs = vec![ + PhysicalSortExpr::new_default(Arc::clone(&col_a)), + PhysicalSortExpr::new_default(Arc::clone(&col_b)), + ]; - let result = eq_properties.with_reorder(sort_exprs); + let change = eq_properties.reorder(sort_exprs)?; + assert!(change); - // Should only contain b since a is constant - assert_eq!(result.oeq_class().len(), 1); - let ordering = result.oeq_class().iter().next().unwrap(); - assert_eq!(ordering.len(), 1); - assert!(ordering[0].expr.eq(&col_b)); + assert_eq!(eq_properties.oeq_class().len(), 1); + let ordering = eq_properties.oeq_class().iter().next().unwrap(); + assert_eq!(ordering.len(), 2); + assert!(ordering[0].expr.eq(&col_a)); + assert!(ordering[1].expr.eq(&col_b)); Ok(()) } @@ -1397,32 +1228,21 @@ mod tests { }; // Initial ordering: [a ASC, b DESC, c ASC] - eq_properties.add_new_orderings([LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: asc, - }, - PhysicalSortExpr { - expr: Arc::clone(&col_b), - options: desc, - }, - PhysicalSortExpr { - expr: Arc::clone(&col_c), - options: asc, - }, - ])]); + eq_properties.add_ordering([ + PhysicalSortExpr::new(Arc::clone(&col_a), asc), + PhysicalSortExpr::new(Arc::clone(&col_b), desc), + PhysicalSortExpr::new(Arc::clone(&col_c), asc), + ]); // New ordering: [a ASC] - let new_order = LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: asc, - }]); + let new_order = vec![PhysicalSortExpr::new(Arc::clone(&col_a), asc)]; - let result = eq_properties.with_reorder(new_order); + let change = eq_properties.reorder(new_order)?; + assert!(!change); // Should only contain [a ASC, b DESC, c ASC] - assert_eq!(result.oeq_class().len(), 1); - let ordering = result.oeq_class().iter().next().unwrap(); + assert_eq!(eq_properties.oeq_class().len(), 1); + let ordering = eq_properties.oeq_class().iter().next().unwrap(); assert_eq!(ordering.len(), 3); assert!(ordering[0].expr.eq(&col_a)); assert!(ordering[0].options.eq(&asc)); @@ -1444,37 +1264,28 @@ mod tests { let col_c = col("c", &schema)?; // Make a and b equivalent - eq_properties.add_equal_conditions(&col_a, &col_b)?; - - let asc = SortOptions::default(); + eq_properties.add_equal_conditions(Arc::clone(&col_a), Arc::clone(&col_b))?; // Initial ordering: [a ASC, c ASC] - eq_properties.add_new_orderings([LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: asc, - }, - PhysicalSortExpr { - expr: Arc::clone(&col_c), - options: asc, - }, - ])]); + eq_properties.add_ordering([ + PhysicalSortExpr::new_default(Arc::clone(&col_a)), + PhysicalSortExpr::new_default(Arc::clone(&col_c)), + ]); // New ordering: [b ASC] - let new_order = LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_b), - options: asc, - }]); + let new_order = vec![PhysicalSortExpr::new_default(Arc::clone(&col_b))]; - let result = eq_properties.with_reorder(new_order); + let change = eq_properties.reorder(new_order)?; - // Should only contain [b ASC, c ASC] - assert_eq!(result.oeq_class().len(), 1); + assert!(!change); + // Should only contain [a/b ASC, c ASC] + assert_eq!(eq_properties.oeq_class().len(), 1); // Verify orderings - let ordering = result.oeq_class().iter().next().unwrap(); + let asc = SortOptions::default(); + let ordering = eq_properties.oeq_class().iter().next().unwrap(); assert_eq!(ordering.len(), 2); - assert!(ordering[0].expr.eq(&col_b)); + assert!(ordering[0].expr.eq(&col_a) || ordering[0].expr.eq(&col_b)); assert!(ordering[0].options.eq(&asc)); assert!(ordering[1].expr.eq(&col_c)); assert!(ordering[1].options.eq(&asc)); @@ -1497,29 +1308,21 @@ mod tests { }; // Initial ordering: [a ASC, b DESC] - eq_properties.add_new_orderings([LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: asc, - }, - PhysicalSortExpr { - expr: Arc::clone(&col_b), - options: desc, - }, - ])]); + eq_properties.add_ordering([ + PhysicalSortExpr::new(Arc::clone(&col_a), asc), + PhysicalSortExpr::new(Arc::clone(&col_b), desc), + ]); // New ordering: [a DESC] - let new_order = LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: desc, - }]); + let new_order = vec![PhysicalSortExpr::new(Arc::clone(&col_a), desc)]; - let result = eq_properties.with_reorder(new_order.clone()); + let change = eq_properties.reorder(new_order.clone())?; + assert!(change); // Should only contain the new ordering since options don't match - assert_eq!(result.oeq_class().len(), 1); - let ordering = result.oeq_class().iter().next().unwrap(); - assert_eq!(ordering, &new_order); + assert_eq!(eq_properties.oeq_class().len(), 1); + let ordering = eq_properties.oeq_class().iter().next().unwrap(); + assert_eq!(ordering.to_vec(), new_order); Ok(()) } @@ -1535,62 +1338,32 @@ mod tests { let col_d = col("d", &schema)?; let col_e = col("e", &schema)?; - let asc = SortOptions::default(); - // Constants: c is constant - eq_properties = eq_properties.with_constants([ConstExpr::from(&col_c)]); + eq_properties.add_constants([ConstExpr::from(Arc::clone(&col_c))])?; // Equality: b = d - eq_properties.add_equal_conditions(&col_b, &col_d)?; + eq_properties.add_equal_conditions(Arc::clone(&col_b), Arc::clone(&col_d))?; // Orderings: [d ASC, a ASC], [e ASC] - eq_properties.add_new_orderings([ - LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(&col_d), - options: asc, - }, - PhysicalSortExpr { - expr: Arc::clone(&col_a), - options: asc, - }, - ]), - LexOrdering::new(vec![PhysicalSortExpr { - expr: Arc::clone(&col_e), - options: asc, - }]), + eq_properties.add_orderings([ + vec![ + PhysicalSortExpr::new_default(Arc::clone(&col_d)), + PhysicalSortExpr::new_default(Arc::clone(&col_a)), + ], + vec![PhysicalSortExpr::new_default(Arc::clone(&col_e))], ]); - // Initial ordering: [b ASC, c ASC] - let new_order = LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(&col_b), - options: asc, - }, - PhysicalSortExpr { - expr: Arc::clone(&col_c), - options: asc, - }, - ]); - - let result = eq_properties.with_reorder(new_order); - - // Should preserve the original [d ASC, a ASC] ordering - assert_eq!(result.oeq_class().len(), 1); - let ordering = result.oeq_class().iter().next().unwrap(); - assert_eq!(ordering.len(), 2); - - // First expression should be either b or d (they're equivalent) - assert!( - ordering[0].expr.eq(&col_b) || ordering[0].expr.eq(&col_d), - "Expected b or d as first expression, got {:?}", - ordering[0].expr - ); - assert!(ordering[0].options.eq(&asc)); + // New ordering: [b ASC, c ASC] + let new_order = vec![ + PhysicalSortExpr::new_default(Arc::clone(&col_b)), + PhysicalSortExpr::new_default(Arc::clone(&col_c)), + ]; - // Second expression should be a - assert!(ordering[1].expr.eq(&col_a)); - assert!(ordering[1].options.eq(&asc)); + let old_orderings = eq_properties.oeq_class().clone(); + let change = eq_properties.reorder(new_order)?; + // Original orderings should be preserved: + assert!(!change); + assert_eq!(eq_properties.oeq_class, old_orderings); Ok(()) } @@ -1691,81 +1464,62 @@ mod tests { { let mut eq_properties = EquivalenceProperties::new(Arc::clone(schema)); - // Convert base ordering - let base_ordering = LexOrdering::new( - base_order - .iter() - .map(|col_name| PhysicalSortExpr { - expr: col(col_name, schema).unwrap(), - options: SortOptions::default(), - }) - .collect(), - ); - // Convert string column names to orderings - let satisfied_orderings: Vec = satisfied_orders + let satisfied_orderings: Vec<_> = satisfied_orders .iter() .map(|cols| { - LexOrdering::new( - cols.iter() - .map(|col_name| PhysicalSortExpr { - expr: col(col_name, schema).unwrap(), - options: SortOptions::default(), - }) - .collect(), - ) + cols.iter() + .map(|col_name| { + PhysicalSortExpr::new_default(col(col_name, schema).unwrap()) + }) + .collect::>() }) .collect(); - let unsatisfied_orderings: Vec = unsatisfied_orders + let unsatisfied_orderings: Vec<_> = unsatisfied_orders .iter() .map(|cols| { - LexOrdering::new( - cols.iter() - .map(|col_name| PhysicalSortExpr { - expr: col(col_name, schema).unwrap(), - options: SortOptions::default(), - }) - .collect(), - ) + cols.iter() + .map(|col_name| { + PhysicalSortExpr::new_default(col(col_name, schema).unwrap()) + }) + .collect::>() }) .collect(); // Test that orderings are not satisfied before adding constraints - for ordering in &satisfied_orderings { - assert!( - !eq_properties.ordering_satisfy(ordering), - "{}: ordering {:?} should not be satisfied before adding constraints", - name, - ordering + for ordering in satisfied_orderings.clone() { + let err_msg = format!( + "{name}: ordering {ordering:?} should not be satisfied before adding constraints", ); + assert!(!eq_properties.ordering_satisfy(ordering)?, "{err_msg}"); } // Add base ordering - eq_properties.add_new_ordering(base_ordering); + let base_ordering = base_order.iter().map(|col_name| PhysicalSortExpr { + expr: col(col_name, schema).unwrap(), + options: SortOptions::default(), + }); + eq_properties.add_ordering(base_ordering); // Add constraints eq_properties = eq_properties.with_constraints(Constraints::new_unverified(constraints)); // Test that expected orderings are now satisfied - for ordering in &satisfied_orderings { - assert!( - eq_properties.ordering_satisfy(ordering), - "{}: ordering {:?} should be satisfied after adding constraints", - name, - ordering + for ordering in satisfied_orderings { + let err_msg = format!( + "{name}: ordering {ordering:?} should be satisfied after adding constraints", ); + assert!(eq_properties.ordering_satisfy(ordering)?, "{err_msg}"); } // Test that unsatisfied orderings remain unsatisfied - for ordering in &unsatisfied_orderings { - assert!( - !eq_properties.ordering_satisfy(ordering), - "{}: ordering {:?} should not be satisfied after adding constraints", - name, - ordering + for ordering in unsatisfied_orderings { + let err_msg = format!( + "{name}: ordering {ordering:?} should not be satisfied after adding constraints", ); + assert!(!eq_properties.ordering_satisfy(ordering)?, "{err_msg}"); } } diff --git a/datafusion/physical-expr/src/equivalence/properties/joins.rs b/datafusion/physical-expr/src/equivalence/properties/joins.rs index 7944e89d0305..9329ce56b7d6 100644 --- a/datafusion/physical-expr/src/equivalence/properties/joins.rs +++ b/datafusion/physical-expr/src/equivalence/properties/joins.rs @@ -15,11 +15,11 @@ // specific language governing permissions and limitations // under the License. +use super::EquivalenceProperties; use crate::{equivalence::OrderingEquivalenceClass, PhysicalExprRef}; -use arrow::datatypes::SchemaRef; -use datafusion_common::{JoinSide, JoinType}; -use super::EquivalenceProperties; +use arrow::datatypes::SchemaRef; +use datafusion_common::{JoinSide, JoinType, Result}; /// Calculate ordering equivalence properties for the given join operation. pub fn join_equivalence_properties( @@ -30,7 +30,7 @@ pub fn join_equivalence_properties( maintains_input_order: &[bool], probe_side: Option, on: &[(PhysicalExprRef, PhysicalExprRef)], -) -> EquivalenceProperties { +) -> Result { let left_size = left.schema.fields.len(); let mut result = EquivalenceProperties::new(join_schema); result.add_equivalence_group(left.eq_group().join( @@ -38,15 +38,13 @@ pub fn join_equivalence_properties( join_type, left_size, on, - )); + )?)?; let EquivalenceProperties { - constants: left_constants, oeq_class: left_oeq_class, .. } = left; let EquivalenceProperties { - constants: right_constants, oeq_class: mut right_oeq_class, .. } = right; @@ -59,7 +57,7 @@ pub fn join_equivalence_properties( &mut right_oeq_class, join_type, left_size, - ); + )?; // Right side ordering equivalence properties should be prepended // with those of the left side while constructing output ordering @@ -70,9 +68,9 @@ pub fn join_equivalence_properties( // then we should add `a ASC, b ASC` to the ordering equivalences // of the join output. let out_oeq_class = left_oeq_class.join_suffix(&right_oeq_class); - result.add_ordering_equivalence_class(out_oeq_class); + result.add_orderings(out_oeq_class); } else { - result.add_ordering_equivalence_class(left_oeq_class); + result.add_orderings(left_oeq_class); } } [false, true] => { @@ -80,7 +78,7 @@ pub fn join_equivalence_properties( &mut right_oeq_class, join_type, left_size, - ); + )?; // In this special case, left side ordering can be prefixed with // the right side ordering. if let (Some(JoinSide::Right), JoinType::Inner) = (probe_side, join_type) { @@ -93,25 +91,16 @@ pub fn join_equivalence_properties( // then we should add `b ASC, a ASC` to the ordering equivalences // of the join output. let out_oeq_class = right_oeq_class.join_suffix(&left_oeq_class); - result.add_ordering_equivalence_class(out_oeq_class); + result.add_orderings(out_oeq_class); } else { - result.add_ordering_equivalence_class(right_oeq_class); + result.add_orderings(right_oeq_class); } } [false, false] => {} [true, true] => unreachable!("Cannot maintain ordering of both sides"), _ => unreachable!("Join operators can not have more than two children"), } - match join_type { - JoinType::LeftAnti | JoinType::LeftSemi => { - result = result.with_constants(left_constants); - } - JoinType::RightAnti | JoinType::RightSemi => { - result = result.with_constants(right_constants); - } - _ => {} - } - result + Ok(result) } /// In the context of a join, update the right side `OrderingEquivalenceClass` @@ -125,28 +114,29 @@ pub fn updated_right_ordering_equivalence_class( right_oeq_class: &mut OrderingEquivalenceClass, join_type: &JoinType, left_size: usize, -) { +) -> Result<()> { if matches!( join_type, JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right ) { - right_oeq_class.add_offset(left_size); + right_oeq_class.add_offset(left_size as _)?; } + Ok(()) } #[cfg(test)] mod tests { - use std::sync::Arc; use super::*; - use crate::equivalence::add_offset_to_expr; - use crate::equivalence::tests::{convert_to_orderings, create_test_schema}; + use crate::equivalence::convert_to_orderings; + use crate::equivalence::tests::create_test_schema; use crate::expressions::col; - use datafusion_common::Result; + use crate::physical_expr::add_offset_to_expr; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Fields, Schema}; + use datafusion_common::Result; #[test] fn test_join_equivalence_properties() -> Result<()> { @@ -154,9 +144,9 @@ mod tests { let col_a = &col("a", &schema)?; let col_b = &col("b", &schema)?; let col_c = &col("c", &schema)?; - let offset = schema.fields.len(); - let col_a2 = &add_offset_to_expr(Arc::clone(col_a), offset); - let col_b2 = &add_offset_to_expr(Arc::clone(col_b), offset); + let offset = schema.fields.len() as _; + let col_a2 = &add_offset_to_expr(Arc::clone(col_a), offset)?; + let col_b2 = &add_offset_to_expr(Arc::clone(col_b), offset)?; let option_asc = SortOptions { descending: false, nulls_first: false, @@ -205,8 +195,8 @@ mod tests { let left_orderings = convert_to_orderings(&left_orderings); let right_orderings = convert_to_orderings(&right_orderings); let expected = convert_to_orderings(&expected); - left_eq_properties.add_new_orderings(left_orderings); - right_eq_properties.add_new_orderings(right_orderings); + left_eq_properties.add_orderings(left_orderings); + right_eq_properties.add_orderings(right_orderings); let join_eq = join_equivalence_properties( left_eq_properties, right_eq_properties, @@ -215,16 +205,14 @@ mod tests { &[true, false], Some(JoinSide::Left), &[], - ); + )?; let err_msg = format!("expected: {:?}, actual:{:?}", expected, &join_eq.oeq_class); - assert_eq!(join_eq.oeq_class.len(), expected.len(), "{}", err_msg); + assert_eq!(join_eq.oeq_class.len(), expected.len(), "{err_msg}"); for ordering in join_eq.oeq_class { assert!( expected.contains(&ordering), - "{}, ordering: {:?}", - err_msg, - ordering + "{err_msg}, ordering: {ordering:?}" ); } } @@ -255,7 +243,7 @@ mod tests { ]; let orderings = convert_to_orderings(&orderings); // Right child ordering equivalences - let mut right_oeq_class = OrderingEquivalenceClass::new(orderings); + let mut right_oeq_class = OrderingEquivalenceClass::from(orderings); let left_columns_len = 4; @@ -266,24 +254,24 @@ mod tests { // Join Schema let schema = Schema::new(fields); - let col_a = &col("a", &schema)?; - let col_d = &col("d", &schema)?; - let col_x = &col("x", &schema)?; - let col_y = &col("y", &schema)?; - let col_z = &col("z", &schema)?; - let col_w = &col("w", &schema)?; + let col_a = col("a", &schema)?; + let col_d = col("d", &schema)?; + let col_x = col("x", &schema)?; + let col_y = col("y", &schema)?; + let col_z = col("z", &schema)?; + let col_w = col("w", &schema)?; let mut join_eq_properties = EquivalenceProperties::new(Arc::new(schema)); // a=x and d=w - join_eq_properties.add_equal_conditions(col_a, col_x)?; - join_eq_properties.add_equal_conditions(col_d, col_w)?; + join_eq_properties.add_equal_conditions(col_a, Arc::clone(&col_x))?; + join_eq_properties.add_equal_conditions(col_d, Arc::clone(&col_w))?; updated_right_ordering_equivalence_class( &mut right_oeq_class, &join_type, left_columns_len, - ); - join_eq_properties.add_ordering_equivalence_class(right_oeq_class); + )?; + join_eq_properties.add_orderings(right_oeq_class); let result = join_eq_properties.oeq_class().clone(); // [x ASC, y ASC], [z ASC, w ASC] @@ -292,7 +280,7 @@ mod tests { vec![(col_z, option_asc), (col_w, option_asc)], ]; let orderings = convert_to_orderings(&orderings); - let expected = OrderingEquivalenceClass::new(orderings); + let expected = OrderingEquivalenceClass::from(orderings); assert_eq!(result, expected); diff --git a/datafusion/physical-expr/src/equivalence/properties/mod.rs b/datafusion/physical-expr/src/equivalence/properties/mod.rs index 5b34a02a9142..6d18d34ca4de 100644 --- a/datafusion/physical-expr/src/equivalence/properties/mod.rs +++ b/datafusion/physical-expr/src/equivalence/properties/mod.rs @@ -19,47 +19,43 @@ mod dependency; // Submodule containing DependencyMap and Dependencies mod joins; // Submodule containing join_equivalence_properties mod union; // Submodule containing calculate_union -use dependency::{ - construct_prefix_orderings, generate_dependency_orderings, referred_dependencies, - Dependencies, DependencyMap, -}; pub use joins::*; pub use union::*; -use std::fmt::Display; -use std::hash::{Hash, Hasher}; +use std::fmt::{self, Display}; +use std::mem; use std::sync::Arc; -use std::{fmt, mem}; -use crate::equivalence::class::{const_exprs_contains, AcrossPartitions}; +use self::dependency::{ + construct_prefix_orderings, generate_dependency_orderings, referred_dependencies, + Dependencies, DependencyMap, +}; use crate::equivalence::{ - EquivalenceClass, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, + AcrossPartitions, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, }; use crate::expressions::{with_new_schema, CastExpr, Column, Literal}; use crate::{ - physical_exprs_contains, ConstExpr, LexOrdering, LexRequirement, PhysicalExpr, - PhysicalSortExpr, PhysicalSortRequirement, + ConstExpr, LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, + PhysicalSortRequirement, }; -use arrow::compute::SortOptions; use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{plan_err, Constraint, Constraints, HashMap, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_physical_expr_common::sort_expr::options_compatible; use datafusion_physical_expr_common::utils::ExprPropertiesNode; use indexmap::IndexSet; use itertools::Itertools; -/// `EquivalenceProperties` stores information about the output -/// of a plan node, that can be used to optimize the plan. -/// -/// Currently, it keeps track of: -/// - Sort expressions (orderings) -/// - Equivalent expressions: expressions that are known to have same value. -/// - Constants expressions: expressions that are known to contain a single -/// constant value. +/// `EquivalenceProperties` stores information about the output of a plan node +/// that can be used to optimize the plan. Currently, it keeps track of: +/// - Sort expressions (orderings), +/// - Equivalent expressions; i.e. expressions known to have the same value. +/// - Constants expressions; i.e. expressions known to contain a single constant +/// value. /// /// Please see the [Using Ordering for Better Plans] blog for more details. /// @@ -81,8 +77,8 @@ use itertools::Itertools; /// ``` /// /// In this case, both `a ASC` and `b DESC` can describe the table ordering. -/// `EquivalenceProperties`, tracks these different valid sort expressions and -/// treat `a ASC` and `b DESC` on an equal footing. For example if the query +/// `EquivalenceProperties` tracks these different valid sort expressions and +/// treat `a ASC` and `b DESC` on an equal footing. For example, if the query /// specifies the output sorted by EITHER `a ASC` or `b DESC`, the sort can be /// avoided. /// @@ -101,12 +97,11 @@ use itertools::Itertools; /// └---┴---┘ /// ``` /// -/// In this case, columns `a` and `b` always have the same value, which can of -/// such equivalences inside this object. With this information, Datafusion can -/// optimize operations such as. For example, if the partition requirement is -/// `Hash(a)` and output partitioning is `Hash(b)`, then DataFusion avoids -/// repartitioning the data as the existing partitioning satisfies the -/// requirement. +/// In this case, columns `a` and `b` always have the same value. With this +/// information, Datafusion can optimize various operations. For example, if +/// the partition requirement is `Hash(a)` and output partitioning is +/// `Hash(b)`, then DataFusion avoids repartitioning the data as the existing +/// partitioning satisfies the requirement. /// /// # Code Example /// ``` @@ -125,40 +120,85 @@ use itertools::Itertools; /// # let col_c = col("c", &schema).unwrap(); /// // This object represents data that is sorted by a ASC, c DESC /// // with a single constant value of b -/// let mut eq_properties = EquivalenceProperties::new(schema) -/// .with_constants(vec![ConstExpr::from(col_b)]); -/// eq_properties.add_new_ordering(LexOrdering::new(vec![ +/// let mut eq_properties = EquivalenceProperties::new(schema); +/// eq_properties.add_constants(vec![ConstExpr::from(col_b)]); +/// eq_properties.add_ordering([ /// PhysicalSortExpr::new_default(col_a).asc(), /// PhysicalSortExpr::new_default(col_c).desc(), -/// ])); +/// ]); /// -/// assert_eq!(eq_properties.to_string(), "order: [[a@0 ASC, c@2 DESC]], const: [b@1(heterogeneous)]") +/// assert_eq!(eq_properties.to_string(), "order: [[a@0 ASC, c@2 DESC]], eq: [{members: [b@1], constant: (heterogeneous)}]"); /// ``` -#[derive(Debug, Clone)] +#[derive(Clone, Debug)] pub struct EquivalenceProperties { - /// Distinct equivalence classes (exprs known to have the same expressions) + /// Distinct equivalence classes (i.e. expressions with the same value). eq_group: EquivalenceGroup, - /// Equivalent sort expressions + /// Equivalent sort expressions (i.e. those define the same ordering). oeq_class: OrderingEquivalenceClass, - /// Expressions whose values are constant - /// - /// TODO: We do not need to track constants separately, they can be tracked - /// inside `eq_group` as `Literal` expressions. - constants: Vec, - /// Table constraints + /// Cache storing equivalent sort expressions in normal form (i.e. without + /// constants/duplicates and in standard form) and a map associating leading + /// terms with full sort expressions. + oeq_cache: OrderingEquivalenceCache, + /// Table constraints that factor in equivalence calculations. constraints: Constraints, /// Schema associated with this object. schema: SchemaRef, } +/// This object serves as a cache for storing equivalent sort expressions +/// in normal form, and a map associating leading sort expressions with +/// full lexicographical orderings. With this information, DataFusion can +/// efficiently determine whether a given ordering is satisfied by the +/// existing orderings, and discover new orderings based on the existing +/// equivalence properties. +#[derive(Clone, Debug, Default)] +struct OrderingEquivalenceCache { + /// Equivalent sort expressions in normal form. + normal_cls: OrderingEquivalenceClass, + /// Map associating leading sort expressions with full lexicographical + /// orderings. Values are indices into `normal_cls`. + leading_map: HashMap, Vec>, +} + +impl OrderingEquivalenceCache { + /// Creates a new `OrderingEquivalenceCache` object with the given + /// equivalent orderings, which should be in normal form. + pub fn new( + orderings: impl IntoIterator>, + ) -> Self { + let mut cache = Self { + normal_cls: OrderingEquivalenceClass::new(orderings), + leading_map: HashMap::new(), + }; + cache.update_map(); + cache + } + + /// Updates/reconstructs the leading expression map according to the normal + /// ordering equivalence class within. + pub fn update_map(&mut self) { + self.leading_map.clear(); + for (idx, ordering) in self.normal_cls.iter().enumerate() { + let expr = Arc::clone(&ordering.first().expr); + self.leading_map.entry(expr).or_default().push(idx); + } + } + + /// Clears the cache, removing all orderings and leading expressions. + pub fn clear(&mut self) { + self.normal_cls.clear(); + self.leading_map.clear(); + } +} + impl EquivalenceProperties { /// Creates an empty `EquivalenceProperties` object. pub fn new(schema: SchemaRef) -> Self { Self { - eq_group: EquivalenceGroup::empty(), - oeq_class: OrderingEquivalenceClass::empty(), - constants: vec![], - constraints: Constraints::empty(), + eq_group: EquivalenceGroup::default(), + oeq_class: OrderingEquivalenceClass::default(), + oeq_cache: OrderingEquivalenceCache::default(), + constraints: Constraints::default(), schema, } } @@ -170,12 +210,23 @@ impl EquivalenceProperties { } /// Creates a new `EquivalenceProperties` object with the given orderings. - pub fn new_with_orderings(schema: SchemaRef, orderings: &[LexOrdering]) -> Self { + pub fn new_with_orderings( + schema: SchemaRef, + orderings: impl IntoIterator>, + ) -> Self { + let eq_group = EquivalenceGroup::default(); + let oeq_class = OrderingEquivalenceClass::new(orderings); + // Here, we can avoid performing a full normalization, and get by with + // only removing constants because the equivalence group is empty. + let normal_orderings = oeq_class.iter().cloned().map(|o| { + o.into_iter() + .filter(|sort_expr| eq_group.is_expr_constant(&sort_expr.expr).is_none()) + }); Self { - eq_group: EquivalenceGroup::empty(), - oeq_class: OrderingEquivalenceClass::new(orderings.to_vec()), - constants: vec![], - constraints: Constraints::empty(), + oeq_cache: OrderingEquivalenceCache::new(normal_orderings), + oeq_class, + eq_group, + constraints: Constraints::default(), schema, } } @@ -190,91 +241,126 @@ impl EquivalenceProperties { &self.oeq_class } - /// Return the inner OrderingEquivalenceClass, consuming self - pub fn into_oeq_class(self) -> OrderingEquivalenceClass { - self.oeq_class - } - /// Returns a reference to the equivalence group within. pub fn eq_group(&self) -> &EquivalenceGroup { &self.eq_group } - /// Returns a reference to the constant expressions - pub fn constants(&self) -> &[ConstExpr] { - &self.constants - } - + /// Returns a reference to the constraints within. pub fn constraints(&self) -> &Constraints { &self.constraints } - /// Returns the output ordering of the properties. - pub fn output_ordering(&self) -> Option { - let constants = self.constants(); - let mut output_ordering = self.oeq_class().output_ordering().unwrap_or_default(); - // Prune out constant expressions - output_ordering - .retain(|sort_expr| !const_exprs_contains(constants, &sort_expr.expr)); - (!output_ordering.is_empty()).then_some(output_ordering) + /// Returns all the known constants expressions. + pub fn constants(&self) -> Vec { + self.eq_group + .iter() + .filter_map(|c| { + c.constant.as_ref().and_then(|across| { + c.canonical_expr() + .map(|expr| ConstExpr::new(Arc::clone(expr), across.clone())) + }) + }) + .collect() } - /// Returns the normalized version of the ordering equivalence class within. - /// Normalization removes constants and duplicates as well as standardizing - /// expressions according to the equivalence group within. - pub fn normalized_oeq_class(&self) -> OrderingEquivalenceClass { - OrderingEquivalenceClass::new( - self.oeq_class - .iter() - .map(|ordering| self.normalize_sort_exprs(ordering)) - .collect(), - ) + /// Returns the output ordering of the properties. + pub fn output_ordering(&self) -> Option { + let concat = self.oeq_class.iter().flat_map(|o| o.iter().cloned()); + self.normalize_sort_exprs(concat) } /// Extends this `EquivalenceProperties` with the `other` object. - pub fn extend(mut self, other: Self) -> Self { - self.eq_group.extend(other.eq_group); - self.oeq_class.extend(other.oeq_class); - self.with_constants(other.constants) + pub fn extend(mut self, other: Self) -> Result { + self.constraints.extend(other.constraints); + self.add_equivalence_group(other.eq_group)?; + self.add_orderings(other.oeq_class); + Ok(self) } /// Clears (empties) the ordering equivalence class within this object. /// Call this method when existing orderings are invalidated. pub fn clear_orderings(&mut self) { self.oeq_class.clear(); + self.oeq_cache.clear(); } /// Removes constant expressions that may change across partitions. - /// This method should be used when data from different partitions are merged. + /// This method should be used when merging data from different partitions. pub fn clear_per_partition_constants(&mut self) { - self.constants.retain(|item| { - matches!(item.across_partitions(), AcrossPartitions::Uniform(_)) - }) - } - - /// Extends this `EquivalenceProperties` by adding the orderings inside the - /// ordering equivalence class `other`. - pub fn add_ordering_equivalence_class(&mut self, other: OrderingEquivalenceClass) { - self.oeq_class.extend(other); + if self.eq_group.clear_per_partition_constants() { + // Renormalize orderings if the equivalence group changes: + let normal_orderings = self + .oeq_class + .iter() + .cloned() + .map(|o| self.eq_group.normalize_sort_exprs(o)); + self.oeq_cache = OrderingEquivalenceCache::new(normal_orderings); + } } /// Adds new orderings into the existing ordering equivalence class. - pub fn add_new_orderings( + pub fn add_orderings( &mut self, - orderings: impl IntoIterator, + orderings: impl IntoIterator>, ) { - self.oeq_class.add_new_orderings(orderings); + let orderings: Vec<_> = + orderings.into_iter().filter_map(LexOrdering::new).collect(); + let normal_orderings: Vec<_> = orderings + .iter() + .cloned() + .filter_map(|o| self.normalize_sort_exprs(o)) + .collect(); + if !normal_orderings.is_empty() { + self.oeq_class.extend(orderings); + // Normalize given orderings to update the cache: + self.oeq_cache.normal_cls.extend(normal_orderings); + // TODO: If no ordering is found to be redunant during extension, we + // can use a shortcut algorithm to update the leading map. + self.oeq_cache.update_map(); + } } /// Adds a single ordering to the existing ordering equivalence class. - pub fn add_new_ordering(&mut self, ordering: LexOrdering) { - self.add_new_orderings([ordering]); + pub fn add_ordering(&mut self, ordering: impl IntoIterator) { + self.add_orderings(std::iter::once(ordering)); } /// Incorporates the given equivalence group to into the existing /// equivalence group within. - pub fn add_equivalence_group(&mut self, other_eq_group: EquivalenceGroup) { - self.eq_group.extend(other_eq_group); + pub fn add_equivalence_group( + &mut self, + other_eq_group: EquivalenceGroup, + ) -> Result<()> { + if !other_eq_group.is_empty() { + self.eq_group.extend(other_eq_group); + // Renormalize orderings if the equivalence group changes: + let normal_cls = mem::take(&mut self.oeq_cache.normal_cls); + let normal_orderings = normal_cls + .into_iter() + .map(|o| self.eq_group.normalize_sort_exprs(o)); + self.oeq_cache.normal_cls = OrderingEquivalenceClass::new(normal_orderings); + self.oeq_cache.update_map(); + // Discover any new orderings based on the new equivalence classes: + let leading_exprs: Vec<_> = + self.oeq_cache.leading_map.keys().cloned().collect(); + for expr in leading_exprs { + self.discover_new_orderings(expr)?; + } + } + Ok(()) + } + + /// Returns the ordering equivalence class within in normal form. + /// Normalization standardizes expressions according to the equivalence + /// group within, and removes constants/duplicates. + pub fn normalized_oeq_class(&self) -> OrderingEquivalenceClass { + self.oeq_class + .iter() + .cloned() + .filter_map(|ordering| self.normalize_sort_exprs(ordering)) + .collect::>() + .into() } /// Adds a new equality condition into the existing equivalence group. @@ -282,290 +368,236 @@ impl EquivalenceProperties { /// equivalence class to the equivalence group. pub fn add_equal_conditions( &mut self, - left: &Arc, - right: &Arc, + left: Arc, + right: Arc, ) -> Result<()> { - // Discover new constants in light of new the equality: - if self.is_expr_constant(left) { - // Left expression is constant, add right as constant - if !const_exprs_contains(&self.constants, right) { - let const_expr = ConstExpr::from(right) - .with_across_partitions(self.get_expr_constant_value(left)); - self.constants.push(const_expr); - } - } else if self.is_expr_constant(right) { - // Right expression is constant, add left as constant - if !const_exprs_contains(&self.constants, left) { - let const_expr = ConstExpr::from(left) - .with_across_partitions(self.get_expr_constant_value(right)); - self.constants.push(const_expr); - } + // Add equal expressions to the state: + if self.eq_group.add_equal_conditions(Arc::clone(&left), right) { + // Renormalize orderings if the equivalence group changes: + let normal_cls = mem::take(&mut self.oeq_cache.normal_cls); + let normal_orderings = normal_cls + .into_iter() + .map(|o| self.eq_group.normalize_sort_exprs(o)); + self.oeq_cache.normal_cls = OrderingEquivalenceClass::new(normal_orderings); + self.oeq_cache.update_map(); + // Discover any new orderings: + self.discover_new_orderings(left)?; } - - // Add equal expressions to the state - self.eq_group.add_equal_conditions(left, right); - - // Discover any new orderings - self.discover_new_orderings(left)?; Ok(()) } /// Track/register physical expressions with constant values. - #[deprecated(since = "43.0.0", note = "Use [`with_constants`] instead")] - pub fn add_constants(self, constants: impl IntoIterator) -> Self { - self.with_constants(constants) - } - - /// Remove the specified constant - pub fn remove_constant(mut self, c: &ConstExpr) -> Self { - self.constants.retain(|existing| existing != c); - self - } - - /// Track/register physical expressions with constant values. - pub fn with_constants( - mut self, + pub fn add_constants( + &mut self, constants: impl IntoIterator, - ) -> Self { - let normalized_constants = constants - .into_iter() - .filter_map(|c| { - let across_partitions = c.across_partitions(); - let expr = c.owned_expr(); - let normalized_expr = self.eq_group.normalize_expr(expr); - - if const_exprs_contains(&self.constants, &normalized_expr) { - return None; - } - - let const_expr = ConstExpr::from(normalized_expr) - .with_across_partitions(across_partitions); - - Some(const_expr) + ) -> Result<()> { + // Add the new constant to the equivalence group: + for constant in constants { + self.eq_group.add_constant(constant); + } + // Renormalize the orderings after adding new constants by removing + // the constants from existing orderings: + let normal_cls = mem::take(&mut self.oeq_cache.normal_cls); + let normal_orderings = normal_cls.into_iter().map(|ordering| { + ordering.into_iter().filter(|sort_expr| { + self.eq_group.is_expr_constant(&sort_expr.expr).is_none() }) - .collect::>(); - - // Add all new normalized constants - self.constants.extend(normalized_constants); - - // Discover any new orderings based on the constants - for ordering in self.normalized_oeq_class().iter() { - if let Err(e) = self.discover_new_orderings(&ordering[0].expr) { - log::debug!("error discovering new orderings: {e}"); - } + }); + self.oeq_cache.normal_cls = OrderingEquivalenceClass::new(normal_orderings); + self.oeq_cache.update_map(); + // Discover any new orderings based on the constants: + let leading_exprs: Vec<_> = self.oeq_cache.leading_map.keys().cloned().collect(); + for expr in leading_exprs { + self.discover_new_orderings(expr)?; } - - self + Ok(()) } - // Discover new valid orderings in light of a new equality. - // Accepts a single argument (`expr`) which is used to determine - // which orderings should be updated. - // When constants or equivalence classes are changed, there may be new orderings - // that can be discovered with the new equivalence properties. - // For a discussion, see: https://github.com/apache/datafusion/issues/9812 - fn discover_new_orderings(&mut self, expr: &Arc) -> Result<()> { - let normalized_expr = self.eq_group().normalize_expr(Arc::clone(expr)); + /// Discover new valid orderings in light of a new equality. Accepts a single + /// argument (`expr`) which is used to determine the orderings to update. + /// When constants or equivalence classes change, there may be new orderings + /// that can be discovered with the new equivalence properties. + /// For a discussion, see: + fn discover_new_orderings( + &mut self, + normal_expr: Arc, + ) -> Result<()> { + let Some(ordering_idxs) = self.oeq_cache.leading_map.get(&normal_expr) else { + return Ok(()); + }; let eq_class = self .eq_group - .iter() - .find_map(|class| { - class - .contains(&normalized_expr) - .then(|| class.clone().into_vec()) - }) - .unwrap_or_else(|| vec![Arc::clone(&normalized_expr)]); - - let mut new_orderings: Vec = vec![]; - for ordering in self.normalized_oeq_class().iter() { - if !ordering[0].expr.eq(&normalized_expr) { - continue; - } + .get_equivalence_class(&normal_expr) + .map_or_else(|| vec![normal_expr], |class| class.clone().into()); + let mut new_orderings = vec![]; + for idx in ordering_idxs { + let ordering = &self.oeq_cache.normal_cls[*idx]; let leading_ordering_options = ordering[0].options; - for equivalent_expr in &eq_class { + 'exprs: for equivalent_expr in &eq_class { let children = equivalent_expr.children(); if children.is_empty() { continue; } - - // Check if all children match the next expressions in the ordering - let mut all_children_match = true; + // Check if all children match the next expressions in the ordering: let mut child_properties = vec![]; - - // Build properties for each child based on the next expressions - for (i, child) in children.iter().enumerate() { - if let Some(next) = ordering.get(i + 1) { - if !child.as_ref().eq(next.expr.as_ref()) { - all_children_match = false; - break; - } - child_properties.push(ExprProperties { - sort_properties: SortProperties::Ordered(next.options), - range: Interval::make_unbounded( - &child.data_type(&self.schema)?, - )?, - preserves_lex_ordering: true, - }); - } else { - all_children_match = false; - break; + // Build properties for each child based on the next expression: + for (i, child) in children.into_iter().enumerate() { + let Some(next) = ordering.get(i + 1) else { + break 'exprs; + }; + if !next.expr.eq(child) { + break 'exprs; } + let data_type = child.data_type(&self.schema)?; + child_properties.push(ExprProperties { + sort_properties: SortProperties::Ordered(next.options), + range: Interval::make_unbounded(&data_type)?, + preserves_lex_ordering: true, + }); } - - if all_children_match { - // Check if the expression is monotonic in all arguments - if let Ok(expr_properties) = - equivalent_expr.get_properties(&child_properties) - { - if expr_properties.preserves_lex_ordering - && SortProperties::Ordered(leading_ordering_options) - == expr_properties.sort_properties - { - // Assume existing ordering is [c ASC, a ASC, b ASC] - // When equality c = f(a,b) is given, if we know that given ordering `[a ASC, b ASC]`, - // ordering `[f(a,b) ASC]` is valid, then we can deduce that ordering `[a ASC, b ASC]` is also valid. - // Hence, ordering `[a ASC, b ASC]` can be added to the state as a valid ordering. - // (e.g. existing ordering where leading ordering is removed) - new_orderings.push(LexOrdering::new(ordering[1..].to_vec())); - break; - } - } + // Check if the expression is monotonic in all arguments: + let expr_properties = + equivalent_expr.get_properties(&child_properties)?; + if expr_properties.preserves_lex_ordering + && expr_properties.sort_properties + == SortProperties::Ordered(leading_ordering_options) + { + // Assume that `[c ASC, a ASC, b ASC]` is among existing + // orderings. If equality `c = f(a, b)` is given, ordering + // `[a ASC, b ASC]` implies the ordering `[c ASC]`. Thus, + // ordering `[a ASC, b ASC]` is also a valid ordering. + new_orderings.push(ordering[1..].to_vec()); + break; } } } - self.oeq_class.add_new_orderings(new_orderings); - Ok(()) - } - - /// Updates the ordering equivalence group within assuming that the table - /// is re-sorted according to the argument `sort_exprs`. Note that constants - /// and equivalence classes are unchanged as they are unaffected by a re-sort. - /// If the given ordering is already satisfied, the function does nothing. - pub fn with_reorder(mut self, sort_exprs: LexOrdering) -> Self { - // Filter out constant expressions as they don't affect ordering - let filtered_exprs = LexOrdering::new( - sort_exprs - .into_iter() - .filter(|expr| !self.is_expr_constant(&expr.expr)) - .collect(), - ); - - if filtered_exprs.is_empty() { - return self; - } - - let mut new_orderings = vec![filtered_exprs.clone()]; - - // Preserve valid suffixes from existing orderings - let oeq_class = mem::take(&mut self.oeq_class); - for existing in oeq_class { - if self.is_prefix_of(&filtered_exprs, &existing) { - let mut extended = filtered_exprs.clone(); - extended.extend(existing.into_iter().skip(filtered_exprs.len())); - new_orderings.push(extended); - } + if !new_orderings.is_empty() { + self.add_orderings(new_orderings); } - - self.oeq_class = OrderingEquivalenceClass::new(new_orderings); - self + Ok(()) } - /// Checks if the new ordering matches a prefix of the existing ordering - /// (considering expression equivalences) - fn is_prefix_of(&self, new_order: &LexOrdering, existing: &LexOrdering) -> bool { - // Check if new order is longer than existing - can't be a prefix - if new_order.len() > existing.len() { - return false; + /// Updates the ordering equivalence class within assuming that the table + /// is re-sorted according to the argument `ordering`, and returns whether + /// this operation resulted in any change. Note that equivalence classes + /// (and constants) do not change as they are unaffected by a re-sort. If + /// the given ordering is already satisfied, the function does nothing. + pub fn reorder( + &mut self, + ordering: impl IntoIterator, + ) -> Result { + let (ordering, ordering_tee) = ordering.into_iter().tee(); + // First, standardize the given ordering: + let Some(normal_ordering) = self.normalize_sort_exprs(ordering) else { + // If the ordering vanishes after normalization, it is satisfied: + return Ok(false); + }; + if normal_ordering.len() != self.common_sort_prefix_length(&normal_ordering)? { + // If the ordering is unsatisfied, replace existing orderings: + self.clear_orderings(); + self.add_ordering(ordering_tee); + return Ok(true); } - - // Check if new order matches existing prefix (considering equivalences) - new_order.iter().zip(existing).all(|(new, existing)| { - self.eq_group.exprs_equal(&new.expr, &existing.expr) - && new.options == existing.options - }) + Ok(false) } /// Normalizes the given sort expressions (i.e. `sort_exprs`) using the - /// equivalence group and the ordering equivalence class within. - /// - /// Assume that `self.eq_group` states column `a` and `b` are aliases. - /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` - /// are equivalent (in the sense that both describe the ordering of the table). - /// If the `sort_exprs` argument were `vec![b ASC, c ASC, a ASC]`, then this - /// function would return `vec![a ASC, c ASC]`. Internally, it would first - /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result - /// after deduplication. - fn normalize_sort_exprs(&self, sort_exprs: &LexOrdering) -> LexOrdering { - // Convert sort expressions to sort requirements: - let sort_reqs = LexRequirement::from(sort_exprs.clone()); - // Normalize the requirements: - let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); - // Convert sort requirements back to sort expressions: - LexOrdering::from(normalized_sort_reqs) + /// equivalence group within. Returns a `LexOrdering` instance if the + /// expressions define a proper lexicographical ordering. For more details, + /// see [`EquivalenceGroup::normalize_sort_exprs`]. + pub fn normalize_sort_exprs( + &self, + sort_exprs: impl IntoIterator, + ) -> Option { + LexOrdering::new(self.eq_group.normalize_sort_exprs(sort_exprs)) } /// Normalizes the given sort requirements (i.e. `sort_reqs`) using the - /// equivalence group and the ordering equivalence class within. It works by: - /// - Removing expressions that have a constant value from the given requirement. - /// - Replacing sections that belong to some equivalence class in the equivalence - /// group with the first entry in the matching equivalence class. - /// - /// Assume that `self.eq_group` states column `a` and `b` are aliases. - /// Also assume that `self.oeq_class` states orderings `d ASC` and `a ASC, c ASC` - /// are equivalent (in the sense that both describe the ordering of the table). - /// If the `sort_reqs` argument were `vec![b ASC, c ASC, a ASC]`, then this - /// function would return `vec![a ASC, c ASC]`. Internally, it would first - /// normalize to `vec![a ASC, c ASC, a ASC]` and end up with the final result - /// after deduplication. - fn normalize_sort_requirements(&self, sort_reqs: &LexRequirement) -> LexRequirement { - let normalized_sort_reqs = self.eq_group.normalize_sort_requirements(sort_reqs); - let mut constant_exprs = vec![]; - constant_exprs.extend( - self.constants - .iter() - .map(|const_expr| Arc::clone(const_expr.expr())), - ); - let constants_normalized = self.eq_group.normalize_exprs(constant_exprs); - // Prune redundant sections in the requirement: - normalized_sort_reqs - .iter() - .filter(|&order| !physical_exprs_contains(&constants_normalized, &order.expr)) - .cloned() - .collect::() - .collapse() + /// equivalence group within. Returns a `LexRequirement` instance if the + /// expressions define a proper lexicographical requirement. For more + /// details, see [`EquivalenceGroup::normalize_sort_exprs`]. + pub fn normalize_sort_requirements( + &self, + sort_reqs: impl IntoIterator, + ) -> Option { + LexRequirement::new(self.eq_group.normalize_sort_requirements(sort_reqs)) } - /// Checks whether the given ordering is satisfied by any of the existing - /// orderings. - pub fn ordering_satisfy(&self, given: &LexOrdering) -> bool { - // Convert the given sort expressions to sort requirements: - let sort_requirements = LexRequirement::from(given.clone()); - self.ordering_satisfy_requirement(&sort_requirements) + /// Iteratively checks whether the given ordering is satisfied by any of + /// the existing orderings. See [`Self::ordering_satisfy_requirement`] for + /// more details and examples. + pub fn ordering_satisfy( + &self, + given: impl IntoIterator, + ) -> Result { + // First, standardize the given ordering: + let Some(normal_ordering) = self.normalize_sort_exprs(given) else { + // If the ordering vanishes after normalization, it is satisfied: + return Ok(true); + }; + Ok(normal_ordering.len() == self.common_sort_prefix_length(&normal_ordering)?) } - /// Returns the number of consecutive requirements (starting from the left) - /// that are satisfied by the plan ordering. - fn compute_common_sort_prefix_length( + /// Iteratively checks whether the given sort requirement is satisfied by + /// any of the existing orderings. + /// + /// ### Example Scenarios + /// + /// In these scenarios, assume that all expressions share the same sort + /// properties. + /// + /// #### Case 1: Sort Requirement `[a, c]` + /// + /// **Existing orderings:** `[[a, b, c], [a, d]]`, **constants:** `[]` + /// 1. The function first checks the leading requirement `a`, which is + /// satisfied by `[a, b, c].first()`. + /// 2. `a` is added as a constant for the next iteration. + /// 3. Normal orderings become `[[b, c], [d]]`. + /// 4. The function fails for `c` in the second iteration, as neither + /// `[b, c]` nor `[d]` satisfies `c`. + /// + /// #### Case 2: Sort Requirement `[a, d]` + /// + /// **Existing orderings:** `[[a, b, c], [a, d]]`, **constants:** `[]` + /// 1. The function first checks the leading requirement `a`, which is + /// satisfied by `[a, b, c].first()`. + /// 2. `a` is added as a constant for the next iteration. + /// 3. Normal orderings become `[[b, c], [d]]`. + /// 4. The function returns `true` as `[d]` satisfies `d`. + pub fn ordering_satisfy_requirement( &self, - normalized_reqs: &LexRequirement, - ) -> usize { - // Check whether given ordering is satisfied by constraints first - if self.satisfied_by_constraints(normalized_reqs) { - // If the constraints satisfy all requirements, return the full normalized requirements length - return normalized_reqs.len(); + given: impl IntoIterator, + ) -> Result { + // First, standardize the given requirement: + let Some(normal_reqs) = self.normalize_sort_requirements(given) else { + // If the requirement vanishes after normalization, it is satisfied: + return Ok(true); + }; + // Then, check whether given requirement is satisfied by constraints: + if self.satisfied_by_constraints(&normal_reqs) { + return Ok(true); } - + let schema = self.schema(); let mut eq_properties = self.clone(); - - for (i, normalized_req) in normalized_reqs.iter().enumerate() { - // Check whether given ordering is satisfied - if !eq_properties.ordering_satisfy_single(normalized_req) { - // As soon as one requirement is not satisfied, return - // how many we've satisfied so far - return i; + for element in normal_reqs { + // Check whether given requirement is satisfied: + let ExprProperties { + sort_properties, .. + } = eq_properties.get_expr_properties(Arc::clone(&element.expr)); + let satisfy = match sort_properties { + SortProperties::Ordered(options) => element.options.is_none_or(|opts| { + let nullable = element.expr.nullable(schema).unwrap_or(true); + options_compatible(&options, &opts, nullable) + }), + // Singleton expressions satisfy any requirement. + SortProperties::Singleton => true, + SortProperties::Unordered => false, + }; + if !satisfy { + return Ok(false); } // Treat satisfied keys as constants in subsequent iterations. We // can do this because the "next" key only matters in a lexicographical @@ -579,288 +611,263 @@ impl EquivalenceProperties { // From the analysis above, we know that `[a ASC]` is satisfied. Then, // we add column `a` as constant to the algorithm state. This enables us // to deduce that `(b + c) ASC` is satisfied, given `a` is constant. - eq_properties = eq_properties.with_constants(std::iter::once( - ConstExpr::from(Arc::clone(&normalized_req.expr)), - )); + let const_expr = ConstExpr::from(element.expr); + eq_properties.add_constants(std::iter::once(const_expr))?; } + Ok(true) + } - // All requirements are satisfied. - normalized_reqs.len() + /// Returns the number of consecutive sort expressions (starting from the + /// left) that are satisfied by the existing ordering. + fn common_sort_prefix_length(&self, normal_ordering: &LexOrdering) -> Result { + let full_length = normal_ordering.len(); + // Check whether the given ordering is satisfied by constraints: + if self.satisfied_by_constraints_ordering(normal_ordering) { + // If constraints satisfy all sort expressions, return the full + // length: + return Ok(full_length); + } + let schema = self.schema(); + let mut eq_properties = self.clone(); + for (idx, element) in normal_ordering.into_iter().enumerate() { + // Check whether given ordering is satisfied: + let ExprProperties { + sort_properties, .. + } = eq_properties.get_expr_properties(Arc::clone(&element.expr)); + let satisfy = match sort_properties { + SortProperties::Ordered(options) => options_compatible( + &options, + &element.options, + element.expr.nullable(schema).unwrap_or(true), + ), + // Singleton expressions satisfy any ordering. + SortProperties::Singleton => true, + SortProperties::Unordered => false, + }; + if !satisfy { + // As soon as one sort expression is unsatisfied, return how + // many we've satisfied so far: + return Ok(idx); + } + // Treat satisfied keys as constants in subsequent iterations. We + // can do this because the "next" key only matters in a lexicographical + // ordering when the keys to its left have the same values. + // + // Note that these expressions are not properly "constants". This is just + // an implementation strategy confined to this function. + // + // For example, assume that the requirement is `[a ASC, (b + c) ASC]`, + // and existing equivalent orderings are `[a ASC, b ASC]` and `[c ASC]`. + // From the analysis above, we know that `[a ASC]` is satisfied. Then, + // we add column `a` as constant to the algorithm state. This enables us + // to deduce that `(b + c) ASC` is satisfied, given `a` is constant. + let const_expr = ConstExpr::from(Arc::clone(&element.expr)); + eq_properties.add_constants(std::iter::once(const_expr))? + } + // All sort expressions are satisfied, return full length: + Ok(full_length) } - /// Determines the longest prefix of `reqs` that is satisfied by the existing ordering. - /// Returns that prefix as a new `LexRequirement`, and a boolean indicating if all the requirements are satisfied. + /// Determines the longest normal prefix of `ordering` satisfied by the + /// existing ordering. Returns that prefix as a new `LexOrdering`, and a + /// boolean indicating whether all the sort expressions are satisfied. pub fn extract_common_sort_prefix( &self, - reqs: &LexRequirement, - ) -> (LexRequirement, bool) { - // First, standardize the given requirement: - let normalized_reqs = self.normalize_sort_requirements(reqs); - - let prefix_len = self.compute_common_sort_prefix_length(&normalized_reqs); - ( - LexRequirement::new(normalized_reqs[..prefix_len].to_vec()), - prefix_len == normalized_reqs.len(), - ) - } - - /// Checks whether the given sort requirements are satisfied by any of the - /// existing orderings. - pub fn ordering_satisfy_requirement(&self, reqs: &LexRequirement) -> bool { - self.extract_common_sort_prefix(reqs).1 + ordering: LexOrdering, + ) -> Result<(Vec, bool)> { + // First, standardize the given ordering: + let Some(normal_ordering) = self.normalize_sort_exprs(ordering) else { + // If the ordering vanishes after normalization, it is satisfied: + return Ok((vec![], true)); + }; + let prefix_len = self.common_sort_prefix_length(&normal_ordering)?; + let flag = prefix_len == normal_ordering.len(); + let mut sort_exprs: Vec<_> = normal_ordering.into(); + if !flag { + sort_exprs.truncate(prefix_len); + } + Ok((sort_exprs, flag)) } - /// Checks if the sort requirements are satisfied by any of the table constraints (primary key or unique). - /// Returns true if any constraint fully satisfies the requirements. - fn satisfied_by_constraints( + /// Checks if the sort expressions are satisfied by any of the table + /// constraints (primary key or unique). Returns true if any constraint + /// fully satisfies the expressions (i.e. constraint indices form a valid + /// prefix of an existing ordering that matches the expressions). For + /// unique constraints, also verifies nullable columns. + fn satisfied_by_constraints_ordering( &self, - normalized_reqs: &[PhysicalSortRequirement], + normal_exprs: &[PhysicalSortExpr], ) -> bool { self.constraints.iter().any(|constraint| match constraint { - Constraint::PrimaryKey(indices) | Constraint::Unique(indices) => self - .satisfied_by_constraint( - normalized_reqs, - indices, - matches!(constraint, Constraint::Unique(_)), - ), - }) - } - - /// Checks if sort requirements are satisfied by a constraint (primary key or unique). - /// Returns true if the constraint indices form a valid prefix of an existing ordering - /// that matches the requirements. For unique constraints, also verifies nullable columns. - fn satisfied_by_constraint( - &self, - normalized_reqs: &[PhysicalSortRequirement], - indices: &[usize], - check_null: bool, - ) -> bool { - // Requirements must contain indices - if indices.len() > normalized_reqs.len() { - return false; - } - - // Iterate over all orderings - self.oeq_class.iter().any(|ordering| { - if indices.len() > ordering.len() { - return false; - } - - // Build a map of column positions in the ordering - let mut col_positions = HashMap::with_capacity(ordering.len()); - for (pos, req) in ordering.iter().enumerate() { - if let Some(col) = req.expr.as_any().downcast_ref::() { - col_positions.insert( - col.index(), - (pos, col.nullable(&self.schema).unwrap_or(true)), - ); - } - } - - // Check if all constraint indices appear in valid positions - if !indices.iter().all(|&idx| { - col_positions - .get(&idx) - .map(|&(pos, nullable)| { - // For unique constraints, verify column is not nullable if it's first/last - !check_null - || (pos != 0 && pos != ordering.len() - 1) - || !nullable + Constraint::PrimaryKey(indices) | Constraint::Unique(indices) => { + let check_null = matches!(constraint, Constraint::Unique(_)); + let normalized_size = normal_exprs.len(); + indices.len() <= normalized_size + && self.oeq_class.iter().any(|ordering| { + let length = ordering.len(); + if indices.len() > length || normalized_size < length { + return false; + } + // Build a map of column positions in the ordering: + let mut col_positions = HashMap::with_capacity(length); + for (pos, req) in ordering.iter().enumerate() { + if let Some(col) = req.expr.as_any().downcast_ref::() + { + let nullable = col.nullable(&self.schema).unwrap_or(true); + col_positions.insert(col.index(), (pos, nullable)); + } + } + // Check if all constraint indices appear in valid positions: + if !indices.iter().all(|idx| { + col_positions.get(idx).is_some_and(|&(pos, nullable)| { + // For unique constraints, verify column is not nullable if it's first/last: + !check_null + || !nullable + || (pos != 0 && pos != length - 1) + }) + }) { + return false; + } + // Check if this ordering matches the prefix: + normal_exprs.iter().zip(ordering).all(|(given, existing)| { + existing.satisfy_expr(given, &self.schema) + }) }) - .unwrap_or(false) - }) { - return false; } - - // Check if this ordering matches requirements prefix - let ordering_len = ordering.len(); - normalized_reqs.len() >= ordering_len - && normalized_reqs[..ordering_len].iter().zip(ordering).all( - |(req, existing)| { - req.expr.eq(&existing.expr) - && req - .options - .is_none_or(|req_opts| req_opts == existing.options) - }, - ) }) } - /// Determines whether the ordering specified by the given sort requirement - /// is satisfied based on the orderings within, equivalence classes, and - /// constant expressions. - /// - /// # Parameters - /// - /// - `req`: A reference to a `PhysicalSortRequirement` for which the ordering - /// satisfaction check will be done. - /// - /// # Returns - /// - /// Returns `true` if the specified ordering is satisfied, `false` otherwise. - fn ordering_satisfy_single(&self, req: &PhysicalSortRequirement) -> bool { - let ExprProperties { - sort_properties, .. - } = self.get_expr_properties(Arc::clone(&req.expr)); - match sort_properties { - SortProperties::Ordered(options) => { - let sort_expr = PhysicalSortExpr { - expr: Arc::clone(&req.expr), - options, - }; - sort_expr.satisfy(req, self.schema()) + /// Checks if the sort requirements are satisfied by any of the table + /// constraints (primary key or unique). Returns true if any constraint + /// fully satisfies the requirements (i.e. constraint indices form a valid + /// prefix of an existing ordering that matches the requirements). For + /// unique constraints, also verifies nullable columns. + fn satisfied_by_constraints(&self, normal_reqs: &[PhysicalSortRequirement]) -> bool { + self.constraints.iter().any(|constraint| match constraint { + Constraint::PrimaryKey(indices) | Constraint::Unique(indices) => { + let check_null = matches!(constraint, Constraint::Unique(_)); + let normalized_size = normal_reqs.len(); + indices.len() <= normalized_size + && self.oeq_class.iter().any(|ordering| { + let length = ordering.len(); + if indices.len() > length || normalized_size < length { + return false; + } + // Build a map of column positions in the ordering: + let mut col_positions = HashMap::with_capacity(length); + for (pos, req) in ordering.iter().enumerate() { + if let Some(col) = req.expr.as_any().downcast_ref::() + { + let nullable = col.nullable(&self.schema).unwrap_or(true); + col_positions.insert(col.index(), (pos, nullable)); + } + } + // Check if all constraint indices appear in valid positions: + if !indices.iter().all(|idx| { + col_positions.get(idx).is_some_and(|&(pos, nullable)| { + // For unique constraints, verify column is not nullable if it's first/last: + !check_null + || !nullable + || (pos != 0 && pos != length - 1) + }) + }) { + return false; + } + // Check if this ordering matches the prefix: + normal_reqs.iter().zip(ordering).all(|(given, existing)| { + existing.satisfy(given, &self.schema) + }) + }) } - // Singleton expressions satisfies any ordering. - SortProperties::Singleton => true, - SortProperties::Unordered => false, - } + }) } /// Checks whether the `given` sort requirements are equal or more specific /// than the `reference` sort requirements. pub fn requirements_compatible( &self, - given: &LexRequirement, - reference: &LexRequirement, + given: LexRequirement, + reference: LexRequirement, ) -> bool { - let normalized_given = self.normalize_sort_requirements(given); - let normalized_reference = self.normalize_sort_requirements(reference); - - (normalized_reference.len() <= normalized_given.len()) - && normalized_reference + let Some(normal_given) = self.normalize_sort_requirements(given) else { + return true; + }; + let Some(normal_reference) = self.normalize_sort_requirements(reference) else { + return true; + }; + + (normal_reference.len() <= normal_given.len()) + && normal_reference .into_iter() - .zip(normalized_given) + .zip(normal_given) .all(|(reference, given)| given.compatible(&reference)) } - /// Returns the finer ordering among the orderings `lhs` and `rhs`, breaking - /// any ties by choosing `lhs`. - /// - /// The finer ordering is the ordering that satisfies both of the orderings. - /// If the orderings are incomparable, returns `None`. - /// - /// For example, the finer ordering among `[a ASC]` and `[a ASC, b ASC]` is - /// the latter. - pub fn get_finer_ordering( - &self, - lhs: &LexOrdering, - rhs: &LexOrdering, - ) -> Option { - // Convert the given sort expressions to sort requirements: - let lhs = LexRequirement::from(lhs.clone()); - let rhs = LexRequirement::from(rhs.clone()); - let finer = self.get_finer_requirement(&lhs, &rhs); - // Convert the chosen sort requirements back to sort expressions: - finer.map(LexOrdering::from) - } - - /// Returns the finer ordering among the requirements `lhs` and `rhs`, - /// breaking any ties by choosing `lhs`. + /// Modify existing orderings by substituting sort expressions with appropriate + /// targets from the projection mapping. We substitute a sort expression when + /// its physical expression has a one-to-one functional relationship with a + /// target expression in the mapping. /// - /// The finer requirements are the ones that satisfy both of the given - /// requirements. If the requirements are incomparable, returns `None`. + /// After substitution, we may generate more than one `LexOrdering` for each + /// existing equivalent ordering. For example, `[a ASC, b ASC]` will turn + /// into `[CAST(a) ASC, b ASC]` and `[a ASC, b ASC]` when applying projection + /// expressions `a, b, CAST(a)`. /// - /// For example, the finer requirements among `[a ASC]` and `[a ASC, b ASC]` - /// is the latter. - pub fn get_finer_requirement( - &self, - req1: &LexRequirement, - req2: &LexRequirement, - ) -> Option { - let mut lhs = self.normalize_sort_requirements(req1); - let mut rhs = self.normalize_sort_requirements(req2); - lhs.inner - .iter_mut() - .zip(rhs.inner.iter_mut()) - .all(|(lhs, rhs)| { - lhs.expr.eq(&rhs.expr) - && match (lhs.options, rhs.options) { - (Some(lhs_opt), Some(rhs_opt)) => lhs_opt == rhs_opt, - (Some(options), None) => { - rhs.options = Some(options); - true - } - (None, Some(options)) => { - lhs.options = Some(options); - true - } - (None, None) => true, - } - }) - .then_some(if lhs.len() >= rhs.len() { lhs } else { rhs }) - } - - /// we substitute the ordering according to input expression type, this is a simplified version - /// In this case, we just substitute when the expression satisfy the following condition: - /// I. just have one column and is a CAST expression - /// TODO: Add one-to-ones analysis for monotonic ScalarFunctions. - /// TODO: we could precompute all the scenario that is computable, for example: atan(x + 1000) should also be substituted if - /// x is DESC or ASC - /// After substitution, we may generate more than 1 `LexOrdering`. As an example, - /// `[a ASC, b ASC]` will turn into `[a ASC, b ASC], [CAST(a) ASC, b ASC]` when projection expressions `a, b, CAST(a)` is applied. - pub fn substitute_ordering_component( - &self, + /// TODO: Handle all scenarios that allow substitution; e.g. when `x` is + /// sorted, `atan(x + 1000)` should also be substituted. For now, we + /// only consider single-column `CAST` expressions. + fn substitute_oeq_class( + schema: &SchemaRef, mapping: &ProjectionMapping, - sort_expr: &LexOrdering, - ) -> Result> { - let new_orderings = sort_expr - .iter() - .map(|sort_expr| { - let referring_exprs: Vec<_> = mapping - .iter() - .map(|(source, _target)| source) - .filter(|source| expr_refers(source, &sort_expr.expr)) - .cloned() - .collect(); - let mut res = LexOrdering::new(vec![sort_expr.clone()]); - // TODO: Add one-to-ones analysis for ScalarFunctions. - for r_expr in referring_exprs { - // we check whether this expression is substitutable or not - if let Some(cast_expr) = r_expr.as_any().downcast_ref::() { - // we need to know whether the Cast Expr matches or not - let expr_type = sort_expr.expr.data_type(&self.schema)?; - if cast_expr.expr.eq(&sort_expr.expr) - && cast_expr.is_bigger_cast(expr_type) + oeq_class: OrderingEquivalenceClass, + ) -> OrderingEquivalenceClass { + let new_orderings = oeq_class.into_iter().flat_map(|order| { + // Modify/expand existing orderings by substituting sort + // expressions with appropriate targets from the mapping: + order + .into_iter() + .map(|sort_expr| { + let referring_exprs = mapping + .iter() + .map(|(source, _target)| source) + .filter(|source| expr_refers(source, &sort_expr.expr)) + .cloned(); + let mut result = vec![]; + // The sort expression comes from this schema, so the + // following call to `unwrap` is safe. + let expr_type = sort_expr.expr.data_type(schema).unwrap(); + // TODO: Add one-to-one analysis for ScalarFunctions. + for r_expr in referring_exprs { + // We check whether this expression is substitutable. + if let Some(cast_expr) = + r_expr.as_any().downcast_ref::() { - res.push(PhysicalSortExpr { - expr: Arc::clone(&r_expr), - options: sort_expr.options, - }); + // For casts, we need to know whether the cast + // expression matches: + if cast_expr.expr.eq(&sort_expr.expr) + && cast_expr.is_bigger_cast(&expr_type) + { + result.push(PhysicalSortExpr::new( + r_expr, + sort_expr.options, + )); + } } } - } - Ok(res) - }) - .collect::>>()?; - // Generate all valid orderings, given substituted expressions. - let res = new_orderings - .into_iter() - .multi_cartesian_product() - .map(LexOrdering::new) - .collect::>(); - Ok(res) + result.push(sort_expr); + result + }) + // Generate all valid orderings given substituted expressions: + .multi_cartesian_product() + }); + OrderingEquivalenceClass::new(new_orderings) } - /// In projection, supposed we have a input function 'A DESC B DESC' and the output shares the same expression - /// with A and B, we could surely use the ordering of the original ordering, However, if the A has been changed, - /// for example, A-> Cast(A, Int64) or any other form, it is invalid if we continue using the original ordering - /// Since it would cause bug in dependency constructions, we should substitute the input order in order to get correct - /// dependency map, happen in issue 8838: - pub fn substitute_oeq_class(&mut self, mapping: &ProjectionMapping) -> Result<()> { - let new_order = self - .oeq_class - .iter() - .map(|order| self.substitute_ordering_component(mapping, order)) - .collect::>>()?; - let new_order = new_order.into_iter().flatten().collect(); - self.oeq_class = OrderingEquivalenceClass::new(new_order); - Ok(()) - } - /// Projects argument `expr` according to `projection_mapping`, taking - /// equivalences into account. + /// Projects argument `expr` according to the projection described by + /// `mapping`, taking equivalences into account. /// /// For example, assume that columns `a` and `c` are always equal, and that - /// `projection_mapping` encodes following mapping: + /// the projection described by `mapping` encodes the following: /// /// ```text /// a -> a1 @@ -868,13 +875,25 @@ impl EquivalenceProperties { /// ``` /// /// Then, this function projects `a + b` to `Some(a1 + b1)`, `c + b` to - /// `Some(a1 + b1)` and `d` to `None`, meaning that it cannot be projected. + /// `Some(a1 + b1)` and `d` to `None`, meaning that it is not projectable. pub fn project_expr( &self, expr: &Arc, - projection_mapping: &ProjectionMapping, + mapping: &ProjectionMapping, ) -> Option> { - self.eq_group.project_expr(projection_mapping, expr) + self.eq_group.project_expr(mapping, expr) + } + + /// Projects the given `expressions` according to the projection described + /// by `mapping`, taking equivalences into account. This function is similar + /// to [`Self::project_expr`], but projects multiple expressions at once + /// more efficiently than calling `project_expr` for each expression. + pub fn project_expressions<'a>( + &'a self, + expressions: impl IntoIterator> + 'a, + mapping: &'a ProjectionMapping, + ) -> impl Iterator>> + 'a { + self.eq_group.project_expressions(mapping, expressions) } /// Constructs a dependency map based on existing orderings referred to in @@ -906,71 +925,85 @@ impl EquivalenceProperties { /// b ASC: Node {Some(b_new ASC), HashSet{a ASC}} /// c ASC: Node {None, HashSet{a ASC}} /// ``` - fn construct_dependency_map(&self, mapping: &ProjectionMapping) -> DependencyMap { - let mut dependency_map = DependencyMap::new(); - for ordering in self.normalized_oeq_class().iter() { - for (idx, sort_expr) in ordering.iter().enumerate() { - let target_sort_expr = - self.project_expr(&sort_expr.expr, mapping).map(|expr| { - PhysicalSortExpr { - expr, - options: sort_expr.options, - } - }); - let is_projected = target_sort_expr.is_some(); - if is_projected - || mapping - .iter() - .any(|(source, _)| expr_refers(source, &sort_expr.expr)) - { - // Previous ordering is a dependency. Note that there is no, - // dependency for a leading ordering (i.e. the first sort - // expression). - let dependency = idx.checked_sub(1).map(|a| &ordering[a]); - // Add sort expressions that can be projected or referred to - // by any of the projection expressions to the dependency map: - dependency_map.insert( - sort_expr, - target_sort_expr.as_ref(), - dependency, - ); - } - if !is_projected { - // If we can not project, stop constructing the dependency - // map as remaining dependencies will be invalid after projection. + fn construct_dependency_map( + &self, + oeq_class: OrderingEquivalenceClass, + mapping: &ProjectionMapping, + ) -> DependencyMap { + let mut map = DependencyMap::default(); + for ordering in oeq_class.into_iter() { + // Previous expression is a dependency. Note that there is no + // dependency for the leading expression. + if !self.insert_to_dependency_map( + mapping, + ordering[0].clone(), + None, + &mut map, + ) { + continue; + } + for (dependency, sort_expr) in ordering.into_iter().tuple_windows() { + if !self.insert_to_dependency_map( + mapping, + sort_expr, + Some(dependency), + &mut map, + ) { + // If we can't project, stop constructing the dependency map + // as remaining dependencies will be invalid post projection. break; } } } - dependency_map + map } - /// Returns a new `ProjectionMapping` where source expressions are normalized. - /// - /// This normalization ensures that source expressions are transformed into a - /// consistent representation. This is beneficial for algorithms that rely on - /// exact equalities, as it allows for more precise and reliable comparisons. + /// Projects the sort expression according to the projection mapping and + /// inserts it into the dependency map with the given dependency. Returns + /// a boolean flag indicating whether the given expression is projectable. + fn insert_to_dependency_map( + &self, + mapping: &ProjectionMapping, + sort_expr: PhysicalSortExpr, + dependency: Option, + map: &mut DependencyMap, + ) -> bool { + let target_sort_expr = self + .project_expr(&sort_expr.expr, mapping) + .map(|expr| PhysicalSortExpr::new(expr, sort_expr.options)); + let projectable = target_sort_expr.is_some(); + if projectable + || mapping + .iter() + .any(|(source, _)| expr_refers(source, &sort_expr.expr)) + { + // Add sort expressions that can be projected or referred to + // by any of the projection expressions to the dependency map: + map.insert(sort_expr, target_sort_expr, dependency); + } + projectable + } + + /// Returns a new `ProjectionMapping` where source expressions are in normal + /// form. Normalization ensures that source expressions are transformed into + /// a consistent representation, which is beneficial for algorithms that rely + /// on exact equalities, as it allows for more precise and reliable comparisons. /// /// # Parameters /// - /// - `mapping`: A reference to the original `ProjectionMapping` to be normalized. + /// - `mapping`: A reference to the original `ProjectionMapping` to normalize. /// /// # Returns /// - /// A new `ProjectionMapping` with normalized source expressions. - fn normalized_mapping(&self, mapping: &ProjectionMapping) -> ProjectionMapping { - // Construct the mapping where source expressions are normalized. In this way - // In the algorithms below we can work on exact equalities - ProjectionMapping { - map: mapping - .iter() - .map(|(source, target)| { - let normalized_source = - self.eq_group.normalize_expr(Arc::clone(source)); - (normalized_source, Arc::clone(target)) - }) - .collect(), - } + /// A new `ProjectionMapping` with source expressions in normal form. + fn normalize_mapping(&self, mapping: &ProjectionMapping) -> ProjectionMapping { + mapping + .iter() + .map(|(source, target)| { + let normal_source = self.eq_group.normalize_expr(Arc::clone(source)); + (normal_source, target.clone()) + }) + .collect() } /// Computes projected orderings based on a given projection mapping. @@ -984,42 +1017,55 @@ impl EquivalenceProperties { /// /// - `mapping`: A reference to the `ProjectionMapping` that defines the /// relationship between source and target expressions. + /// - `oeq_class`: The `OrderingEquivalenceClass` containing the orderings + /// to project. /// /// # Returns /// - /// A vector of `LexOrdering` containing all valid orderings after projection. - fn projected_orderings(&self, mapping: &ProjectionMapping) -> Vec { - let mapping = self.normalized_mapping(mapping); - + /// A vector of all valid (but not in normal form) orderings after projection. + fn projected_orderings( + &self, + mapping: &ProjectionMapping, + mut oeq_class: OrderingEquivalenceClass, + ) -> Vec { + // Normalize source expressions in the mapping: + let mapping = self.normalize_mapping(mapping); // Get dependency map for existing orderings: - let dependency_map = self.construct_dependency_map(&mapping); - let orderings = mapping.iter().flat_map(|(source, target)| { + oeq_class = Self::substitute_oeq_class(&self.schema, &mapping, oeq_class); + let dependency_map = self.construct_dependency_map(oeq_class, &mapping); + let orderings = mapping.iter().flat_map(|(source, targets)| { referred_dependencies(&dependency_map, source) .into_iter() - .filter_map(|relevant_deps| { - if let Ok(SortProperties::Ordered(options)) = - get_expr_properties(source, &relevant_deps, &self.schema) - .map(|prop| prop.sort_properties) - { - Some((options, relevant_deps)) + .filter_map(|deps| { + let ep = get_expr_properties(source, &deps, &self.schema); + let sort_properties = ep.map(|prop| prop.sort_properties); + if let Ok(SortProperties::Ordered(options)) = sort_properties { + Some((options, deps)) } else { - // Do not consider unordered cases + // Do not consider unordered cases. None } }) .flat_map(|(options, relevant_deps)| { - let sort_expr = PhysicalSortExpr { - expr: Arc::clone(target), - options, - }; - // Generate dependent orderings (i.e. prefixes for `sort_expr`): - let mut dependency_orderings = + // Generate dependent orderings (i.e. prefixes for targets): + let dependency_orderings = generate_dependency_orderings(&relevant_deps, &dependency_map); - // Append `sort_expr` to the dependent orderings: - for ordering in dependency_orderings.iter_mut() { - ordering.push(sort_expr.clone()); + let sort_exprs = targets.iter().map(|(target, _)| { + PhysicalSortExpr::new(Arc::clone(target), options) + }); + if dependency_orderings.is_empty() { + sort_exprs.map(|sort_expr| [sort_expr].into()).collect() + } else { + sort_exprs + .flat_map(|sort_expr| { + let mut result = dependency_orderings.clone(); + for ordering in result.iter_mut() { + ordering.push(sort_expr.clone()); + } + result + }) + .collect::>() } - dependency_orderings }) }); @@ -1033,116 +1079,67 @@ impl EquivalenceProperties { if prefixes.is_empty() { // If prefix is empty, there is no dependency. Insert // empty ordering: - prefixes = vec![LexOrdering::default()]; - } - // Append current ordering on top its dependencies: - for ordering in prefixes.iter_mut() { - if let Some(target) = &node.target_sort_expr { - ordering.push(target.clone()) + if let Some(target) = &node.target { + prefixes.push([target.clone()].into()); + } + } else { + // Append current ordering on top its dependencies: + for ordering in prefixes.iter_mut() { + if let Some(target) = &node.target { + ordering.push(target.clone()); + } } } prefixes }); // Simplify each ordering by removing redundant sections: - orderings - .chain(projected_orderings) - .map(|lex_ordering| lex_ordering.collapse()) - .collect() - } - - /// Projects constants based on the provided `ProjectionMapping`. - /// - /// This function takes a `ProjectionMapping` and identifies/projects - /// constants based on the existing constants and the mapping. It ensures - /// that constants are appropriately propagated through the projection. - /// - /// # Parameters - /// - /// - `mapping`: A reference to a `ProjectionMapping` representing the - /// mapping of source expressions to target expressions in the projection. - /// - /// # Returns - /// - /// Returns a `Vec>` containing the projected constants. - fn projected_constants(&self, mapping: &ProjectionMapping) -> Vec { - // First, project existing constants. For example, assume that `a + b` - // is known to be constant. If the projection were `a as a_new`, `b as b_new`, - // then we would project constant `a + b` as `a_new + b_new`. - let mut projected_constants = self - .constants - .iter() - .flat_map(|const_expr| { - const_expr - .map(|expr| self.eq_group.project_expr(mapping, expr)) - .map(|projected_expr| { - projected_expr - .with_across_partitions(const_expr.across_partitions()) - }) - }) - .collect::>(); - - // Add projection expressions that are known to be constant: - for (source, target) in mapping.iter() { - if self.is_expr_constant(source) - && !const_exprs_contains(&projected_constants, target) - { - if self.is_expr_constant_across_partitions(source) { - projected_constants.push( - ConstExpr::from(target) - .with_across_partitions(self.get_expr_constant_value(source)), - ) - } else { - projected_constants.push( - ConstExpr::from(target) - .with_across_partitions(AcrossPartitions::Heterogeneous), - ) - } - } - } - projected_constants + orderings.chain(projected_orderings).collect() } /// Projects constraints according to the given projection mapping. /// - /// This function takes a projection mapping and extracts the column indices of the target columns. - /// It then projects the constraints to only include relationships between - /// columns that exist in the projected output. + /// This function takes a projection mapping and extracts column indices of + /// target columns. It then projects the constraints to only include + /// relationships between columns that exist in the projected output. /// - /// # Arguments + /// # Parameters /// - /// * `mapping` - A reference to `ProjectionMapping` that defines how expressions are mapped - /// in the projection operation + /// * `mapping` - A reference to the `ProjectionMapping` that defines the + /// projection operation. /// /// # Returns /// - /// Returns a new `Constraints` object containing only the constraints - /// that are valid for the projected columns. + /// Returns an optional `Constraints` object containing only the constraints + /// that are valid for the projected columns (if any exists). fn projected_constraints(&self, mapping: &ProjectionMapping) -> Option { let indices = mapping .iter() - .filter_map(|(_, target)| target.as_any().downcast_ref::()) - .map(|col| col.index()) + .flat_map(|(_, targets)| { + targets.iter().flat_map(|(target, _)| { + target.as_any().downcast_ref::().map(|c| c.index()) + }) + }) .collect::>(); - debug_assert_eq!(mapping.map.len(), indices.len()); self.constraints.project(&indices) } - /// Projects the equivalences within according to `mapping` - /// and `output_schema`. + /// Projects the equivalences within according to `mapping` and + /// `output_schema`. pub fn project(&self, mapping: &ProjectionMapping, output_schema: SchemaRef) -> Self { let eq_group = self.eq_group.project(mapping); - let oeq_class = OrderingEquivalenceClass::new(self.projected_orderings(mapping)); - let constants = self.projected_constants(mapping); - let constraints = self - .projected_constraints(mapping) - .unwrap_or_else(Constraints::empty); + let orderings = + self.projected_orderings(mapping, self.oeq_cache.normal_cls.clone()); + let normal_orderings = orderings + .iter() + .cloned() + .map(|o| eq_group.normalize_sort_exprs(o)); Self { + oeq_cache: OrderingEquivalenceCache::new(normal_orderings), + oeq_class: OrderingEquivalenceClass::new(orderings), + constraints: self.projected_constraints(mapping).unwrap_or_default(), schema: output_schema, eq_group, - oeq_class, - constants, - constraints, } } @@ -1159,7 +1156,7 @@ impl EquivalenceProperties { pub fn find_longest_permutation( &self, exprs: &[Arc], - ) -> (LexOrdering, Vec) { + ) -> Result<(Vec, Vec)> { let mut eq_properties = self.clone(); let mut result = vec![]; // The algorithm is as follows: @@ -1172,32 +1169,23 @@ impl EquivalenceProperties { // This algorithm should reach a fixed point in at most `exprs.len()` // iterations. let mut search_indices = (0..exprs.len()).collect::>(); - for _idx in 0..exprs.len() { + for _ in 0..exprs.len() { // Get ordered expressions with their indices. let ordered_exprs = search_indices .iter() - .flat_map(|&idx| { + .filter_map(|&idx| { let ExprProperties { sort_properties, .. } = eq_properties.get_expr_properties(Arc::clone(&exprs[idx])); match sort_properties { - SortProperties::Ordered(options) => Some(( - PhysicalSortExpr { - expr: Arc::clone(&exprs[idx]), - options, - }, - idx, - )), + SortProperties::Ordered(options) => { + let expr = Arc::clone(&exprs[idx]); + Some((PhysicalSortExpr::new(expr, options), idx)) + } SortProperties::Singleton => { - // Assign default ordering to constant expressions - let options = SortOptions::default(); - Some(( - PhysicalSortExpr { - expr: Arc::clone(&exprs[idx]), - options, - }, - idx, - )) + // Assign default ordering to constant expressions: + let expr = Arc::clone(&exprs[idx]); + Some((PhysicalSortExpr::new_default(expr), idx)) } SortProperties::Unordered => None, } @@ -1215,44 +1203,20 @@ impl EquivalenceProperties { // Note that these expressions are not properly "constants". This is just // an implementation strategy confined to this function. for (PhysicalSortExpr { expr, .. }, idx) in &ordered_exprs { - eq_properties = - eq_properties.with_constants(std::iter::once(ConstExpr::from(expr))); + let const_expr = ConstExpr::from(Arc::clone(expr)); + eq_properties.add_constants(std::iter::once(const_expr))?; search_indices.shift_remove(idx); } // Add new ordered section to the state. result.extend(ordered_exprs); } - let (left, right) = result.into_iter().unzip(); - (LexOrdering::new(left), right) - } - - /// This function determines whether the provided expression is constant - /// based on the known constants. - /// - /// # Parameters - /// - /// - `expr`: A reference to a `Arc` representing the - /// expression to be checked. - /// - /// # Returns - /// - /// Returns `true` if the expression is constant according to equivalence - /// group, `false` otherwise. - pub fn is_expr_constant(&self, expr: &Arc) -> bool { - // As an example, assume that we know columns `a` and `b` are constant. - // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will - // return `false`. - let const_exprs = self - .constants - .iter() - .map(|const_expr| Arc::clone(const_expr.expr())); - let normalized_constants = self.eq_group.normalize_exprs(const_exprs); - let normalized_expr = self.eq_group.normalize_expr(Arc::clone(expr)); - is_constant_recurse(&normalized_constants, &normalized_expr) + Ok(result.into_iter().unzip()) } /// This function determines whether the provided expression is constant - /// across partitions based on the known constants. + /// based on the known constants. For example, if columns `a` and `b` are + /// constant, then expressions `a`, `b` and `a + b` will all return `true` + /// whereas expression `c` will return `false`. /// /// # Parameters /// @@ -1261,87 +1225,15 @@ impl EquivalenceProperties { /// /// # Returns /// - /// Returns `true` if the expression is constant across all partitions according - /// to equivalence group, `false` otherwise - #[deprecated( - since = "45.0.0", - note = "Use [`is_expr_constant_across_partitions`] instead" - )] - pub fn is_expr_constant_accross_partitions( + /// Returns a `Some` value if the expression is constant according to + /// equivalence group, and `None` otherwise. The `Some` variant contains + /// an `AcrossPartitions` value indicating whether the expression is + /// constant across partitions, and its actual value (if available). + pub fn is_expr_constant( &self, expr: &Arc, - ) -> bool { - self.is_expr_constant_across_partitions(expr) - } - - /// This function determines whether the provided expression is constant - /// across partitions based on the known constants. - /// - /// # Parameters - /// - /// - `expr`: A reference to a `Arc` representing the - /// expression to be checked. - /// - /// # Returns - /// - /// Returns `true` if the expression is constant across all partitions according - /// to equivalence group, `false` otherwise. - pub fn is_expr_constant_across_partitions( - &self, - expr: &Arc, - ) -> bool { - // As an example, assume that we know columns `a` and `b` are constant. - // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will - // return `false`. - let const_exprs = self - .constants - .iter() - .filter_map(|const_expr| { - if matches!( - const_expr.across_partitions(), - AcrossPartitions::Uniform { .. } - ) { - Some(Arc::clone(const_expr.expr())) - } else { - None - } - }) - .collect::>(); - let normalized_constants = self.eq_group.normalize_exprs(const_exprs); - let normalized_expr = self.eq_group.normalize_expr(Arc::clone(expr)); - is_constant_recurse(&normalized_constants, &normalized_expr) - } - - /// Retrieves the constant value of a given physical expression, if it exists. - /// - /// Normalizes the input expression and checks if it matches any known constants - /// in the current context. Returns whether the expression has a uniform value, - /// varies across partitions, or is not constant. - /// - /// # Parameters - /// - `expr`: A reference to the physical expression to evaluate. - /// - /// # Returns - /// - `AcrossPartitions::Uniform(value)`: If the expression has the same value across partitions. - /// - `AcrossPartitions::Heterogeneous`: If the expression varies across partitions. - /// - `None`: If the expression is not recognized as constant. - pub fn get_expr_constant_value( - &self, - expr: &Arc, - ) -> AcrossPartitions { - let normalized_expr = self.eq_group.normalize_expr(Arc::clone(expr)); - - if let Some(lit) = normalized_expr.as_any().downcast_ref::() { - return AcrossPartitions::Uniform(Some(lit.value().clone())); - } - - for const_expr in self.constants.iter() { - if normalized_expr.eq(const_expr.expr()) { - return const_expr.across_partitions(); - } - } - - AcrossPartitions::Heterogeneous + ) -> Option { + self.eq_group.is_expr_constant(expr) } /// Retrieves the properties for a given physical expression. @@ -1364,13 +1256,12 @@ impl EquivalenceProperties { .transform_up(|expr| update_properties(expr, self)) .data() .map(|node| node.data) - .unwrap_or(ExprProperties::new_unknown()) + .unwrap_or_else(|_| ExprProperties::new_unknown()) } - /// Transforms this `EquivalenceProperties` into a new `EquivalenceProperties` - /// by mapping columns in the original schema to columns in the new schema - /// by index. - pub fn with_new_schema(self, schema: SchemaRef) -> Result { + /// Transforms this `EquivalenceProperties` by mapping columns in the + /// original schema to columns in the new schema by index. + pub fn with_new_schema(mut self, schema: SchemaRef) -> Result { // The new schema and the original schema is aligned when they have the // same number of columns, and fields at the same index have the same // type in both schemas. @@ -1385,54 +1276,49 @@ impl EquivalenceProperties { // Rewriting equivalence properties in terms of new schema is not // safe when schemas are not aligned: return plan_err!( - "Cannot rewrite old_schema:{:?} with new schema: {:?}", + "Schemas have to be aligned to rewrite equivalences:\n Old schema: {:?}\n New schema: {:?}", self.schema, schema ); } - // Rewrite constants according to new schema: - let new_constants = self - .constants - .into_iter() - .map(|const_expr| { - let across_partitions = const_expr.across_partitions(); - let new_const_expr = with_new_schema(const_expr.owned_expr(), &schema)?; - Ok(ConstExpr::new(new_const_expr) - .with_across_partitions(across_partitions)) - }) - .collect::>>()?; - - // Rewrite orderings according to new schema: - let mut new_orderings = vec![]; - for ordering in self.oeq_class { - let new_ordering = ordering - .into_iter() - .map(|mut sort_expr| { - sort_expr.expr = with_new_schema(sort_expr.expr, &schema)?; - Ok(sort_expr) - }) - .collect::>()?; - new_orderings.push(new_ordering); - } // Rewrite equivalence classes according to the new schema: let mut eq_classes = vec![]; - for eq_class in self.eq_group { - let new_eq_exprs = eq_class - .into_vec() + for mut eq_class in self.eq_group { + // Rewrite the expressions in the equivalence class: + eq_class.exprs = eq_class + .exprs .into_iter() .map(|expr| with_new_schema(expr, &schema)) .collect::>()?; - eq_classes.push(EquivalenceClass::new(new_eq_exprs)); + // Rewrite the constant value (if available and known): + let data_type = eq_class + .canonical_expr() + .map(|e| e.data_type(&schema)) + .transpose()?; + if let (Some(data_type), Some(AcrossPartitions::Uniform(Some(value)))) = + (data_type, &mut eq_class.constant) + { + *value = value.cast_to(&data_type)?; + } + eq_classes.push(eq_class); } + self.eq_group = eq_classes.into(); + + // Rewrite orderings according to new schema: + self.oeq_class = self.oeq_class.with_new_schema(&schema)?; + self.oeq_cache.normal_cls = self.oeq_cache.normal_cls.with_new_schema(&schema)?; + + // Update the schema: + self.schema = schema; - // Construct the resulting equivalence properties: - let mut result = EquivalenceProperties::new(schema); - result.constants = new_constants; - result.add_new_orderings(new_orderings); - result.add_equivalence_group(EquivalenceGroup::new(eq_classes)); + Ok(self) + } +} - Ok(result) +impl From for OrderingEquivalenceClass { + fn from(eq_properties: EquivalenceProperties) -> Self { + eq_properties.oeq_class } } @@ -1440,24 +1326,21 @@ impl EquivalenceProperties { /// /// Format: /// ```text -/// order: [[a ASC, b ASC], [a ASC, c ASC]], eq: [[a = b], [a = c]], const: [a = 1] +/// order: [[b@1 ASC NULLS LAST]], eq: [{members: [a@0], constant: (heterogeneous)}] /// ``` impl Display for EquivalenceProperties { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if self.eq_group.is_empty() - && self.oeq_class.is_empty() - && self.constants.is_empty() - { - return write!(f, "No properties"); - } - if !self.oeq_class.is_empty() { + let empty_eq_group = self.eq_group.is_empty(); + let empty_oeq_class = self.oeq_class.is_empty(); + if empty_oeq_class && empty_eq_group { + write!(f, "No properties")?; + } else if !empty_oeq_class { write!(f, "order: {}", self.oeq_class)?; - } - if !self.eq_group.is_empty() { - write!(f, ", eq: {}", self.eq_group)?; - } - if !self.constants.is_empty() { - write!(f, ", const: [{}]", ConstExpr::format_list(&self.constants))?; + if !empty_eq_group { + write!(f, ", eq: {}", self.eq_group)?; + } + } else { + write!(f, "eq: {}", self.eq_group)?; } Ok(()) } @@ -1501,45 +1384,20 @@ fn update_properties( Interval::make_unbounded(&node.expr.data_type(eq_properties.schema())?)? } // Now, check what we know about orderings: - let normalized_expr = eq_properties + let normal_expr = eq_properties .eq_group .normalize_expr(Arc::clone(&node.expr)); - let oeq_class = eq_properties.normalized_oeq_class(); - if eq_properties.is_expr_constant(&normalized_expr) - || oeq_class.is_expr_partial_const(&normalized_expr) + let oeq_class = &eq_properties.oeq_cache.normal_cls; + if eq_properties.is_expr_constant(&normal_expr).is_some() + || oeq_class.is_expr_partial_const(&normal_expr) { node.data.sort_properties = SortProperties::Singleton; - } else if let Some(options) = oeq_class.get_options(&normalized_expr) { + } else if let Some(options) = oeq_class.get_options(&normal_expr) { node.data.sort_properties = SortProperties::Ordered(options); } Ok(Transformed::yes(node)) } -/// This function determines whether the provided expression is constant -/// based on the known constants. -/// -/// # Parameters -/// -/// - `constants`: A `&[Arc]` containing expressions known to -/// be a constant. -/// - `expr`: A reference to a `Arc` representing the expression -/// to check. -/// -/// # Returns -/// -/// Returns `true` if the expression is constant according to equivalence -/// group, `false` otherwise. -fn is_constant_recurse( - constants: &[Arc], - expr: &Arc, -) -> bool { - if physical_exprs_contains(constants, expr) || expr.as_any().is::() { - return true; - } - let children = expr.children(); - !children.is_empty() && children.iter().all(|c| is_constant_recurse(constants, c)) -} - /// This function examines whether a referring expression directly refers to a /// given referred expression or if any of its children in the expression tree /// refer to the specified expression. @@ -1600,7 +1458,7 @@ fn get_expr_properties( } else if let Some(literal) = expr.as_any().downcast_ref::() { Ok(ExprProperties { sort_properties: SortProperties::Singleton, - range: Interval::try_new(literal.value().clone(), literal.value().clone())?, + range: literal.value().into(), preserves_lex_ordering: true, }) } else { @@ -1614,59 +1472,3 @@ fn get_expr_properties( expr.get_properties(&child_states) } } - -/// Wrapper struct for `Arc` to use them as keys in a hash map. -#[derive(Debug, Clone)] -struct ExprWrapper(Arc); - -impl PartialEq for ExprWrapper { - fn eq(&self, other: &Self) -> bool { - self.0.eq(&other.0) - } -} - -impl Eq for ExprWrapper {} - -impl Hash for ExprWrapper { - fn hash(&self, state: &mut H) { - self.0.hash(state); - } -} - -#[cfg(test)] -mod tests { - - use super::*; - use crate::expressions::{col, BinaryExpr}; - - use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; - use datafusion_expr::Operator; - - #[test] - fn test_expr_consists_of_constants() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("c", DataType::Int32, true), - Field::new("d", DataType::Int32, true), - Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), - ])); - let col_a = col("a", &schema)?; - let col_b = col("b", &schema)?; - let col_d = col("d", &schema)?; - let b_plus_d = Arc::new(BinaryExpr::new( - Arc::clone(&col_b), - Operator::Plus, - Arc::clone(&col_d), - )) as Arc; - - let constants = vec![Arc::clone(&col_a), Arc::clone(&col_b)]; - let expr = Arc::clone(&b_plus_d); - assert!(!is_constant_recurse(&constants, &expr)); - - let constants = vec![Arc::clone(&col_a), Arc::clone(&col_b), Arc::clone(&col_d)]; - let expr = Arc::clone(&b_plus_d); - assert!(is_constant_recurse(&constants, &expr)); - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/equivalence/properties/union.rs b/datafusion/physical-expr/src/equivalence/properties/union.rs index 64ef9278e248..4f44b9b0c9d4 100644 --- a/datafusion/physical-expr/src/equivalence/properties/union.rs +++ b/datafusion/physical-expr/src/equivalence/properties/union.rs @@ -15,28 +15,26 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{internal_err, Result}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; use std::iter::Peekable; use std::sync::Arc; +use super::EquivalenceProperties; use crate::equivalence::class::AcrossPartitions; -use crate::ConstExpr; +use crate::{ConstExpr, PhysicalSortExpr}; -use super::EquivalenceProperties; -use crate::PhysicalSortExpr; use arrow::datatypes::SchemaRef; -use std::slice::Iter; +use datafusion_common::{internal_err, Result}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; -/// Calculates the union (in the sense of `UnionExec`) `EquivalenceProperties` -/// of `lhs` and `rhs` according to the schema of `lhs`. +/// Computes the union (in the sense of `UnionExec`) `EquivalenceProperties` +/// of `lhs` and `rhs` according to the schema of `lhs`. /// -/// Rules: The UnionExec does not interleave its inputs: instead it passes each -/// input partition from the children as its own output. +/// Rules: The `UnionExec` does not interleave its inputs, instead it passes +/// each input partition from the children as its own output. /// /// Since the output equivalence properties are properties that are true for /// *all* output partitions, that is the same as being true for all *input* -/// partitions +/// partitions. fn calculate_union_binary( lhs: EquivalenceProperties, mut rhs: EquivalenceProperties, @@ -48,28 +46,21 @@ fn calculate_union_binary( // First, calculate valid constants for the union. An expression is constant // at the output of the union if it is constant in both sides with matching values. + let rhs_constants = rhs.constants(); let constants = lhs .constants() - .iter() + .into_iter() .filter_map(|lhs_const| { // Find matching constant expression in RHS - rhs.constants() + rhs_constants .iter() - .find(|rhs_const| rhs_const.expr().eq(lhs_const.expr())) + .find(|rhs_const| rhs_const.expr.eq(&lhs_const.expr)) .map(|rhs_const| { - let mut const_expr = ConstExpr::new(Arc::clone(lhs_const.expr())); - - // If both sides have matching constant values, preserve the value and set across_partitions=true - if let ( - AcrossPartitions::Uniform(Some(lhs_val)), - AcrossPartitions::Uniform(Some(rhs_val)), - ) = (lhs_const.across_partitions(), rhs_const.across_partitions()) - { - if lhs_val == rhs_val { - const_expr = const_expr.with_across_partitions( - AcrossPartitions::Uniform(Some(lhs_val)), - ) - } + let mut const_expr = lhs_const.clone(); + // If both sides have matching constant values, preserve it. + // Otherwise, set fall back to heterogeneous values. + if lhs_const.across_partitions != rhs_const.across_partitions { + const_expr.across_partitions = AcrossPartitions::Heterogeneous; } const_expr }) @@ -79,14 +70,13 @@ fn calculate_union_binary( // Next, calculate valid orderings for the union by searching for prefixes // in both sides. let mut orderings = UnionEquivalentOrderingBuilder::new(); - orderings.add_satisfied_orderings(lhs.normalized_oeq_class(), lhs.constants(), &rhs); - orderings.add_satisfied_orderings(rhs.normalized_oeq_class(), rhs.constants(), &lhs); + orderings.add_satisfied_orderings(&lhs, &rhs)?; + orderings.add_satisfied_orderings(&rhs, &lhs)?; let orderings = orderings.build(); - let mut eq_properties = - EquivalenceProperties::new(lhs.schema).with_constants(constants); - - eq_properties.add_new_orderings(orderings); + let mut eq_properties = EquivalenceProperties::new(lhs.schema); + eq_properties.add_constants(constants)?; + eq_properties.add_orderings(orderings); Ok(eq_properties) } @@ -137,135 +127,139 @@ impl UnionEquivalentOrderingBuilder { Self { orderings: vec![] } } - /// Add all orderings from `orderings` that satisfy `properties`, - /// potentially augmented with`constants`. + /// Add all orderings from `source` that satisfy `properties`, + /// potentially augmented with the constants in `source`. /// - /// Note: any column that is known to be constant can be inserted into the - /// ordering without changing its meaning + /// Note: Any column that is known to be constant can be inserted into the + /// ordering without changing its meaning. /// /// For example: - /// * `orderings` contains `[a ASC, c ASC]` and `constants` contains `b` - /// * `properties` has required ordering `[a ASC, b ASC]` + /// * Orderings in `source` contains `[a ASC, c ASC]` and constants contains + /// `b`, + /// * `properties` has the ordering `[a ASC, b ASC]`. /// /// Then this will add `[a ASC, b ASC]` to the `orderings` list (as `a` was /// in the sort order and `b` was a constant). fn add_satisfied_orderings( &mut self, - orderings: impl IntoIterator, - constants: &[ConstExpr], + source: &EquivalenceProperties, properties: &EquivalenceProperties, - ) { - for mut ordering in orderings.into_iter() { + ) -> Result<()> { + let constants = source.constants(); + let properties_constants = properties.constants(); + for mut ordering in source.oeq_cache.normal_cls.clone() { // Progressively shorten the ordering to search for a satisfied prefix: loop { - match self.try_add_ordering(ordering, constants, properties) { + ordering = match self.try_add_ordering( + ordering, + &constants, + properties, + &properties_constants, + )? { AddedOrdering::Yes => break, - AddedOrdering::No(o) => { - ordering = o; - ordering.pop(); + AddedOrdering::No(ordering) => { + let mut sort_exprs: Vec<_> = ordering.into(); + sort_exprs.pop(); + if let Some(ordering) = LexOrdering::new(sort_exprs) { + ordering + } else { + break; + } } } } } + Ok(()) } - /// Adds `ordering`, potentially augmented with constants, if it satisfies - /// the target `properties` properties. + /// Adds `ordering`, potentially augmented with `constants`, if it satisfies + /// the given `properties`. /// - /// Returns + /// # Returns /// - /// * [`AddedOrdering::Yes`] if the ordering was added (either directly or - /// augmented), or was empty. - /// - /// * [`AddedOrdering::No`] if the ordering was not added + /// An [`AddedOrdering::Yes`] instance if the ordering was added (either + /// directly or augmented), or was empty. An [`AddedOrdering::No`] instance + /// otherwise. fn try_add_ordering( &mut self, ordering: LexOrdering, constants: &[ConstExpr], properties: &EquivalenceProperties, - ) -> AddedOrdering { - if ordering.is_empty() { - AddedOrdering::Yes - } else if properties.ordering_satisfy(ordering.as_ref()) { + properties_constants: &[ConstExpr], + ) -> Result { + if properties.ordering_satisfy(ordering.clone())? { // If the ordering satisfies the target properties, no need to // augment it with constants. self.orderings.push(ordering); - AddedOrdering::Yes + Ok(AddedOrdering::Yes) + } else if self.try_find_augmented_ordering( + &ordering, + constants, + properties, + properties_constants, + ) { + // Augmented with constants to match the properties. + Ok(AddedOrdering::Yes) } else { - // Did not satisfy target properties, try and augment with constants - // to match the properties - if self.try_find_augmented_ordering(&ordering, constants, properties) { - AddedOrdering::Yes - } else { - AddedOrdering::No(ordering) - } + Ok(AddedOrdering::No(ordering)) } } /// Attempts to add `constants` to `ordering` to satisfy the properties. - /// - /// returns true if any orderings were added, false otherwise + /// Returns `true` if augmentation took place, `false` otherwise. fn try_find_augmented_ordering( &mut self, ordering: &LexOrdering, constants: &[ConstExpr], properties: &EquivalenceProperties, + properties_constants: &[ConstExpr], ) -> bool { - // can't augment if there is nothing to augment with - if constants.is_empty() { - return false; - } - let start_num_orderings = self.orderings.len(); - - // for each equivalent ordering in properties, try and augment - // `ordering` it with the constants to match - for existing_ordering in properties.oeq_class.iter() { - if let Some(augmented_ordering) = self.augment_ordering( - ordering, - constants, - existing_ordering, - &properties.constants, - ) { - if !augmented_ordering.is_empty() { - assert!(properties.ordering_satisfy(augmented_ordering.as_ref())); + let mut result = false; + // Can only augment if there are constants. + if !constants.is_empty() { + // For each equivalent ordering in properties, try and augment + // `ordering` with the constants to match `existing_ordering`: + for existing_ordering in properties.oeq_class.iter() { + if let Some(augmented_ordering) = Self::augment_ordering( + ordering, + constants, + existing_ordering, + properties_constants, + ) { self.orderings.push(augmented_ordering); + result = true; } } } - - self.orderings.len() > start_num_orderings + result } - /// Attempts to augment the ordering with constants to match the - /// `existing_ordering` - /// - /// Returns Some(ordering) if an augmented ordering was found, None otherwise + /// Attempts to augment the ordering with constants to match `existing_ordering`. + /// Returns `Some(ordering)` if an augmented ordering was found, `None` otherwise. fn augment_ordering( - &mut self, ordering: &LexOrdering, constants: &[ConstExpr], existing_ordering: &LexOrdering, existing_constants: &[ConstExpr], ) -> Option { - let mut augmented_ordering = LexOrdering::default(); - let mut sort_expr_iter = ordering.iter().peekable(); - let mut existing_sort_expr_iter = existing_ordering.iter().peekable(); - - // walk in parallel down the two orderings, trying to match them up - while sort_expr_iter.peek().is_some() || existing_sort_expr_iter.peek().is_some() - { - // If the next expressions are equal, add the next match - // otherwise try and match with a constant + let mut augmented_ordering = vec![]; + let mut sort_exprs = ordering.iter().peekable(); + let mut existing_sort_exprs = existing_ordering.iter().peekable(); + + // Walk in parallel down the two orderings, trying to match them up: + while sort_exprs.peek().is_some() || existing_sort_exprs.peek().is_some() { + // If the next expressions are equal, add the next match. Otherwise, + // try and match with a constant. if let Some(expr) = - advance_if_match(&mut sort_expr_iter, &mut existing_sort_expr_iter) + advance_if_match(&mut sort_exprs, &mut existing_sort_exprs) { augmented_ordering.push(expr); } else if let Some(expr) = - advance_if_matches_constant(&mut sort_expr_iter, existing_constants) + advance_if_matches_constant(&mut sort_exprs, existing_constants) { augmented_ordering.push(expr); } else if let Some(expr) = - advance_if_matches_constant(&mut existing_sort_expr_iter, constants) + advance_if_matches_constant(&mut existing_sort_exprs, constants) { augmented_ordering.push(expr); } else { @@ -274,7 +268,7 @@ impl UnionEquivalentOrderingBuilder { } } - Some(augmented_ordering) + LexOrdering::new(augmented_ordering) } fn build(self) -> Vec { @@ -282,134 +276,135 @@ impl UnionEquivalentOrderingBuilder { } } -/// Advances two iterators in parallel -/// -/// If the next expressions are equal, the iterators are advanced and returns -/// the matched expression . -/// -/// Otherwise, the iterators are left unchanged and return `None` -fn advance_if_match( - iter1: &mut Peekable>, - iter2: &mut Peekable>, +/// Advances two iterators in parallel if the next expressions are equal. +/// Otherwise, the iterators are left unchanged and returns `None`. +fn advance_if_match<'a>( + iter1: &mut Peekable>, + iter2: &mut Peekable>, ) -> Option { - if matches!((iter1.peek(), iter2.peek()), (Some(expr1), Some(expr2)) if expr1.eq(expr2)) - { - iter1.next().unwrap(); + let (expr1, expr2) = (iter1.peek()?, iter2.peek()?); + if expr1.eq(expr2) { + iter1.next(); iter2.next().cloned() } else { None } } -/// Advances the iterator with a constant -/// -/// If the next expression matches one of the constants, advances the iterator -/// returning the matched expression -/// -/// Otherwise, the iterator is left unchanged and returns `None` -fn advance_if_matches_constant( - iter: &mut Peekable>, +/// Advances the iterator with a constant if the next expression matches one of +/// the constants. Otherwise, the iterator is left unchanged and returns `None`. +fn advance_if_matches_constant<'a>( + iter: &mut Peekable>, constants: &[ConstExpr], ) -> Option { let expr = iter.peek()?; - let const_expr = constants.iter().find(|c| c.eq_expr(expr))?; - let found_expr = PhysicalSortExpr::new(Arc::clone(const_expr.expr()), expr.options); + let const_expr = constants.iter().find(|c| expr.expr.eq(&c.expr))?; + let found_expr = PhysicalSortExpr::new(Arc::clone(&const_expr.expr), expr.options); iter.next(); Some(found_expr) } #[cfg(test)] mod tests { - use super::*; - use crate::equivalence::class::const_exprs_contains; use crate::equivalence::tests::{create_test_schema, parse_sort_expr}; use crate::expressions::col; + use crate::PhysicalExpr; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::ScalarValue; use itertools::Itertools; + /// Checks whether `expr` is among in the `const_exprs`. + fn const_exprs_contains( + const_exprs: &[ConstExpr], + expr: &Arc, + ) -> bool { + const_exprs + .iter() + .any(|const_expr| const_expr.expr.eq(expr)) + } + #[test] - fn test_union_equivalence_properties_multi_children_1() { + fn test_union_equivalence_properties_multi_children_1() -> Result<()> { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); let schema3 = append_fields(&schema, "2"); UnionEquivalenceTest::new(&schema) // Children 1 - .with_child_sort(vec![vec!["a", "b", "c"]], &schema) + .with_child_sort(vec![vec!["a", "b", "c"]], &schema)? // Children 2 - .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2) + .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2)? // Children 3 - .with_child_sort(vec![vec!["a2", "b2"]], &schema3) - .with_expected_sort(vec![vec!["a", "b"]]) + .with_child_sort(vec![vec!["a2", "b2"]], &schema3)? + .with_expected_sort(vec![vec!["a", "b"]])? .run() } #[test] - fn test_union_equivalence_properties_multi_children_2() { + fn test_union_equivalence_properties_multi_children_2() -> Result<()> { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); let schema3 = append_fields(&schema, "2"); UnionEquivalenceTest::new(&schema) // Children 1 - .with_child_sort(vec![vec!["a", "b", "c"]], &schema) + .with_child_sort(vec![vec!["a", "b", "c"]], &schema)? // Children 2 - .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2) + .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2)? // Children 3 - .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3) - .with_expected_sort(vec![vec!["a", "b", "c"]]) + .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3)? + .with_expected_sort(vec![vec!["a", "b", "c"]])? .run() } #[test] - fn test_union_equivalence_properties_multi_children_3() { + fn test_union_equivalence_properties_multi_children_3() -> Result<()> { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); let schema3 = append_fields(&schema, "2"); UnionEquivalenceTest::new(&schema) // Children 1 - .with_child_sort(vec![vec!["a", "b"]], &schema) + .with_child_sort(vec![vec!["a", "b"]], &schema)? // Children 2 - .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2) + .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2)? // Children 3 - .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3) - .with_expected_sort(vec![vec!["a", "b"]]) + .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3)? + .with_expected_sort(vec![vec!["a", "b"]])? .run() } #[test] - fn test_union_equivalence_properties_multi_children_4() { + fn test_union_equivalence_properties_multi_children_4() -> Result<()> { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); let schema3 = append_fields(&schema, "2"); UnionEquivalenceTest::new(&schema) // Children 1 - .with_child_sort(vec![vec!["a", "b"]], &schema) + .with_child_sort(vec![vec!["a", "b"]], &schema)? // Children 2 - .with_child_sort(vec![vec!["a1", "b1"]], &schema2) + .with_child_sort(vec![vec!["a1", "b1"]], &schema2)? // Children 3 - .with_child_sort(vec![vec!["b2", "c2"]], &schema3) - .with_expected_sort(vec![]) + .with_child_sort(vec![vec!["b2", "c2"]], &schema3)? + .with_expected_sort(vec![])? .run() } #[test] - fn test_union_equivalence_properties_multi_children_5() { + fn test_union_equivalence_properties_multi_children_5() -> Result<()> { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); UnionEquivalenceTest::new(&schema) // Children 1 - .with_child_sort(vec![vec!["a", "b"], vec!["c"]], &schema) + .with_child_sort(vec![vec!["a", "b"], vec!["c"]], &schema)? // Children 2 - .with_child_sort(vec![vec!["a1", "b1"], vec!["c1"]], &schema2) - .with_expected_sort(vec![vec!["a", "b"], vec!["c"]]) + .with_child_sort(vec![vec!["a1", "b1"], vec!["c1"]], &schema2)? + .with_expected_sort(vec![vec!["a", "b"], vec!["c"]])? .run() } #[test] - fn test_union_equivalence_properties_constants_common_constants() { + fn test_union_equivalence_properties_constants_common_constants() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -417,23 +412,23 @@ mod tests { vec![vec!["a"]], vec!["b", "c"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child: [b ASC], const [a, c] vec![vec!["b"]], vec!["a", "c"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union expected orderings: [[a ASC], [b ASC]], const [c] vec![vec!["a"], vec!["b"]], vec!["c"], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_prefix() { + fn test_union_equivalence_properties_constants_prefix() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -441,23 +436,23 @@ mod tests { vec![vec!["a"]], vec![], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child: [a ASC, b ASC], const [] vec![vec!["a", "b"]], vec![], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: [a ASC], const [] vec![vec!["a"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_asc_desc_mismatch() { + fn test_union_equivalence_properties_constants_asc_desc_mismatch() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -465,23 +460,23 @@ mod tests { vec![vec!["a"]], vec![], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child orderings: [a DESC], const [] vec![vec!["a DESC"]], vec![], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union doesn't have any ordering or constant vec![], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_different_schemas() { + fn test_union_equivalence_properties_constants_different_schemas() -> Result<()> { let schema = create_test_schema().unwrap(); let schema2 = append_fields(&schema, "1"); UnionEquivalenceTest::new(&schema) @@ -490,13 +485,13 @@ mod tests { vec![vec!["a"]], vec![], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child orderings: [a1 ASC, b1 ASC], const [] vec![vec!["a1", "b1"]], vec![], &schema2, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: [a ASC] // @@ -504,12 +499,12 @@ mod tests { // corresponding schemas. vec![vec!["a"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_fill_gaps() { + fn test_union_equivalence_properties_constants_fill_gaps() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -517,13 +512,13 @@ mod tests { vec![vec!["a", "c"]], vec!["b"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child orderings: [b ASC, c ASC], const [a] vec![vec!["b", "c"]], vec!["a"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: [ // [a ASC, b ASC, c ASC], @@ -531,12 +526,12 @@ mod tests { // ], const [] vec![vec!["a", "b", "c"], vec!["b", "a", "c"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_no_fill_gaps() { + fn test_union_equivalence_properties_constants_no_fill_gaps() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -544,23 +539,23 @@ mod tests { vec![vec!["a", "c"]], vec!["d"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child orderings: [b ASC, c ASC], const [a] vec![vec!["b", "c"]], vec!["a"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: [[a]] (only a is constant) vec![vec!["a"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_fill_some_gaps() { + fn test_union_equivalence_properties_constants_fill_some_gaps() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -568,23 +563,24 @@ mod tests { vec![vec!["c"]], vec!["a", "b"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child orderings: [a DESC, b], const [] vec![vec!["a DESC", "b"]], vec![], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: [[a, b]] (can fill in the a/b with constants) vec![vec!["a DESC", "b"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_fill_gaps_non_symmetric() { + fn test_union_equivalence_properties_constants_fill_gaps_non_symmetric() -> Result<()> + { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -592,13 +588,13 @@ mod tests { vec![vec!["a", "c"]], vec!["b"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child orderings: [b ASC, c ASC], const [a] vec![vec!["b DESC", "c"]], vec!["a"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: [ // [a ASC, b ASC, c ASC], @@ -606,12 +602,12 @@ mod tests { // ], const [] vec![vec!["a", "b DESC", "c"], vec!["b DESC", "a", "c"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_gap_fill_symmetric() { + fn test_union_equivalence_properties_constants_gap_fill_symmetric() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -619,25 +615,25 @@ mod tests { vec![vec!["a", "b", "d"]], vec!["c"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child: [a ASC, c ASC, d ASC], const [b] vec![vec!["a", "c", "d"]], vec!["b"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: // [a, b, c, d] // [a, c, b, d] vec![vec!["a", "c", "b", "d"], vec!["a", "b", "c", "d"]], vec![], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_gap_fill_and_common() { + fn test_union_equivalence_properties_constants_gap_fill_and_common() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -645,24 +641,24 @@ mod tests { vec![vec!["a DESC", "d"]], vec!["b", "c"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child: [a DESC, c ASC, d ASC], const [b] vec![vec!["a DESC", "c", "d"]], vec!["b"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: // [a DESC, c, d] [b] vec![vec!["a DESC", "c", "d"]], vec!["b"], - ) + )? .run() } #[test] - fn test_union_equivalence_properties_constants_middle_desc() { + fn test_union_equivalence_properties_constants_middle_desc() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) .with_child_sort_and_const_exprs( @@ -672,20 +668,20 @@ mod tests { vec![vec!["a", "b DESC", "d"]], vec!["c"], &schema, - ) + )? .with_child_sort_and_const_exprs( // Second child: [a ASC, c ASC, d ASC], const [b] vec![vec!["a", "c", "d"]], vec!["b"], &schema, - ) + )? .with_expected_sort_and_const_exprs( // Union orderings: // [a, b, d] (c constant) // [a, c, d] (b constant) vec![vec!["a", "c", "b DESC", "d"], vec!["a", "b DESC", "c", "d"]], vec![], - ) + )? .run() } @@ -718,10 +714,10 @@ mod tests { mut self, orderings: Vec>, schema: &SchemaRef, - ) -> Self { - let properties = self.make_props(orderings, vec![], schema); + ) -> Result { + let properties = self.make_props(orderings, vec![], schema)?; self.child_properties.push(properties); - self + Ok(self) } /// Add a union input with the specified orderings and constant @@ -734,19 +730,19 @@ mod tests { orderings: Vec>, constants: Vec<&str>, schema: &SchemaRef, - ) -> Self { - let properties = self.make_props(orderings, constants, schema); + ) -> Result { + let properties = self.make_props(orderings, constants, schema)?; self.child_properties.push(properties); - self + Ok(self) } /// Set the expected output sort order for the union of the children /// /// See [`Self::make_props`] for the format of the strings in `orderings` - fn with_expected_sort(mut self, orderings: Vec>) -> Self { - let properties = self.make_props(orderings, vec![], &self.output_schema); + fn with_expected_sort(mut self, orderings: Vec>) -> Result { + let properties = self.make_props(orderings, vec![], &self.output_schema)?; self.expected_properties = Some(properties); - self + Ok(self) } /// Set the expected output sort order and constant expressions for the @@ -758,15 +754,16 @@ mod tests { mut self, orderings: Vec>, constants: Vec<&str>, - ) -> Self { - let properties = self.make_props(orderings, constants, &self.output_schema); + ) -> Result { + let properties = + self.make_props(orderings, constants, &self.output_schema)?; self.expected_properties = Some(properties); - self + Ok(self) } /// compute the union's output equivalence properties from the child /// properties, and compare them to the expected properties - fn run(self) { + fn run(self) -> Result<()> { let Self { output_schema, child_properties, @@ -798,6 +795,7 @@ mod tests { ), ); } + Ok(()) } fn assert_eq_properties_same( @@ -808,9 +806,9 @@ mod tests { // Check whether constants are same let lhs_constants = lhs.constants(); let rhs_constants = rhs.constants(); - for rhs_constant in rhs_constants { + for rhs_constant in &rhs_constants { assert!( - const_exprs_contains(lhs_constants, rhs_constant.expr()), + const_exprs_contains(&lhs_constants, &rhs_constant.expr), "{err_msg}\nlhs: {lhs}\nrhs: {rhs}" ); } @@ -845,24 +843,19 @@ mod tests { orderings: Vec>, constants: Vec<&str>, schema: &SchemaRef, - ) -> EquivalenceProperties { - let orderings = orderings - .iter() - .map(|ordering| { - ordering - .iter() - .map(|name| parse_sort_expr(name, schema)) - .collect::() - }) - .collect::>(); + ) -> Result { + let orderings = orderings.iter().map(|ordering| { + ordering.iter().map(|name| parse_sort_expr(name, schema)) + }); let constants = constants .iter() - .map(|col_name| ConstExpr::new(col(col_name, schema).unwrap())) - .collect::>(); + .map(|col_name| ConstExpr::from(col(col_name, schema).unwrap())); - EquivalenceProperties::new_with_orderings(Arc::clone(schema), &orderings) - .with_constants(constants) + let mut props = + EquivalenceProperties::new_with_orderings(Arc::clone(schema), orderings); + props.add_constants(constants)?; + Ok(props) } } @@ -877,25 +870,29 @@ mod tests { let literal_10 = ScalarValue::Int32(Some(10)); // Create first input with a=10 - let const_expr1 = ConstExpr::new(Arc::clone(&col_a)) - .with_across_partitions(AcrossPartitions::Uniform(Some(literal_10.clone()))); - let input1 = EquivalenceProperties::new(Arc::clone(&schema)) - .with_constants(vec![const_expr1]); + let const_expr1 = ConstExpr::new( + Arc::clone(&col_a), + AcrossPartitions::Uniform(Some(literal_10.clone())), + ); + let mut input1 = EquivalenceProperties::new(Arc::clone(&schema)); + input1.add_constants(vec![const_expr1])?; // Create second input with a=10 - let const_expr2 = ConstExpr::new(Arc::clone(&col_a)) - .with_across_partitions(AcrossPartitions::Uniform(Some(literal_10.clone()))); - let input2 = EquivalenceProperties::new(Arc::clone(&schema)) - .with_constants(vec![const_expr2]); + let const_expr2 = ConstExpr::new( + Arc::clone(&col_a), + AcrossPartitions::Uniform(Some(literal_10.clone())), + ); + let mut input2 = EquivalenceProperties::new(Arc::clone(&schema)); + input2.add_constants(vec![const_expr2])?; // Calculate union properties let union_props = calculate_union(vec![input1, input2], schema)?; // Verify column 'a' remains constant with value 10 let const_a = &union_props.constants()[0]; - assert!(const_a.expr().eq(&col_a)); + assert!(const_a.expr.eq(&col_a)); assert_eq!( - const_a.across_partitions(), + const_a.across_partitions, AcrossPartitions::Uniform(Some(literal_10)) ); diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 84374f4a2970..798e68a459ce 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -17,19 +17,20 @@ mod kernels; -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::expressions::binary::kernels::concat_elements_utf8view; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::PhysicalExpr; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; use arrow::array::*; use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene}; use arrow::compute::kernels::cmp::*; use arrow::compute::kernels::comparison::{regexp_is_match, regexp_is_match_scalar}; use arrow::compute::kernels::concat_elements::concat_elements_utf8; -use arrow::compute::{cast, ilike, like, nilike, nlike}; +use arrow::compute::{ + cast, filter_record_batch, ilike, like, nilike, nlike, SlicesIterator, +}; use arrow::datatypes::*; use arrow::error::ArrowError; use datafusion_common::cast::as_boolean_array; @@ -358,11 +359,24 @@ impl PhysicalExpr for BinaryExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { use arrow::compute::kernels::numeric::*; + // Evaluate left-hand side expression. let lhs = self.left.evaluate(batch)?; - // Optimize for short-circuiting `Operator::And` or `Operator::Or` operations and return early. - if check_short_circuit(&lhs, &self.op) { - return Ok(lhs); + // Check if we can apply short-circuit evaluation. + match check_short_circuit(&lhs, &self.op) { + ShortCircuitStrategy::None => {} + ShortCircuitStrategy::ReturnLeft => return Ok(lhs), + ShortCircuitStrategy::ReturnRight => { + let rhs = self.right.evaluate(batch)?; + return Ok(rhs); + } + ShortCircuitStrategy::PreSelection(selection) => { + // The function `evaluate_selection` was not called for filtering and calculation, + // as it takes into account cases where the selection contains null values. + let batch = filter_record_batch(batch, selection)?; + let right_ret = self.right.evaluate(&batch)?; + return pre_selection_scatter(selection, right_ret); + } } let rhs = self.right.evaluate(batch)?; @@ -405,23 +419,19 @@ impl PhysicalExpr for BinaryExpr { let result_type = self.data_type(input_schema)?; - // Attempt to use special kernels if one input is scalar and the other is an array - let scalar_result = match (&lhs, &rhs) { - (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => { - // if left is array and right is literal(not NULL) - use scalar operations - if scalar.is_null() { - None - } else { - self.evaluate_array_scalar(array, scalar.clone())?.map(|r| { - r.and_then(|a| to_result_type_array(&self.op, a, &result_type)) - }) + // If the left-hand side is an array and the right-hand side is a non-null scalar, try the optimized kernel. + if let (ColumnarValue::Array(array), ColumnarValue::Scalar(ref scalar)) = + (&lhs, &rhs) + { + if !scalar.is_null() { + if let Some(result_array) = + self.evaluate_array_scalar(array, scalar.clone())? + { + let final_array = result_array + .and_then(|a| to_result_type_array(&self.op, a, &result_type)); + return final_array.map(ColumnarValue::Array); } } - (_, _) => None, // default to array implementation - }; - - if let Some(result) = scalar_result { - return result.map(ColumnarValue::Array); } // if both arrays or both literals - extract arrays and continue execution @@ -506,7 +516,7 @@ impl PhysicalExpr for BinaryExpr { } } else if self.op.eq(&Operator::Or) { if interval.eq(&Interval::CERTAINLY_FALSE) { - // A certainly false logical conjunction can only derive from certainly + // A certainly false logical disjunction can only derive from certainly // false operands. Otherwise, we prove infeasibility. Ok((!left_interval.eq(&Interval::CERTAINLY_TRUE) && !right_interval.eq(&Interval::CERTAINLY_TRUE)) @@ -811,58 +821,199 @@ impl BinaryExpr { } } +enum ShortCircuitStrategy<'a> { + None, + ReturnLeft, + ReturnRight, + PreSelection(&'a BooleanArray), +} + +/// Based on the results calculated from the left side of the short-circuit operation, +/// if the proportion of `true` is less than 0.2 and the current operation is an `and`, +/// the `RecordBatch` will be filtered in advance. +const PRE_SELECTION_THRESHOLD: f32 = 0.2; + /// Checks if a logical operator (`AND`/`OR`) can short-circuit evaluation based on the left-hand side (lhs) result. /// -/// Short-circuiting occurs when evaluating the right-hand side (rhs) becomes unnecessary: -/// - For `AND`: if ALL values in `lhs` are `false`, the expression must be `false` regardless of rhs. -/// - For `OR`: if ALL values in `lhs` are `true`, the expression must be `true` regardless of rhs. -/// -/// Returns `true` if short-circuiting is possible, `false` otherwise. -/// +/// Short-circuiting occurs under these circumstances: +/// - For `AND`: +/// - if LHS is all false => short-circuit → return LHS +/// - if LHS is all true => short-circuit → return RHS +/// - if LHS is mixed and true_count/sum_count <= [`PRE_SELECTION_THRESHOLD`] -> pre-selection +/// - For `OR`: +/// - if LHS is all true => short-circuit → return LHS +/// - if LHS is all false => short-circuit → return RHS /// # Arguments -/// * `arg` - The left-hand side (lhs) columnar value (array or scalar) +/// * `lhs` - The left-hand side (lhs) columnar value (array or scalar) +/// * `lhs` - The left-hand side (lhs) columnar value (array or scalar) /// * `op` - The logical operator (`AND` or `OR`) /// /// # Implementation Notes /// 1. Only works with Boolean-typed arguments (other types automatically return `false`) /// 2. Handles both scalar values and array values -/// 3. For arrays, uses optimized `true_count()`/`false_count()` methods from arrow-rs. -/// `bool_or`/`bool_and` maybe a better choice too,for detailed discussion,see:[link](https://github.com/apache/datafusion/pull/15462#discussion_r2020558418) -fn check_short_circuit(arg: &ColumnarValue, op: &Operator) -> bool { - let data_type = arg.data_type(); - match (data_type, op) { - (DataType::Boolean, Operator::And) => { - match arg { - ColumnarValue::Array(array) => { - if let Ok(array) = as_boolean_array(&array) { - return array.false_count() == array.len(); - } +/// 3. For arrays, uses optimized bit counting techniques for boolean arrays +fn check_short_circuit<'a>( + lhs: &'a ColumnarValue, + op: &Operator, +) -> ShortCircuitStrategy<'a> { + // Quick reject for non-logical operators,and quick judgment when op is and + let is_and = match op { + Operator::And => true, + Operator::Or => false, + _ => return ShortCircuitStrategy::None, + }; + + // Non-boolean types can't be short-circuited + if lhs.data_type() != DataType::Boolean { + return ShortCircuitStrategy::None; + } + + match lhs { + ColumnarValue::Array(array) => { + // Fast path for arrays - try to downcast to boolean array + if let Ok(bool_array) = as_boolean_array(array) { + // Arrays with nulls can't be short-circuited + if bool_array.null_count() > 0 { + return ShortCircuitStrategy::None; } - ColumnarValue::Scalar(scalar) => { - if let ScalarValue::Boolean(Some(value)) = scalar { - return !value; + + let len = bool_array.len(); + if len == 0 { + return ShortCircuitStrategy::None; + } + + let true_count = bool_array.values().count_set_bits(); + if is_and { + // For AND, prioritize checking for all-false (short circuit case) + // Uses optimized false_count() method provided by Arrow + + // Short circuit if all values are false + if true_count == 0 { + return ShortCircuitStrategy::ReturnLeft; + } + + // If no false values, then all must be true + if true_count == len { + return ShortCircuitStrategy::ReturnRight; + } + + // determine if we can pre-selection + if true_count as f32 / len as f32 <= PRE_SELECTION_THRESHOLD { + return ShortCircuitStrategy::PreSelection(bool_array); + } + } else { + // For OR, prioritize checking for all-true (short circuit case) + // Uses optimized true_count() method provided by Arrow + + // Short circuit if all values are true + if true_count == len { + return ShortCircuitStrategy::ReturnLeft; + } + + // If no true values, then all must be false + if true_count == 0 { + return ShortCircuitStrategy::ReturnRight; } } } - false } - (DataType::Boolean, Operator::Or) => { - match arg { - ColumnarValue::Array(array) => { - if let Ok(array) = as_boolean_array(&array) { - return array.true_count() == array.len(); - } - } - ColumnarValue::Scalar(scalar) => { - if let ScalarValue::Boolean(Some(value)) = scalar { - return *value; - } + ColumnarValue::Scalar(scalar) => { + // Fast path for scalar values + if let ScalarValue::Boolean(Some(is_true)) = scalar { + // Return Left for: + // - AND with false value + // - OR with true value + if (is_and && !is_true) || (!is_and && *is_true) { + return ShortCircuitStrategy::ReturnLeft; + } else { + return ShortCircuitStrategy::ReturnRight; } } - false } - _ => false, } + + // If we can't short-circuit, indicate that normal evaluation should continue + ShortCircuitStrategy::None +} + +/// Creates a new boolean array based on the evaluation of the right expression, +/// but only for positions where the left_result is true. +/// +/// This function is used for short-circuit evaluation optimization of logical AND operations: +/// - When left_result has few true values, we only evaluate the right expression for those positions +/// - Values are copied from right_array where left_result is true +/// - All other positions are filled with false values +/// +/// # Parameters +/// - `left_result` Boolean array with selection mask (typically from left side of AND) +/// - `right_result` Result of evaluating right side of expression (only for selected positions) +/// +/// # Returns +/// A combined ColumnarValue with values from right_result where left_result is true +/// +/// # Example +/// Initial Data: { 1, 2, 3, 4, 5 } +/// Left Evaluation +/// (Condition: Equal to 2 or 3) +/// ↓ +/// Filtered Data: {2, 3} +/// Left Bitmap: { 0, 1, 1, 0, 0 } +/// ↓ +/// Right Evaluation +/// (Condition: Even numbers) +/// ↓ +/// Right Data: { 2 } +/// Right Bitmap: { 1, 0 } +/// ↓ +/// Combine Results +/// Final Bitmap: { 0, 1, 0, 0, 0 } +/// +/// # Note +/// Perhaps it would be better to modify `left_result` directly without creating a copy? +/// In practice, `left_result` should have only one owner, so making changes should be safe. +/// However, this is difficult to achieve under the immutable constraints of [`Arc`] and [`BooleanArray`]. +fn pre_selection_scatter( + left_result: &BooleanArray, + right_result: ColumnarValue, +) -> Result { + let right_boolean_array = match &right_result { + ColumnarValue::Array(array) => array.as_boolean(), + ColumnarValue::Scalar(_) => return Ok(right_result), + }; + + let result_len = left_result.len(); + + let mut result_array_builder = BooleanArray::builder(result_len); + + // keep track of current position we have in right boolean array + let mut right_array_pos = 0; + + // keep track of how much is filled + let mut last_end = 0; + SlicesIterator::new(left_result).for_each(|(start, end)| { + // the gap needs to be filled with false + if start > last_end { + result_array_builder.append_n(start - last_end, false); + } + + // copy values from right array for this slice + let len = end - start; + right_boolean_array + .slice(right_array_pos, len) + .iter() + .for_each(|v| result_array_builder.append_option(v)); + + right_array_pos += len; + last_end = end; + }); + + // Fill any remaining positions with false + if last_end < result_len { + result_array_builder.append_n(result_len - last_end, false); + } + let boolean_result = result_array_builder.finish(); + + Ok(ColumnarValue::Array(Arc::new(boolean_result))) } fn concat_elements(left: Arc, right: Arc) -> Result { @@ -919,10 +1070,14 @@ pub fn similar_to( mod tests { use super::*; use crate::expressions::{col, lit, try_cast, Column, Literal}; + use datafusion_expr::lit as expr_lit; use datafusion_common::plan_datafusion_err; use datafusion_physical_expr_common::physical_expr::fmt_sql; + use crate::planner::logical2physical; + use arrow::array::BooleanArray; + use datafusion_expr::col as logical_col; /// Performs a binary operation, applying any type coercion necessary fn binary_op( left: Arc, @@ -4895,9 +5050,7 @@ mod tests { #[test] fn test_check_short_circuit() { - use crate::planner::logical2physical; - use datafusion_expr::col as logical_col; - use datafusion_expr::lit; + // Test with non-nullable arrays let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), @@ -4911,20 +5064,339 @@ mod tests { .unwrap(); // op: AND left: all false - let left_expr = logical2physical(&logical_col("a").eq(lit(2)), &schema); + let left_expr = logical2physical(&logical_col("a").eq(expr_lit(2)), &schema); let left_value = left_expr.evaluate(&batch).unwrap(); - assert!(check_short_circuit(&left_value, &Operator::And)); + assert!(matches!( + check_short_circuit(&left_value, &Operator::And), + ShortCircuitStrategy::ReturnLeft + )); + // op: AND left: not all false - let left_expr = logical2physical(&logical_col("a").eq(lit(3)), &schema); + let left_expr = logical2physical(&logical_col("a").eq(expr_lit(3)), &schema); let left_value = left_expr.evaluate(&batch).unwrap(); - assert!(!check_short_circuit(&left_value, &Operator::And)); + let ColumnarValue::Array(array) = &left_value else { + panic!("Expected ColumnarValue::Array"); + }; + let ShortCircuitStrategy::PreSelection(value) = + check_short_circuit(&left_value, &Operator::And) + else { + panic!("Expected ShortCircuitStrategy::PreSelection"); + }; + let expected_boolean_arr: Vec<_> = + as_boolean_array(array).unwrap().iter().collect(); + let boolean_arr: Vec<_> = value.iter().collect(); + assert_eq!(expected_boolean_arr, boolean_arr); + // op: OR left: all true - let left_expr = logical2physical(&logical_col("a").gt(lit(0)), &schema); + let left_expr = logical2physical(&logical_col("a").gt(expr_lit(0)), &schema); let left_value = left_expr.evaluate(&batch).unwrap(); - assert!(check_short_circuit(&left_value, &Operator::Or)); + assert!(matches!( + check_short_circuit(&left_value, &Operator::Or), + ShortCircuitStrategy::ReturnLeft + )); + // op: OR left: not all true - let left_expr = logical2physical(&logical_col("a").gt(lit(2)), &schema); + let left_expr: Arc = + logical2physical(&logical_col("a").gt(expr_lit(2)), &schema); let left_value = left_expr.evaluate(&batch).unwrap(); - assert!(!check_short_circuit(&left_value, &Operator::Or)); + assert!(matches!( + check_short_circuit(&left_value, &Operator::Or), + ShortCircuitStrategy::None + )); + + // Test with nullable arrays and null values + let schema_nullable = Arc::new(Schema::new(vec![ + Field::new("c", DataType::Boolean, true), + Field::new("d", DataType::Boolean, true), + ])); + + // Create arrays with null values + let c_array = Arc::new(BooleanArray::from(vec![ + Some(true), + Some(false), + None, + Some(true), + None, + ])) as ArrayRef; + let d_array = Arc::new(BooleanArray::from(vec![ + Some(false), + Some(true), + Some(false), + None, + Some(true), + ])) as ArrayRef; + + let batch_nullable = RecordBatch::try_new( + Arc::clone(&schema_nullable), + vec![Arc::clone(&c_array), Arc::clone(&d_array)], + ) + .unwrap(); + + // Case: Mixed values with nulls - shouldn't short-circuit for AND + let mixed_nulls = logical2physical(&logical_col("c"), &schema_nullable); + let mixed_nulls_value = mixed_nulls.evaluate(&batch_nullable).unwrap(); + assert!(matches!( + check_short_circuit(&mixed_nulls_value, &Operator::And), + ShortCircuitStrategy::None + )); + + // Case: Mixed values with nulls - shouldn't short-circuit for OR + assert!(matches!( + check_short_circuit(&mixed_nulls_value, &Operator::Or), + ShortCircuitStrategy::None + )); + + // Test with all nulls + let all_nulls = Arc::new(BooleanArray::from(vec![None, None, None])) as ArrayRef; + let null_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("e", DataType::Boolean, true)])), + vec![all_nulls], + ) + .unwrap(); + + let null_expr = logical2physical(&logical_col("e"), &null_batch.schema()); + let null_value = null_expr.evaluate(&null_batch).unwrap(); + + // All nulls shouldn't short-circuit for AND or OR + assert!(matches!( + check_short_circuit(&null_value, &Operator::And), + ShortCircuitStrategy::None + )); + assert!(matches!( + check_short_circuit(&null_value, &Operator::Or), + ShortCircuitStrategy::None + )); + + // Test with scalar values + // Scalar true + let scalar_true = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))); + assert!(matches!( + check_short_circuit(&scalar_true, &Operator::Or), + ShortCircuitStrategy::ReturnLeft + )); // Should short-circuit OR + assert!(matches!( + check_short_circuit(&scalar_true, &Operator::And), + ShortCircuitStrategy::ReturnRight + )); // Should return the RHS for AND + + // Scalar false + let scalar_false = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); + assert!(matches!( + check_short_circuit(&scalar_false, &Operator::And), + ShortCircuitStrategy::ReturnLeft + )); // Should short-circuit AND + assert!(matches!( + check_short_circuit(&scalar_false, &Operator::Or), + ShortCircuitStrategy::ReturnRight + )); // Should return the RHS for OR + + // Scalar null + let scalar_null = ColumnarValue::Scalar(ScalarValue::Boolean(None)); + assert!(matches!( + check_short_circuit(&scalar_null, &Operator::And), + ShortCircuitStrategy::None + )); + assert!(matches!( + check_short_circuit(&scalar_null, &Operator::Or), + ShortCircuitStrategy::None + )); + } + + /// Test for [pre_selection_scatter] + /// Since [check_short_circuit] ensures that the left side does not contain null and is neither all_true nor all_false, as well as not being empty, + /// the following tests have been designed: + /// 1. Test sparse left with interleaved true/false + /// 2. Test multiple consecutive true blocks + /// 3. Test multiple consecutive true blocks + /// 4. Test single true at first position + /// 5. Test single true at last position + /// 6. Test nulls in right array + /// 7. Test scalar right handling + #[test] + fn test_pre_selection_scatter() { + fn create_bool_array(bools: Vec) -> BooleanArray { + BooleanArray::from(bools.into_iter().map(Some).collect::>()) + } + // Test sparse left with interleaved true/false + { + // Left: [T, F, T, F, T] + // Right: [F, T, F] (values for 3 true positions) + let left = create_bool_array(vec![true, false, true, false, true]); + let right = ColumnarValue::Array(Arc::new(create_bool_array(vec![ + false, true, false, + ]))); + + let result = pre_selection_scatter(&left, right).unwrap(); + let result_arr = result.into_array(left.len()).unwrap(); + + let expected = create_bool_array(vec![false, false, true, false, false]); + assert_eq!(&expected, result_arr.as_boolean()); + } + // Test multiple consecutive true blocks + { + // Left: [F, T, T, F, T, T, T] + // Right: [T, F, F, T, F] + let left = + create_bool_array(vec![false, true, true, false, true, true, true]); + let right = ColumnarValue::Array(Arc::new(create_bool_array(vec![ + true, false, false, true, false, + ]))); + + let result = pre_selection_scatter(&left, right).unwrap(); + let result_arr = result.into_array(left.len()).unwrap(); + + let expected = + create_bool_array(vec![false, true, false, false, false, true, false]); + assert_eq!(&expected, result_arr.as_boolean()); + } + // Test single true at first position + { + // Left: [T, F, F] + // Right: [F] + let left = create_bool_array(vec![true, false, false]); + let right = ColumnarValue::Array(Arc::new(create_bool_array(vec![false]))); + + let result = pre_selection_scatter(&left, right).unwrap(); + let result_arr = result.into_array(left.len()).unwrap(); + + let expected = create_bool_array(vec![false, false, false]); + assert_eq!(&expected, result_arr.as_boolean()); + } + // Test single true at last position + { + // Left: [F, F, T] + // Right: [F] + let left = create_bool_array(vec![false, false, true]); + let right = ColumnarValue::Array(Arc::new(create_bool_array(vec![false]))); + + let result = pre_selection_scatter(&left, right).unwrap(); + let result_arr = result.into_array(left.len()).unwrap(); + + let expected = create_bool_array(vec![false, false, false]); + assert_eq!(&expected, result_arr.as_boolean()); + } + // Test nulls in right array + { + // Left: [F, T, F, T] + // Right: [None, Some(false)] (with null at first position) + let left = create_bool_array(vec![false, true, false, true]); + let right_arr = BooleanArray::from(vec![None, Some(false)]); + let right = ColumnarValue::Array(Arc::new(right_arr)); + + let result = pre_selection_scatter(&left, right).unwrap(); + let result_arr = result.into_array(left.len()).unwrap(); + + let expected = BooleanArray::from(vec![ + Some(false), + None, // null from right + Some(false), + Some(false), + ]); + assert_eq!(&expected, result_arr.as_boolean()); + } + // Test scalar right handling + { + // Left: [T, F, T] + // Right: Scalar true + let left = create_bool_array(vec![true, false, true]); + let right = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))); + + let result = pre_selection_scatter(&left, right).unwrap(); + assert!(matches!(result, ColumnarValue::Scalar(_))); + } + } + + #[test] + fn test_evaluate_bounds_int32() { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + let a = Arc::new(Column::new("a", 0)) as _; + let b = Arc::new(Column::new("b", 1)) as _; + + // Test addition bounds + let add_expr = + binary_expr(Arc::clone(&a), Operator::Plus, Arc::clone(&b), &schema).unwrap(); + let add_bounds = add_expr + .evaluate_bounds(&[ + &Interval::make(Some(1), Some(10)).unwrap(), + &Interval::make(Some(5), Some(15)).unwrap(), + ]) + .unwrap(); + assert_eq!(add_bounds, Interval::make(Some(6), Some(25)).unwrap()); + + // Test subtraction bounds + let sub_expr = + binary_expr(Arc::clone(&a), Operator::Minus, Arc::clone(&b), &schema) + .unwrap(); + let sub_bounds = sub_expr + .evaluate_bounds(&[ + &Interval::make(Some(1), Some(10)).unwrap(), + &Interval::make(Some(5), Some(15)).unwrap(), + ]) + .unwrap(); + assert_eq!(sub_bounds, Interval::make(Some(-14), Some(5)).unwrap()); + + // Test multiplication bounds + let mul_expr = + binary_expr(Arc::clone(&a), Operator::Multiply, Arc::clone(&b), &schema) + .unwrap(); + let mul_bounds = mul_expr + .evaluate_bounds(&[ + &Interval::make(Some(1), Some(10)).unwrap(), + &Interval::make(Some(5), Some(15)).unwrap(), + ]) + .unwrap(); + assert_eq!(mul_bounds, Interval::make(Some(5), Some(150)).unwrap()); + + // Test division bounds + let div_expr = + binary_expr(Arc::clone(&a), Operator::Divide, Arc::clone(&b), &schema) + .unwrap(); + let div_bounds = div_expr + .evaluate_bounds(&[ + &Interval::make(Some(10), Some(20)).unwrap(), + &Interval::make(Some(2), Some(5)).unwrap(), + ]) + .unwrap(); + assert_eq!(div_bounds, Interval::make(Some(2), Some(10)).unwrap()); + } + + #[test] + fn test_evaluate_bounds_bool() { + let schema = Schema::new(vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Boolean, false), + ]); + + let a = Arc::new(Column::new("a", 0)) as _; + let b = Arc::new(Column::new("b", 1)) as _; + + // Test OR bounds + let or_expr = + binary_expr(Arc::clone(&a), Operator::Or, Arc::clone(&b), &schema).unwrap(); + let or_bounds = or_expr + .evaluate_bounds(&[ + &Interval::make(Some(true), Some(true)).unwrap(), + &Interval::make(Some(false), Some(false)).unwrap(), + ]) + .unwrap(); + assert_eq!(or_bounds, Interval::make(Some(true), Some(true)).unwrap()); + + // Test AND bounds + let and_expr = + binary_expr(Arc::clone(&a), Operator::And, Arc::clone(&b), &schema).unwrap(); + let and_bounds = and_expr + .evaluate_bounds(&[ + &Interval::make(Some(true), Some(true)).unwrap(), + &Interval::make(Some(false), Some(false)).unwrap(), + ]) + .unwrap(); + assert_eq!( + and_bounds, + Interval::make(Some(false), Some(false)).unwrap() + ); } } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 854c715eb0a2..1a74e78f1075 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -15,13 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::expressions::try_cast; +use crate::PhysicalExpr; use std::borrow::Cow; use std::hash::Hash; use std::{any::Any, sync::Arc}; -use crate::expressions::try_cast; -use crate::PhysicalExpr; - use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; @@ -603,7 +602,7 @@ mod tests { use crate::expressions::{binary, cast, col, lit, BinaryExpr}; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; - use arrow::datatypes::*; + use arrow::datatypes::Field; use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index a6766687a881..c91678317b75 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::physical_expr::PhysicalExpr; use arrow::compute::{can_cast_types, CastOptions}; -use arrow::datatypes::{DataType, DataType::*, Schema}; +use arrow::datatypes::{DataType, DataType::*, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result}; @@ -97,8 +97,10 @@ impl CastExpr { pub fn cast_options(&self) -> &CastOptions<'static> { &self.cast_options } - pub fn is_bigger_cast(&self, src: DataType) -> bool { - if src == self.cast_type { + + /// Check if the cast is a widening cast (e.g. from `Int8` to `Int16`). + pub fn is_bigger_cast(&self, src: &DataType) -> bool { + if self.cast_type.eq(src) { return true; } matches!( @@ -144,6 +146,16 @@ impl PhysicalExpr for CastExpr { value.cast_to(&self.cast_type, Some(&self.cast_options)) } + fn return_field(&self, input_schema: &Schema) -> Result { + Ok(self + .expr + .return_field(input_schema)? + .as_ref() + .clone() + .with_data_type(self.cast_type.clone()) + .into()) + } + fn children(&self) -> Vec<&Arc> { vec![&self.expr] } diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index ab5b35984753..5a11783a87e9 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -22,6 +22,7 @@ use std::hash::Hash; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; +use arrow::datatypes::FieldRef; use arrow::{ datatypes::{DataType, Schema, SchemaRef}, record_batch::RecordBatch, @@ -127,6 +128,10 @@ impl PhysicalExpr for Column { Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) } + fn return_field(&self, input_schema: &Schema) -> Result { + Ok(input_schema.field(self.index).clone().into()) + } + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/datafusion/physical-expr/src/expressions/dynamic_filters.rs b/datafusion/physical-expr/src/expressions/dynamic_filters.rs index c0a3285f0e78..ba30b916b9f8 100644 --- a/datafusion/physical-expr/src/expressions/dynamic_filters.rs +++ b/datafusion/physical-expr/src/expressions/dynamic_filters.rs @@ -43,7 +43,7 @@ pub struct DynamicFilterPhysicalExpr { /// so that when we update `current()` in subsequent iterations we can re-apply the replacements. remapped_children: Option>>, /// The source of dynamic filters. - inner: Arc>>, + inner: Arc>, /// For testing purposes track the data type and nullability to make sure they don't change. /// If they do, there's a bug in the implementation. /// But this can have overhead in production, so it's only included in our tests. @@ -51,6 +51,25 @@ pub struct DynamicFilterPhysicalExpr { nullable: Arc>>, } +#[derive(Debug)] +struct Inner { + /// A counter that gets incremented every time the expression is updated so that we can track changes cheaply. + /// This is used for [`PhysicalExpr::snapshot_generation`] to have a cheap check for changes. + generation: u64, + expr: Arc, +} + +impl Inner { + fn new(expr: Arc) -> Self { + Self { + // Start with generation 1 which gives us a different result for [`PhysicalExpr::generation`] than the default 0. + // This is not currently used anywhere but it seems useful to have this simple distinction. + generation: 1, + expr, + } + } +} + impl Hash for DynamicFilterPhysicalExpr { fn hash(&self, state: &mut H) { let inner = self.current().expect("Failed to get current expression"); @@ -75,7 +94,7 @@ impl Eq for DynamicFilterPhysicalExpr {} impl Display for DynamicFilterPhysicalExpr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let inner = self.current().expect("Failed to get current expression"); - write!(f, "DynamicFilterPhysicalExpr [ {} ]", inner) + write!(f, "DynamicFilterPhysicalExpr [ {inner} ]") } } @@ -111,7 +130,7 @@ impl DynamicFilterPhysicalExpr { Self { children, remapped_children: None, // Initially no remapped children - inner: Arc::new(RwLock::new(inner)), + inner: Arc::new(RwLock::new(Inner::new(inner))), data_type: Arc::new(RwLock::new(None)), nullable: Arc::new(RwLock::new(None)), } @@ -150,15 +169,17 @@ impl DynamicFilterPhysicalExpr { /// This will return the current expression with any children /// remapped to match calls to [`PhysicalExpr::with_new_children`]. pub fn current(&self) -> Result> { - let inner = self - .inner - .read() - .map_err(|_| { - datafusion_common::DataFusionError::Execution( - "Failed to acquire read lock for inner".to_string(), - ) - })? - .clone(); + let inner = Arc::clone( + &self + .inner + .read() + .map_err(|_| { + datafusion_common::DataFusionError::Execution( + "Failed to acquire read lock for inner".to_string(), + ) + })? + .expr, + ); let inner = Self::remap_children(&self.children, self.remapped_children.as_ref(), inner)?; Ok(inner) @@ -186,7 +207,10 @@ impl DynamicFilterPhysicalExpr { self.remapped_children.as_ref(), new_expr, )?; - *current = new_expr; + // Update the inner expression to the new expression. + current.expr = new_expr; + // Increment the generation to indicate that the expression has changed. + current.generation += 1; Ok(()) } } @@ -291,6 +315,14 @@ impl PhysicalExpr for DynamicFilterPhysicalExpr { // Return the current expression as a snapshot. Ok(Some(self.current()?)) } + + fn snapshot_generation(&self) -> u64 { + // Return the current generation of the expression. + self.inner + .read() + .expect("Failed to acquire read lock for inner") + .generation + } } #[cfg(test)] @@ -342,7 +374,7 @@ mod test { ) .unwrap(); let snap = dynamic_filter_1.snapshot().unwrap().unwrap(); - insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 0 }, op: Eq, right: Literal { value: Int32(42) }, fail_on_overflow: false }"#); + insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 0 }, op: Eq, right: Literal { value: Int32(42), field: Field { name: "lit", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }"#); let dynamic_filter_2 = reassign_predicate_columns( Arc::clone(&dynamic_filter) as Arc, &filter_schema_2, @@ -350,7 +382,7 @@ mod test { ) .unwrap(); let snap = dynamic_filter_2.snapshot().unwrap().unwrap(); - insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 1 }, op: Eq, right: Literal { value: Int32(42) }, fail_on_overflow: false }"#); + insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 1 }, op: Eq, right: Literal { value: Int32(42), field: Field { name: "lit", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, fail_on_overflow: false }"#); // Both filters allow evaluating the same expression let batch_1 = RecordBatch::try_new( Arc::clone(&filter_schema_1), diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 469f7bbee317..b6fe84ea5157 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -1451,7 +1451,7 @@ mod tests { let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); assert_eq!(sql_string, "a IN (a, b)"); - assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }])"); + assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(\"b\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }])"); // Test: a NOT IN ('a', 'b') let list = vec![lit("a"), lit("b")]; @@ -1459,7 +1459,7 @@ mod tests { let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); assert_eq!(sql_string, "a NOT IN (a, b)"); - assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }])"); + assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(\"b\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }])"); // Test: a IN ('a', 'b', NULL) let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; @@ -1467,7 +1467,7 @@ mod tests { let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); assert_eq!(sql_string, "a IN (a, b, NULL)"); - assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }, Literal { value: Utf8(NULL) }])"); + assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(\"b\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(NULL), field: Field { name: \"lit\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }])"); // Test: a NOT IN ('a', 'b', NULL) let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; @@ -1475,7 +1475,7 @@ mod tests { let sql_string = fmt_sql(expr.as_ref()).to_string(); let display_string = expr.to_string(); assert_eq!(sql_string, "a NOT IN (a, b, NULL)"); - assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }, Literal { value: Utf8(NULL) }])"); + assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(\"b\"), field: Field { name: \"lit\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8(NULL), field: Field { name: \"lit\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} } }])"); Ok(()) } diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 0619e7248858..ff05dab40126 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -17,10 +17,8 @@ //! IS NOT NULL expression -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::PhysicalExpr; +use arrow::datatypes::FieldRef; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -28,6 +26,8 @@ use arrow::{ use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; /// IS NOT NULL expression #[derive(Debug, Eq)] @@ -94,6 +94,10 @@ impl PhysicalExpr for IsNotNullExpr { } } + fn return_field(&self, input_schema: &Schema) -> Result { + self.arg.return_field(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.arg] } diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 4c6081f35cad..15c7c645bda0 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -17,10 +17,8 @@ //! IS NULL expression -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::PhysicalExpr; +use arrow::datatypes::FieldRef; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -28,6 +26,8 @@ use arrow::{ use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; /// IS NULL expression #[derive(Debug, Eq)] @@ -93,6 +93,10 @@ impl PhysicalExpr for IsNullExpr { } } + fn return_field(&self, input_schema: &Schema) -> Result { + self.arg.return_field(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.arg] } diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index ebf9882665ba..e86c778d5161 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -15,15 +15,14 @@ // specific language governing permissions and limitations // under the License. -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::datum::apply_cmp; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; // Like expression #[derive(Debug, Eq)] diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 0d0c0ecc62c7..1a2ebf000f1d 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -23,26 +23,59 @@ use std::sync::Arc; use crate::physical_expr::PhysicalExpr; +use arrow::datatypes::{Field, FieldRef}; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::expr::FieldMetadata; use datafusion_expr::Expr; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; /// Represents a literal value -#[derive(Debug, PartialEq, Eq, Hash)] +#[derive(Debug, PartialEq, Eq)] pub struct Literal { value: ScalarValue, + field: FieldRef, +} + +impl Hash for Literal { + fn hash(&self, state: &mut H) { + self.value.hash(state); + let metadata = self.field.metadata(); + let mut keys = metadata.keys().collect::>(); + keys.sort(); + for key in keys { + key.hash(state); + metadata.get(key).unwrap().hash(state); + } + } } impl Literal { /// Create a literal value expression pub fn new(value: ScalarValue) -> Self { - Self { value } + Self::new_with_metadata(value, None) + } + + /// Create a literal value expression + pub fn new_with_metadata( + value: ScalarValue, + metadata: Option, + ) -> Self { + let mut field = Field::new("lit".to_string(), value.data_type(), value.is_null()); + + if let Some(metadata) = metadata { + field = metadata.add_to_field(field); + } + + Self { + value, + field: field.into(), + } } /// Get the scalar value @@ -71,6 +104,10 @@ impl PhysicalExpr for Literal { Ok(self.value.is_null()) } + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.field)) + } + fn evaluate(&self, _batch: &RecordBatch) -> Result { Ok(ColumnarValue::Scalar(self.value.clone())) } @@ -102,7 +139,7 @@ impl PhysicalExpr for Literal { /// Create a literal expression pub fn lit(value: T) -> Arc { match value.lit() { - Expr::Literal(v) => Arc::new(Literal::new(v)), + Expr::Literal(v, _) => Arc::new(Literal::new(v)), _ => unreachable!(), } } @@ -112,7 +149,7 @@ mod tests { use super::*; use arrow::array::Int32Array; - use arrow::datatypes::*; + use arrow::datatypes::Field; use datafusion_common::cast::as_int32_array; use datafusion_physical_expr_common::physical_expr::fmt_sql; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index d77207fbbcd7..8f46133ed0bb 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -43,6 +43,7 @@ pub use case::{case, CaseExpr}; pub use cast::{cast, CastExpr}; pub use column::{col, with_new_schema, Column}; pub use datafusion_expr::utils::format_state_name; +pub use dynamic_filters::DynamicFilterPhysicalExpr; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 33a1bae14d42..fa7224768a77 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -23,6 +23,7 @@ use std::sync::Arc; use crate::PhysicalExpr; +use arrow::datatypes::FieldRef; use arrow::{ compute::kernels::numeric::neg_wrapping, datatypes::{DataType, Schema}, @@ -103,6 +104,10 @@ impl PhysicalExpr for NegativeExpr { } } + fn return_field(&self, input_schema: &Schema) -> Result { + self.arg.return_field(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.arg] } diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs index 24d2f4d9e074..94610996c6b0 100644 --- a/datafusion/physical-expr/src/expressions/no_op.rs +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -21,12 +21,11 @@ use std::any::Any; use std::hash::Hash; use std::sync::Arc; +use crate::PhysicalExpr; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; - -use crate::PhysicalExpr; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index 8a3348b43d20..8184ef601e54 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use crate::PhysicalExpr; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{cast::as_boolean_array, internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; @@ -101,6 +101,10 @@ impl PhysicalExpr for NotExpr { } } + fn return_field(&self, input_schema: &Schema) -> Result { + self.arg.return_field(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.arg] } diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index e49815cd8b64..b593dfe83209 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::PhysicalExpr; use arrow::compute; use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use compute::can_cast_types; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; @@ -110,6 +110,13 @@ impl PhysicalExpr for TryCastExpr { } } + fn return_field(&self, input_schema: &Schema) -> Result { + self.expr + .return_field(input_schema) + .map(|f| f.as_ref().clone().with_data_type(self.cast_type.clone())) + .map(Arc::new) + } + fn children(&self) -> Vec<&Arc> { vec![&self.expr] } diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index a53814c3ad2b..c44197bbbe6f 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -42,7 +42,7 @@ //! //! In order to use interval arithmetic to compute bounds for this expression, //! one would first determine intervals that represent the possible values of -//! `x` and `y`` Let's say that the interval for `x` is `[1, 2]` and the interval +//! `x` and `y` Let's say that the interval for `x` is `[1, 2]` and the interval //! for `y` is `[-3, 1]`. In the chart below, you can see how the computation //! takes place. //! @@ -148,12 +148,12 @@ use std::sync::Arc; use super::utils::{ convert_duration_type_to_interval, convert_interval_type_to_duration, get_inverse_op, }; -use crate::expressions::Literal; +use crate::expressions::{BinaryExpr, Literal}; use crate::utils::{build_dag, ExprTreeNode}; use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_expr::interval_arithmetic::{apply_operator, satisfy_greater, Interval}; use datafusion_expr::Operator; @@ -645,6 +645,17 @@ impl ExprIntervalGraph { .map(|child| self.graph[*child].interval()) .collect::>(); let node_interval = self.graph[node].interval(); + // Special case: true OR could in principle be propagated by 3 interval sets, + // (i.e. left true, or right true, or both true) however we do not support this yet. + if node_interval == &Interval::CERTAINLY_TRUE + && self.graph[node] + .expr + .as_any() + .downcast_ref::() + .is_some_and(|expr| expr.op() == &Operator::Or) + { + return not_impl_err!("OR operator cannot yet propagate true intervals"); + } let propagated_intervals = self.graph[node] .expr .propagate_constraints(node_interval, &children_intervals)?; @@ -857,8 +868,8 @@ mod tests { let mut r = StdRng::seed_from_u64(seed); let (left_given, right_given, left_expected, right_expected) = if ASC { - let left = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); - let right = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); + let left = r.random_range((0 as $TYPE)..(1000 as $TYPE)); + let right = r.random_range((0 as $TYPE)..(1000 as $TYPE)); ( (Some(left), None), (Some(right), None), @@ -866,8 +877,8 @@ mod tests { (Some(<$TYPE>::max(right, left + expr_right)), None), ) } else { - let left = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); - let right = r.gen_range((0 as $TYPE)..(1000 as $TYPE)); + let left = r.random_range((0 as $TYPE)..(1000 as $TYPE)); + let right = r.random_range((0 as $TYPE)..(1000 as $TYPE)); ( (None, Some(left)), (None, Some(right)), diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 9f795c81fa48..3bdb9d84d827 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -30,6 +30,7 @@ pub mod analysis; pub mod binary_map { pub use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; } +pub mod async_scalar_function; pub mod equivalence; pub mod expressions; pub mod intervals; @@ -37,6 +38,7 @@ mod partitioning; mod physical_expr; pub mod planner; mod scalar_function; +pub mod schema_rewriter; pub mod statistics; pub mod utils; pub mod window; @@ -54,20 +56,20 @@ pub use equivalence::{ }; pub use partitioning::{Distribution, Partitioning}; pub use physical_expr::{ - create_ordering, create_physical_sort_expr, create_physical_sort_exprs, - physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal, - PhysicalExprRef, + add_offset_to_expr, add_offset_to_physical_sort_exprs, create_ordering, + create_physical_sort_expr, create_physical_sort_exprs, physical_exprs_bag_equal, + physical_exprs_contains, physical_exprs_equal, }; -pub use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +pub use datafusion_physical_expr_common::physical_expr::{PhysicalExpr, PhysicalExprRef}; pub use datafusion_physical_expr_common::sort_expr::{ - LexOrdering, LexRequirement, PhysicalSortExpr, PhysicalSortRequirement, + LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortExpr, + PhysicalSortRequirement, }; pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; - -pub use datafusion_physical_expr_common::utils::reverse_order_bys; +pub use schema_rewriter::PhysicalExprSchemaRewriter; pub use utils::{conjunction, conjunction_opt, split_conjunction}; // For backwards compatibility diff --git a/datafusion/physical-expr/src/partitioning.rs b/datafusion/physical-expr/src/partitioning.rs index eb7e1ea6282b..d6b2b1b046f7 100644 --- a/datafusion/physical-expr/src/partitioning.rs +++ b/datafusion/physical-expr/src/partitioning.rs @@ -199,18 +199,17 @@ impl Partitioning { /// Calculate the output partitioning after applying the given projection. pub fn project( &self, - projection_mapping: &ProjectionMapping, + mapping: &ProjectionMapping, input_eq_properties: &EquivalenceProperties, ) -> Self { if let Partitioning::Hash(exprs, part) = self { - let normalized_exprs = exprs - .iter() - .map(|expr| { - input_eq_properties - .project_expr(expr, projection_mapping) - .unwrap_or_else(|| { - Arc::new(UnKnownColumn::new(&expr.to_string())) - }) + let normalized_exprs = input_eq_properties + .project_expressions(exprs, mapping) + .zip(exprs) + .map(|(proj_expr, expr)| { + proj_expr.unwrap_or_else(|| { + Arc::new(UnKnownColumn::new(&expr.to_string())) + }) }) .collect(); Partitioning::Hash(normalized_exprs, *part) diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 63c4ccbb4b38..80dd8ce069b7 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -17,13 +17,40 @@ use std::sync::Arc; -use crate::create_physical_expr; +use crate::expressions::{self, Column}; +use crate::{create_physical_expr, LexOrdering, PhysicalSortExpr}; + +use arrow::compute::SortOptions; +use arrow::datatypes::Schema; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{plan_err, Result}; use datafusion_common::{DFSchema, HashMap}; use datafusion_expr::execution_props::ExecutionProps; -pub(crate) use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -pub use datafusion_physical_expr_common::physical_expr::PhysicalExprRef; +use datafusion_expr::{Expr, SortExpr}; + use itertools::izip; +// Exports: +pub(crate) use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +/// Adds the `offset` value to `Column` indices inside `expr`. This function is +/// generally used during the update of the right table schema in join operations. +pub fn add_offset_to_expr( + expr: Arc, + offset: isize, +) -> Result> { + expr.transform_down(|e| match e.as_any().downcast_ref::() { + Some(col) => { + let Some(idx) = col.index().checked_add_signed(offset) else { + return plan_err!("Column index overflow"); + }; + Ok(Transformed::yes(Arc::new(Column::new(col.name(), idx)))) + } + None => Ok(Transformed::no(e)), + }) + .data() +} + /// This function is similar to the `contains` method of `Vec`. It finds /// whether `expr` is among `physical_exprs`. pub fn physical_exprs_contains( @@ -60,26 +87,21 @@ pub fn physical_exprs_bag_equal( multi_set_lhs == multi_set_rhs } -use crate::{expressions, LexOrdering, PhysicalSortExpr}; -use arrow::compute::SortOptions; -use arrow::datatypes::Schema; -use datafusion_common::plan_err; -use datafusion_common::Result; -use datafusion_expr::{Expr, SortExpr}; - -/// Converts logical sort expressions to physical sort expressions +/// Converts logical sort expressions to physical sort expressions. /// -/// This function transforms a collection of logical sort expressions into their physical -/// representation that can be used during query execution. +/// This function transforms a collection of logical sort expressions into their +/// physical representation that can be used during query execution. /// /// # Arguments /// -/// * `schema` - The schema containing column definitions -/// * `sort_order` - A collection of logical sort expressions grouped into lexicographic orderings +/// * `schema` - The schema containing column definitions. +/// * `sort_order` - A collection of logical sort expressions grouped into +/// lexicographic orderings. /// /// # Returns /// -/// A vector of lexicographic orderings for physical execution, or an error if the transformation fails +/// A vector of lexicographic orderings for physical execution, or an error if +/// the transformation fails. /// /// # Examples /// @@ -114,18 +136,13 @@ pub fn create_ordering( for (group_idx, exprs) in sort_order.iter().enumerate() { // Construct PhysicalSortExpr objects from Expr objects: - let mut sort_exprs = LexOrdering::default(); + let mut sort_exprs = vec![]; for (expr_idx, sort) in exprs.iter().enumerate() { match &sort.expr { Expr::Column(col) => match expressions::col(&col.name, schema) { Ok(expr) => { - sort_exprs.push(PhysicalSortExpr { - expr, - options: SortOptions { - descending: !sort.asc, - nulls_first: sort.nulls_first, - }, - }); + let opts = SortOptions::new(!sort.asc, sort.nulls_first); + sort_exprs.push(PhysicalSortExpr::new(expr, opts)); } // Cannot find expression in the projected_schema, stop iterating // since rest of the orderings are violated @@ -141,9 +158,7 @@ pub fn create_ordering( } } } - if !sort_exprs.is_empty() { - all_sort_orders.push(sort_exprs); - } + all_sort_orders.extend(LexOrdering::new(sort_exprs)); } Ok(all_sort_orders) } @@ -154,17 +169,9 @@ pub fn create_physical_sort_expr( input_dfschema: &DFSchema, execution_props: &ExecutionProps, ) -> Result { - let SortExpr { - expr, - asc, - nulls_first, - } = e; - Ok(PhysicalSortExpr { - expr: create_physical_expr(expr, input_dfschema, execution_props)?, - options: SortOptions { - descending: !asc, - nulls_first: *nulls_first, - }, + create_physical_expr(&e.expr, input_dfschema, execution_props).map(|expr| { + let options = SortOptions::new(!e.asc, e.nulls_first); + PhysicalSortExpr::new(expr, options) }) } @@ -173,11 +180,24 @@ pub fn create_physical_sort_exprs( exprs: &[SortExpr], input_dfschema: &DFSchema, execution_props: &ExecutionProps, -) -> Result { +) -> Result> { exprs .iter() - .map(|expr| create_physical_sort_expr(expr, input_dfschema, execution_props)) - .collect::>() + .map(|e| create_physical_sort_expr(e, input_dfschema, execution_props)) + .collect() +} + +pub fn add_offset_to_physical_sort_exprs( + sort_exprs: impl IntoIterator, + offset: isize, +) -> Result> { + sort_exprs + .into_iter() + .map(|mut sort_expr| { + sort_expr.expr = add_offset_to_expr(sort_expr.expr, offset)?; + Ok(sort_expr) + }) + .collect() } #[cfg(test)] diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 8660bff796d5..fbc19b1202ee 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -28,7 +28,9 @@ use datafusion_common::{ exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema, }; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction}; +use datafusion_expr::expr::{ + Alias, Cast, FieldMetadata, InList, Placeholder, ScalarFunction, +}; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ @@ -111,14 +113,28 @@ pub fn create_physical_expr( let input_schema: &Schema = &input_dfschema.into(); match e { - Expr::Alias(Alias { expr, .. }) => { - Ok(create_physical_expr(expr, input_dfschema, execution_props)?) + Expr::Alias(Alias { expr, metadata, .. }) => { + if let Expr::Literal(v, prior_metadata) = expr.as_ref() { + let new_metadata = FieldMetadata::merge_options( + prior_metadata.as_ref(), + metadata.as_ref(), + ); + Ok(Arc::new(Literal::new_with_metadata( + v.clone(), + new_metadata, + ))) + } else { + Ok(create_physical_expr(expr, input_dfschema, execution_props)?) + } } Expr::Column(c) => { let idx = input_dfschema.index_of_column(c)?; Ok(Arc::new(Column::new(&c.name, idx))) } - Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), + Expr::Literal(value, metadata) => Ok(Arc::new(Literal::new_with_metadata( + value.clone(), + metadata.clone(), + ))), Expr::ScalarVariable(_, variable_names) => { if is_system_variables(variable_names) { match execution_props.get_var_provider(VarType::System) { @@ -168,7 +184,7 @@ pub fn create_physical_expr( let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsNotDistinctFrom, - Expr::Literal(ScalarValue::Boolean(None)), + Expr::Literal(ScalarValue::Boolean(None), None), ); create_physical_expr(&binary_op, input_dfschema, execution_props) } @@ -176,7 +192,7 @@ pub fn create_physical_expr( let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsDistinctFrom, - Expr::Literal(ScalarValue::Boolean(None)), + Expr::Literal(ScalarValue::Boolean(None), None), ); create_physical_expr(&binary_op, input_dfschema, execution_props) } @@ -347,7 +363,7 @@ pub fn create_physical_expr( list, negated, }) => match expr.as_ref() { - Expr::Literal(ScalarValue::Utf8(None)) => { + Expr::Literal(ScalarValue::Utf8(None), _) => { Ok(expressions::lit(ScalarValue::Boolean(None))) } _ => { @@ -380,7 +396,7 @@ where exprs .into_iter() .map(|expr| create_physical_expr(expr, input_dfschema, execution_props)) - .collect::>>() + .collect() } /// Convert a logical expression to a physical expression (without any simplification, etc) diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 44bbcc4928c6..d014bbb74caa 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -38,13 +38,13 @@ use crate::expressions::Literal; use crate::PhysicalExpr; use arrow::array::{Array, RecordBatch}; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, FieldRef, Schema}; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::{ - expr_vec_fmt, ColumnarValue, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, + expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, }; /// Physical expression of a scalar function @@ -53,8 +53,7 @@ pub struct ScalarFunctionExpr { fun: Arc, name: String, args: Vec>, - return_type: DataType, - nullable: bool, + return_field: FieldRef, } impl Debug for ScalarFunctionExpr { @@ -63,7 +62,7 @@ impl Debug for ScalarFunctionExpr { .field("fun", &"") .field("name", &self.name) .field("args", &self.args) - .field("return_type", &self.return_type) + .field("return_field", &self.return_field) .finish() } } @@ -74,14 +73,13 @@ impl ScalarFunctionExpr { name: &str, fun: Arc, args: Vec>, - return_type: DataType, + return_field: FieldRef, ) -> Self { Self { fun, name: name.to_owned(), args, - return_type, - nullable: true, + return_field, } } @@ -92,18 +90,17 @@ impl ScalarFunctionExpr { schema: &Schema, ) -> Result { let name = fun.name().to_string(); - let arg_types = args + let arg_fields = args .iter() - .map(|e| e.data_type(schema)) + .map(|e| e.return_field(schema)) .collect::>>()?; // verify that input data types is consistent with function's `TypeSignature` - data_types_with_scalar_udf(&arg_types, &fun)?; - - let nullables = args + let arg_types = arg_fields .iter() - .map(|e| e.nullable(schema)) - .collect::>>()?; + .map(|f| f.data_type().clone()) + .collect::>(); + data_types_with_scalar_udf(&arg_types, &fun)?; let arguments = args .iter() @@ -113,18 +110,16 @@ impl ScalarFunctionExpr { .map(|literal| literal.value()) }) .collect::>(); - let ret_args = ReturnTypeArgs { - arg_types: &arg_types, + let ret_args = ReturnFieldArgs { + arg_fields: &arg_fields, scalar_arguments: &arguments, - nullables: &nullables, }; - let (return_type, nullable) = fun.return_type_from_args(ret_args)?.into_parts(); + let return_field = fun.return_field_from_args(ret_args)?; Ok(Self { fun, name, args, - return_type, - nullable, + return_field, }) } @@ -145,16 +140,21 @@ impl ScalarFunctionExpr { /// Data type produced by this expression pub fn return_type(&self) -> &DataType { - &self.return_type + self.return_field.data_type() } pub fn with_nullable(mut self, nullable: bool) -> Self { - self.nullable = nullable; + self.return_field = self + .return_field + .as_ref() + .clone() + .with_nullable(nullable) + .into(); self } pub fn nullable(&self) -> bool { - self.nullable + self.return_field.is_nullable() } } @@ -171,11 +171,11 @@ impl PhysicalExpr for ScalarFunctionExpr { } fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(self.return_type.clone()) + Ok(self.return_field.data_type().clone()) } fn nullable(&self, _input_schema: &Schema) -> Result { - Ok(self.nullable) + Ok(self.return_field.is_nullable()) } fn evaluate(&self, batch: &RecordBatch) -> Result { @@ -185,6 +185,12 @@ impl PhysicalExpr for ScalarFunctionExpr { .map(|e| e.evaluate(batch)) .collect::>>()?; + let arg_fields = self + .args + .iter() + .map(|e| e.return_field(batch.schema_ref())) + .collect::>>()?; + let input_empty = args.is_empty(); let input_all_scalar = args .iter() @@ -193,8 +199,9 @@ impl PhysicalExpr for ScalarFunctionExpr { // evaluate the function let output = self.fun.invoke_with_args(ScalarFunctionArgs { args, + arg_fields, number_rows: batch.num_rows(), - return_type: &self.return_type, + return_field: Arc::clone(&self.return_field), })?; if let ColumnarValue::Array(array) = &output { @@ -214,6 +221,10 @@ impl PhysicalExpr for ScalarFunctionExpr { Ok(output) } + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.return_field)) + } + fn children(&self) -> Vec<&Arc> { self.args.iter().collect() } @@ -222,15 +233,12 @@ impl PhysicalExpr for ScalarFunctionExpr { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new( - ScalarFunctionExpr::new( - &self.name, - Arc::clone(&self.fun), - children, - self.return_type().clone(), - ) - .with_nullable(self.nullable), - )) + Ok(Arc::new(ScalarFunctionExpr::new( + &self.name, + Arc::clone(&self.fun), + children, + Arc::clone(&self.return_field), + ))) } fn evaluate_bounds(&self, children: &[&Interval]) -> Result { diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs new file mode 100644 index 000000000000..b8759ea16d6e --- /dev/null +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -0,0 +1,466 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical expression schema rewriting utilities + +use std::sync::Arc; + +use arrow::compute::can_cast_types; +use arrow::datatypes::{FieldRef, Schema}; +use datafusion_common::{ + exec_err, + tree_node::{Transformed, TransformedResult, TreeNode}, + Result, ScalarValue, +}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +use crate::expressions::{self, CastExpr, Column}; + +/// Builder for rewriting physical expressions to match different schemas. +/// +/// # Example +/// +/// ```rust +/// use datafusion_physical_expr::schema_rewriter::PhysicalExprSchemaRewriter; +/// use arrow::datatypes::Schema; +/// +/// # fn example( +/// # predicate: std::sync::Arc, +/// # physical_file_schema: &Schema, +/// # logical_file_schema: &Schema, +/// # ) -> datafusion_common::Result<()> { +/// let rewriter = PhysicalExprSchemaRewriter::new(physical_file_schema, logical_file_schema); +/// let adapted_predicate = rewriter.rewrite(predicate)?; +/// # Ok(()) +/// # } +/// ``` +pub struct PhysicalExprSchemaRewriter<'a> { + physical_file_schema: &'a Schema, + logical_file_schema: &'a Schema, + partition_fields: Vec, + partition_values: Vec, +} + +impl<'a> PhysicalExprSchemaRewriter<'a> { + /// Create a new schema rewriter with the given schemas + pub fn new( + physical_file_schema: &'a Schema, + logical_file_schema: &'a Schema, + ) -> Self { + Self { + physical_file_schema, + logical_file_schema, + partition_fields: Vec::new(), + partition_values: Vec::new(), + } + } + + /// Add partition columns and their corresponding values + /// + /// When a column reference matches a partition field, it will be replaced + /// with the corresponding literal value from partition_values. + pub fn with_partition_columns( + mut self, + partition_fields: Vec, + partition_values: Vec, + ) -> Self { + self.partition_fields = partition_fields; + self.partition_values = partition_values; + self + } + + /// Rewrite the given physical expression to match the target schema + /// + /// This method applies the following transformations: + /// 1. Replaces partition column references with literal values + /// 2. Handles missing columns by inserting null literals + /// 3. Casts columns when logical and physical schemas have different types + pub fn rewrite(&self, expr: Arc) -> Result> { + expr.transform(|expr| self.rewrite_expr(expr)).data() + } + + fn rewrite_expr( + &self, + expr: Arc, + ) -> Result>> { + if let Some(column) = expr.as_any().downcast_ref::() { + return self.rewrite_column(Arc::clone(&expr), column); + } + + Ok(Transformed::no(expr)) + } + + fn rewrite_column( + &self, + expr: Arc, + column: &Column, + ) -> Result>> { + // Get the logical field for this column + let logical_field = match self.logical_file_schema.field_with_name(column.name()) + { + Ok(field) => field, + Err(e) => { + // If the column is a partition field, we can use the partition value + if let Some(partition_value) = self.get_partition_value(column.name()) { + return Ok(Transformed::yes(expressions::lit(partition_value))); + } + // If the column is not found in the logical schema and is not a partition value, return an error + // This should probably never be hit unless something upstream broke, but nontheless it's better + // for us to return a handleable error than to panic / do something unexpected. + return Err(e.into()); + } + }; + + // Check if the column exists in the physical schema + let physical_column_index = + match self.physical_file_schema.index_of(column.name()) { + Ok(index) => index, + Err(_) => { + if !logical_field.is_nullable() { + return exec_err!( + "Non-nullable column '{}' is missing from the physical schema", + column.name() + ); + } + // If the column is missing from the physical schema fill it in with nulls as `SchemaAdapter` would do. + // TODO: do we need to sync this with what the `SchemaAdapter` actually does? + // While the default implementation fills in nulls in theory a custom `SchemaAdapter` could do something else! + // See https://github.com/apache/datafusion/issues/16527 + let null_value = + ScalarValue::Null.cast_to(logical_field.data_type())?; + return Ok(Transformed::yes(expressions::lit(null_value))); + } + }; + let physical_field = self.physical_file_schema.field(physical_column_index); + + let column = match ( + column.index() == physical_column_index, + logical_field.data_type() == physical_field.data_type(), + ) { + // If the column index matches and the data types match, we can use the column as is + (true, true) => return Ok(Transformed::no(expr)), + // If the indexes or data types do not match, we need to create a new column expression + (true, _) => column.clone(), + (false, _) => { + Column::new_with_schema(logical_field.name(), self.physical_file_schema)? + } + }; + + if logical_field.data_type() == physical_field.data_type() { + // If the data types match, we can use the column as is + return Ok(Transformed::yes(Arc::new(column))); + } + + // We need to cast the column to the logical data type + // TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123` + // since that's much cheaper to evalaute. + // See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928 + if !can_cast_types(physical_field.data_type(), logical_field.data_type()) { + return exec_err!( + "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)", + column.name(), + physical_field.data_type(), + logical_field.data_type() + ); + } + + let cast_expr = Arc::new(CastExpr::new( + Arc::new(column), + logical_field.data_type().clone(), + None, + )); + + Ok(Transformed::yes(cast_expr)) + } + + fn get_partition_value(&self, column_name: &str) -> Option { + self.partition_fields + .iter() + .zip(self.partition_values.iter()) + .find(|(field, _)| field.name() == column_name) + .map(|(_, value)| value.clone()) + } +} + +#[cfg(test)] +mod tests { + use crate::expressions::{col, lit}; + + use super::*; + use arrow::{ + array::{RecordBatch, RecordBatchOptions}, + datatypes::{DataType, Field, Schema, SchemaRef}, + }; + use datafusion_common::{record_batch, ScalarValue}; + use datafusion_expr::Operator; + use itertools::Itertools; + use std::sync::Arc; + + fn create_test_schema() -> (Schema, Schema) { + let physical_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ]); + + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int64, false), // Different type + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Float64, true), // Missing from physical + ]); + + (physical_schema, logical_schema) + } + + #[test] + fn test_rewrite_column_with_type_cast() { + let (physical_schema, logical_schema) = create_test_schema(); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let column_expr = Arc::new(Column::new("a", 0)); + + let result = rewriter.rewrite(column_expr).unwrap(); + + // Should be wrapped in a cast expression + assert!(result.as_any().downcast_ref::().is_some()); + } + + #[test] + fn test_rewrite_mulit_column_expr_with_type_cast() { + let (physical_schema, logical_schema) = create_test_schema(); + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + + // Create a complex expression: (a + 5) OR (c > 0.0) that tests the recursive case of the rewriter + let column_a = Arc::new(Column::new("a", 0)) as Arc; + let column_c = Arc::new(Column::new("c", 2)) as Arc; + let expr = expressions::BinaryExpr::new( + Arc::clone(&column_a), + Operator::Plus, + Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))), + ); + let expr = expressions::BinaryExpr::new( + Arc::new(expr), + Operator::Or, + Arc::new(expressions::BinaryExpr::new( + Arc::clone(&column_c), + Operator::Gt, + Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))), + )), + ); + + let result = rewriter.rewrite(Arc::new(expr)).unwrap(); + println!("Rewritten expression: {result}"); + + let expected = expressions::BinaryExpr::new( + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 0)), + DataType::Int64, + None, + )), + Operator::Plus, + Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))), + ); + let expected = Arc::new(expressions::BinaryExpr::new( + Arc::new(expected), + Operator::Or, + Arc::new(expressions::BinaryExpr::new( + lit(ScalarValue::Null), + Operator::Gt, + Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))), + )), + )) as Arc; + + assert_eq!( + result.to_string(), + expected.to_string(), + "The rewritten expression did not match the expected output" + ); + } + + #[test] + fn test_rewrite_missing_column() -> Result<()> { + let (physical_schema, logical_schema) = create_test_schema(); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let column_expr = Arc::new(Column::new("c", 2)); + + let result = rewriter.rewrite(column_expr)?; + + // Should be replaced with a literal null + if let Some(literal) = result.as_any().downcast_ref::() { + assert_eq!(*literal.value(), ScalarValue::Float64(None)); + } else { + panic!("Expected literal expression"); + } + + Ok(()) + } + + #[test] + fn test_rewrite_partition_column() -> Result<()> { + let (physical_schema, logical_schema) = create_test_schema(); + + let partition_fields = + vec![Arc::new(Field::new("partition_col", DataType::Utf8, false))]; + let partition_values = vec![ScalarValue::Utf8(Some("test_value".to_string()))]; + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema) + .with_partition_columns(partition_fields, partition_values); + + let column_expr = Arc::new(Column::new("partition_col", 0)); + let result = rewriter.rewrite(column_expr)?; + + // Should be replaced with the partition value + if let Some(literal) = result.as_any().downcast_ref::() { + assert_eq!( + *literal.value(), + ScalarValue::Utf8(Some("test_value".to_string())) + ); + } else { + panic!("Expected literal expression"); + } + + Ok(()) + } + + #[test] + fn test_rewrite_no_change_needed() -> Result<()> { + let (physical_schema, logical_schema) = create_test_schema(); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let column_expr = Arc::new(Column::new("b", 1)) as Arc; + + let result = rewriter.rewrite(Arc::clone(&column_expr))?; + + // Should be the same expression (no transformation needed) + // We compare the underlying pointer through the trait object + assert!(std::ptr::eq( + column_expr.as_ref() as *const dyn PhysicalExpr, + result.as_ref() as *const dyn PhysicalExpr + )); + + Ok(()) + } + + #[test] + fn test_non_nullable_missing_column_error() { + let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), // Non-nullable missing column + ]); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let column_expr = Arc::new(Column::new("b", 1)); + + let result = rewriter.rewrite(column_expr); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Non-nullable column 'b' is missing")); + } + + /// Roughly stolen from ProjectionExec + fn batch_project( + expr: Vec>, + batch: &RecordBatch, + schema: SchemaRef, + ) -> Result { + let arrays = expr + .iter() + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) + .collect::>>()?; + + if arrays.is_empty() { + let options = + RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); + RecordBatch::try_new_with_options(Arc::clone(&schema), arrays, &options) + .map_err(Into::into) + } else { + RecordBatch::try_new(Arc::clone(&schema), arrays).map_err(Into::into) + } + } + + /// Example showing how we can use the `PhysicalExprSchemaRewriter` to adapt RecordBatches during a scan + /// to apply projections, type conversions and handling of missing columns all at once. + #[test] + fn test_adapt_batches() { + let physical_batch = record_batch!( + ("a", Int32, vec![Some(1), None, Some(3)]), + ("extra", Utf8, vec![Some("x"), Some("y"), None]) + ) + .unwrap(); + + let physical_schema = physical_batch.schema(); + + let logical_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), // Different type + Field::new("b", DataType::Utf8, true), // Missing from physical + ])); + + let projection = vec![ + col("b", &logical_schema).unwrap(), + col("a", &logical_schema).unwrap(), + ]; + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + + let adapted_projection = projection + .into_iter() + .map(|expr| rewriter.rewrite(expr).unwrap()) + .collect_vec(); + + let adapted_schema = Arc::new(Schema::new( + adapted_projection + .iter() + .map(|expr| expr.return_field(&physical_schema).unwrap()) + .collect_vec(), + )); + + let res = batch_project( + adapted_projection, + &physical_batch, + Arc::clone(&adapted_schema), + ) + .unwrap(); + + assert_eq!(res.num_columns(), 2); + assert_eq!(res.column(0).data_type(), &DataType::Utf8); + assert_eq!(res.column(1).data_type(), &DataType::Int64); + assert_eq!( + res.column(0) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect_vec(), + vec![None, None, None] + ); + assert_eq!( + res.column(1) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect_vec(), + vec![Some(1), None, Some(3)] + ); + } +} diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index b4d0758fd2e8..5cfbf13a25cf 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -26,15 +26,13 @@ use crate::tree_node::ExprContext; use crate::PhysicalExpr; use crate::PhysicalSortExpr; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::Schema; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::{HashMap, HashSet, Result}; use datafusion_expr::Operator; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use itertools::Itertools; use petgraph::graph::NodeIndex; use petgraph::stable_graph::StableGraph; @@ -244,7 +242,7 @@ pub fn collect_columns(expr: &Arc) -> HashSet { /// This may be helpful when dealing with projections. pub fn reassign_predicate_columns( pred: Arc, - schema: &SchemaRef, + schema: &Schema, ignore_not_found: bool, ) -> Result> { pred.transform_down(|expr| { @@ -266,15 +264,6 @@ pub fn reassign_predicate_columns( .data() } -/// Merge left and right sort expressions, checking for duplicates. -pub fn merge_vectors(left: &LexOrdering, right: &LexOrdering) -> LexOrdering { - left.iter() - .cloned() - .chain(right.iter().cloned()) - .unique() - .collect() -} - #[cfg(test)] pub(crate) mod tests { use std::any::Any; diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index a94d5b1212f5..6f0e7c963d14 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -27,14 +27,15 @@ use crate::window::window_expr::AggregateWindowExpr; use crate::window::{ PartitionBatches, PartitionWindowAggStates, SlidingAggregateWindowExpr, WindowExpr, }; -use crate::{reverse_order_bys, EquivalenceProperties, PhysicalExpr}; +use crate::{EquivalenceProperties, PhysicalExpr}; use arrow::array::Array; +use arrow::array::ArrayRef; +use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; -use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::{Accumulator, WindowFrame}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_expr::{Accumulator, WindowFrame, WindowFrameBound, WindowFrameUnits}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; /// A window expr that takes the form of an aggregate function. /// @@ -43,8 +44,9 @@ use datafusion_physical_expr_common::sort_expr::LexOrdering; pub struct PlainAggregateWindowExpr { aggregate: Arc, partition_by: Vec>, - order_by: LexOrdering, + order_by: Vec, window_frame: Arc, + is_constant_in_partition: bool, } impl PlainAggregateWindowExpr { @@ -52,14 +54,17 @@ impl PlainAggregateWindowExpr { pub fn new( aggregate: Arc, partition_by: &[Arc], - order_by: &LexOrdering, + order_by: &[PhysicalSortExpr], window_frame: Arc, ) -> Self { + let is_constant_in_partition = + Self::is_window_constant_in_partition(order_by, &window_frame); Self { aggregate, partition_by: partition_by.to_vec(), - order_by: order_by.clone(), + order_by: order_by.to_vec(), window_frame, + is_constant_in_partition, } } @@ -72,7 +77,7 @@ impl PlainAggregateWindowExpr { &self, eq_properties: &mut EquivalenceProperties, window_expr_index: usize, - ) { + ) -> Result<()> { if let Some(expr) = self .get_aggregate_expr() .get_result_ordering(window_expr_index) @@ -81,8 +86,33 @@ impl PlainAggregateWindowExpr { eq_properties, expr, &self.partition_by, - ); + )?; } + Ok(()) + } + + // Returns true if every row in the partition has the same window frame. This allows + // for preventing bound + function calculation for every row due to the values being the + // same. + // + // This occurs when both bounds fall under either condition below: + // 1. Bound is unbounded (`Preceding` or `Following`) + // 2. Bound is `CurrentRow` while using `Range` units with no order by clause + // This results in an invalid range specification. Following PostgreSQL’s convention, + // we interpret this as the entire partition being used for the current window frame. + fn is_window_constant_in_partition( + order_by: &[PhysicalSortExpr], + window_frame: &WindowFrame, + ) -> bool { + let is_constant_bound = |bound: &WindowFrameBound| match bound { + WindowFrameBound::CurrentRow => { + window_frame.units == WindowFrameUnits::Range && order_by.is_empty() + } + _ => bound.is_unbounded(), + }; + + is_constant_bound(&window_frame.start_bound) + && is_constant_bound(&window_frame.end_bound) } } @@ -95,7 +125,7 @@ impl WindowExpr for PlainAggregateWindowExpr { self } - fn field(&self) -> Result { + fn field(&self) -> Result { Ok(self.aggregate.field()) } @@ -141,8 +171,8 @@ impl WindowExpr for PlainAggregateWindowExpr { &self.partition_by } - fn order_by(&self) -> &LexOrdering { - self.order_by.as_ref() + fn order_by(&self) -> &[PhysicalSortExpr] { + &self.order_by } fn get_window_frame(&self) -> &Arc { @@ -156,14 +186,22 @@ impl WindowExpr for PlainAggregateWindowExpr { Arc::new(PlainAggregateWindowExpr::new( Arc::new(reverse_expr), &self.partition_by.clone(), - reverse_order_bys(self.order_by.as_ref()).as_ref(), + &self + .order_by + .iter() + .map(|e| e.reverse()) + .collect::>(), Arc::new(self.window_frame.reverse()), )) as _ } else { Arc::new(SlidingAggregateWindowExpr::new( Arc::new(reverse_expr), &self.partition_by.clone(), - reverse_order_bys(self.order_by.as_ref()).as_ref(), + &self + .order_by + .iter() + .map(|e| e.reverse()) + .collect::>(), Arc::new(self.window_frame.reverse()), )) as _ } @@ -212,4 +250,8 @@ impl AggregateWindowExpr for PlainAggregateWindowExpr { accumulator.evaluate() } } + + fn is_constant_in_partition(&self) -> bool { + self.is_constant_in_partition + } } diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 23967e78f07a..33921a57a6ce 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -26,14 +26,13 @@ use crate::window::window_expr::AggregateWindowExpr; use crate::window::{ PartitionBatches, PartitionWindowAggStates, PlainAggregateWindowExpr, WindowExpr, }; -use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr}; +use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; use arrow::array::{Array, ArrayRef}; -use arrow::datatypes::Field; +use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{Accumulator, WindowFrame}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; /// A window expr that takes the form of an aggregate function that /// can be incrementally computed over sliding windows. @@ -43,7 +42,7 @@ use datafusion_physical_expr_common::sort_expr::LexOrdering; pub struct SlidingAggregateWindowExpr { aggregate: Arc, partition_by: Vec>, - order_by: LexOrdering, + order_by: Vec, window_frame: Arc, } @@ -52,13 +51,13 @@ impl SlidingAggregateWindowExpr { pub fn new( aggregate: Arc, partition_by: &[Arc], - order_by: &LexOrdering, + order_by: &[PhysicalSortExpr], window_frame: Arc, ) -> Self { Self { aggregate, partition_by: partition_by.to_vec(), - order_by: order_by.clone(), + order_by: order_by.to_vec(), window_frame, } } @@ -80,7 +79,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { self } - fn field(&self) -> Result { + fn field(&self) -> Result { Ok(self.aggregate.field()) } @@ -108,8 +107,8 @@ impl WindowExpr for SlidingAggregateWindowExpr { &self.partition_by } - fn order_by(&self) -> &LexOrdering { - self.order_by.as_ref() + fn order_by(&self) -> &[PhysicalSortExpr] { + &self.order_by } fn get_window_frame(&self) -> &Arc { @@ -123,14 +122,22 @@ impl WindowExpr for SlidingAggregateWindowExpr { Arc::new(PlainAggregateWindowExpr::new( Arc::new(reverse_expr), &self.partition_by.clone(), - reverse_order_bys(self.order_by.as_ref()).as_ref(), + &self + .order_by + .iter() + .map(|e| e.reverse()) + .collect::>(), Arc::new(self.window_frame.reverse()), )) as _ } else { Arc::new(SlidingAggregateWindowExpr::new( Arc::new(reverse_expr), &self.partition_by.clone(), - reverse_order_bys(self.order_by.as_ref()).as_ref(), + &self + .order_by + .iter() + .map(|e| e.reverse()) + .collect::>(), Arc::new(self.window_frame.reverse()), )) as _ } @@ -157,7 +164,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { expr: new_expr, options: req.options, }) - .collect::(); + .collect(); Some(Arc::new(SlidingAggregateWindowExpr { aggregate: self .aggregate @@ -210,4 +217,8 @@ impl AggregateWindowExpr for SlidingAggregateWindowExpr { accumulator.evaluate() } } + + fn is_constant_in_partition(&self) -> bool { + false + } } diff --git a/datafusion/physical-expr/src/window/standard.rs b/datafusion/physical-expr/src/window/standard.rs index 22e8aea83fe7..c3761aa78f72 100644 --- a/datafusion/physical-expr/src/window/standard.rs +++ b/datafusion/physical-expr/src/window/standard.rs @@ -24,23 +24,23 @@ use std::sync::Arc; use super::{StandardWindowFunctionExpr, WindowExpr}; use crate::window::window_expr::{get_orderby_values, WindowFn}; use crate::window::{PartitionBatches, PartitionWindowAggStates, WindowState}; -use crate::{reverse_order_bys, EquivalenceProperties, PhysicalExpr}; +use crate::{EquivalenceProperties, PhysicalExpr}; + use arrow::array::{new_empty_array, ArrayRef}; -use arrow::compute::SortOptions; -use arrow::datatypes::Field; +use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; use datafusion_common::utils::evaluate_partition_ranges; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::window_state::{WindowAggState, WindowFrameContext}; use datafusion_expr::WindowFrame; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; /// A window expr that takes the form of a [`StandardWindowFunctionExpr`]. #[derive(Debug)] pub struct StandardWindowExpr { expr: Arc, partition_by: Vec>, - order_by: LexOrdering, + order_by: Vec, window_frame: Arc, } @@ -49,13 +49,13 @@ impl StandardWindowExpr { pub fn new( expr: Arc, partition_by: &[Arc], - order_by: &LexOrdering, + order_by: &[PhysicalSortExpr], window_frame: Arc, ) -> Self { Self { expr, partition_by: partition_by.to_vec(), - order_by: order_by.clone(), + order_by: order_by.to_vec(), window_frame, } } @@ -70,15 +70,19 @@ impl StandardWindowExpr { /// If `self.expr` doesn't have an ordering, ordering equivalence properties /// are not updated. Otherwise, ordering equivalence properties are updated /// by the ordering of `self.expr`. - pub fn add_equal_orderings(&self, eq_properties: &mut EquivalenceProperties) { + pub fn add_equal_orderings( + &self, + eq_properties: &mut EquivalenceProperties, + ) -> Result<()> { let schema = eq_properties.schema(); if let Some(fn_res_ordering) = self.expr.get_result_ordering(schema) { add_new_ordering_expr_with_partition_by( eq_properties, fn_res_ordering, &self.partition_by, - ); + )?; } + Ok(()) } } @@ -92,7 +96,7 @@ impl WindowExpr for StandardWindowExpr { self.expr.name() } - fn field(&self) -> Result { + fn field(&self) -> Result { self.expr.field() } @@ -104,16 +108,15 @@ impl WindowExpr for StandardWindowExpr { &self.partition_by } - fn order_by(&self) -> &LexOrdering { - self.order_by.as_ref() + fn order_by(&self) -> &[PhysicalSortExpr] { + &self.order_by } fn evaluate(&self, batch: &RecordBatch) -> Result { let mut evaluator = self.expr.create_evaluator()?; let num_rows = batch.num_rows(); if evaluator.uses_window_frame() { - let sort_options: Vec = - self.order_by.iter().map(|o| o.options).collect(); + let sort_options = self.order_by.iter().map(|o| o.options).collect(); let mut row_wise_results = vec![]; let mut values = self.evaluate_args(batch)?; @@ -253,7 +256,11 @@ impl WindowExpr for StandardWindowExpr { Arc::new(StandardWindowExpr::new( reverse_expr, &self.partition_by.clone(), - reverse_order_bys(self.order_by.as_ref()).as_ref(), + &self + .order_by + .iter() + .map(|e| e.reverse()) + .collect::>(), Arc::new(self.window_frame.reverse()), )) as _ }) @@ -276,10 +283,10 @@ pub(crate) fn add_new_ordering_expr_with_partition_by( eqp: &mut EquivalenceProperties, expr: PhysicalSortExpr, partition_by: &[Arc], -) { +) -> Result<()> { if partition_by.is_empty() { // In the absence of a PARTITION BY, ordering of `self.expr` is global: - eqp.add_new_orderings([LexOrdering::new(vec![expr])]); + eqp.add_ordering([expr]); } else { // If we have a PARTITION BY, standard functions can not introduce // a global ordering unless the existing ordering is compatible @@ -287,10 +294,11 @@ pub(crate) fn add_new_ordering_expr_with_partition_by( // expressions and existing ordering expressions are equal (w.r.t. // set equality), we can prefix the ordering of `self.expr` with // the existing ordering. - let (mut ordering, _) = eqp.find_longest_permutation(partition_by); + let (mut ordering, _) = eqp.find_longest_permutation(partition_by)?; if ordering.len() == partition_by.len() { ordering.push(expr); - eqp.add_new_orderings([ordering]); + eqp.add_ordering(ordering); } } + Ok(()) } diff --git a/datafusion/physical-expr/src/window/standard_window_function_expr.rs b/datafusion/physical-expr/src/window/standard_window_function_expr.rs index 624b747d93f9..871f735e9a96 100644 --- a/datafusion/physical-expr/src/window/standard_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/standard_window_function_expr.rs @@ -18,7 +18,7 @@ use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::array::ArrayRef; -use arrow::datatypes::{Field, SchemaRef}; +use arrow::datatypes::{FieldRef, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_expr::PartitionEvaluator; @@ -41,7 +41,7 @@ pub trait StandardWindowFunctionExpr: Send + Sync + std::fmt::Debug { fn as_any(&self) -> &dyn Any; /// The field of the final result of evaluating this window function. - fn field(&self) -> Result; + fn field(&self) -> Result; /// Expressions that are passed to the [`PartitionEvaluator`]. fn expressions(&self) -> Vec>; diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 793f2e5ee586..dd671e068571 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -20,12 +20,12 @@ use std::fmt::Debug; use std::ops::Range; use std::sync::Arc; -use crate::{LexOrdering, PhysicalExpr}; +use crate::PhysicalExpr; use arrow::array::{new_empty_array, Array, ArrayRef}; use arrow::compute::kernels::sort::SortColumn; use arrow::compute::SortOptions; -use arrow::datatypes::Field; +use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; use datafusion_common::utils::compare_rows; use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; @@ -33,6 +33,7 @@ use datafusion_expr::window_state::{ PartitionBatchState, WindowAggState, WindowFrameContext, WindowFrameStateGroups, }; use datafusion_expr::{Accumulator, PartitionEvaluator, WindowFrame, WindowFrameBound}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use indexmap::IndexMap; @@ -67,7 +68,7 @@ pub trait WindowExpr: Send + Sync + Debug { fn as_any(&self) -> &dyn Any; /// The field of the final result of this window function. - fn field(&self) -> Result; + fn field(&self) -> Result; /// Human readable name such as `"MIN(c2)"` or `"RANK()"`. The default /// implementation returns placeholder text. @@ -109,14 +110,14 @@ pub trait WindowExpr: Send + Sync + Debug { fn partition_by(&self) -> &[Arc]; /// Expressions that's from the window function's order by clause, empty if absent - fn order_by(&self) -> &LexOrdering; + fn order_by(&self) -> &[PhysicalSortExpr]; /// Get order by columns, empty if absent fn order_by_columns(&self, batch: &RecordBatch) -> Result> { self.order_by() .iter() .map(|e| e.evaluate_to_sort_column(batch)) - .collect::>>() + .collect() } /// Get the window frame of this [WindowExpr]. @@ -138,7 +139,7 @@ pub trait WindowExpr: Send + Sync + Debug { .order_by() .iter() .map(|sort_expr| Arc::clone(&sort_expr.expr)) - .collect::>(); + .collect(); WindowPhysicalExpressions { args, partition_by_exprs, @@ -186,12 +187,15 @@ pub trait AggregateWindowExpr: WindowExpr { accumulator: &mut Box, ) -> Result; + /// Indicates whether this window function always produces the same result + /// for all rows in the partition. + fn is_constant_in_partition(&self) -> bool; + /// Evaluates the window function against the batch. fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result { let mut accumulator = self.get_accumulator()?; let mut last_range = Range { start: 0, end: 0 }; - let sort_options: Vec = - self.order_by().iter().map(|o| o.options).collect(); + let sort_options = self.order_by().iter().map(|o| o.options).collect(); let mut window_frame_ctx = WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options); self.get_result_column( @@ -239,8 +243,7 @@ pub trait AggregateWindowExpr: WindowExpr { // If there is no window state context, initialize it. let window_frame_ctx = state.window_frame_ctx.get_or_insert_with(|| { - let sort_options: Vec = - self.order_by().iter().map(|o| o.options).collect(); + let sort_options = self.order_by().iter().map(|o| o.options).collect(); WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options) }); let out_col = self.get_result_column( @@ -260,6 +263,15 @@ pub trait AggregateWindowExpr: WindowExpr { /// Calculates the window expression result for the given record batch. /// Assumes that `record_batch` belongs to a single partition. + /// + /// # Arguments + /// * `accumulator`: The accumulator to use for the calculation. + /// * `record_batch`: batch belonging to the current partition (see [`PartitionBatchState`]). + /// * `most_recent_row`: the batch that contains the most recent row, if available (see [`PartitionBatchState`]). + /// * `last_range`: The last range of rows that were processed (see [`WindowAggState`]). + /// * `window_frame_ctx`: Details about the window frame (see [`WindowFrameContext`]). + /// * `idx`: The index of the current row in the record batch. + /// * `not_end`: is the current row not the end of the partition (see [`PartitionBatchState`]). #[allow(clippy::too_many_arguments)] fn get_result_column( &self, @@ -272,8 +284,18 @@ pub trait AggregateWindowExpr: WindowExpr { not_end: bool, ) -> Result { let values = self.evaluate_args(record_batch)?; - let order_bys = get_orderby_values(self.order_by_columns(record_batch)?); + if self.is_constant_in_partition() { + if not_end { + let field = self.field()?; + let out_type = field.data_type(); + return Ok(new_empty_array(out_type)); + } + accumulator.update_batch(&values)?; + let value = accumulator.evaluate()?; + return value.to_array_of_size(record_batch.num_rows()); + } + let order_bys = get_orderby_values(self.order_by_columns(record_batch)?); let most_recent_row_order_bys = most_recent_row .map(|batch| self.order_by_columns(batch)) .transpose()? @@ -344,13 +366,13 @@ pub(crate) fn is_end_bound_safe( window_frame_ctx: &WindowFrameContext, order_bys: &[ArrayRef], most_recent_order_bys: Option<&[ArrayRef]>, - sort_exprs: &LexOrdering, + sort_exprs: &[PhysicalSortExpr], idx: usize, ) -> Result { if sort_exprs.is_empty() { // Early return if no sort expressions are present: return Ok(false); - } + }; match window_frame_ctx { WindowFrameContext::Rows(window_frame) => { diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml index aaadb09bcc98..c7795418bf10 100644 --- a/datafusion/physical-optimizer/Cargo.toml +++ b/datafusion/physical-optimizer/Cargo.toml @@ -46,6 +46,7 @@ datafusion-expr-common = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } +datafusion-pruning = { workspace = true } itertools = { workspace = true } log = { workspace = true } recursive = { workspace = true, optional = true } @@ -54,3 +55,4 @@ recursive = { workspace = true, optional = true } datafusion-expr = { workspace = true } datafusion-functions-nested = { workspace = true } insta = { workspace = true } +tokio = { workspace = true } diff --git a/datafusion/physical-optimizer/README.md b/datafusion/physical-optimizer/README.md index eb361d3f6779..374351b802c8 100644 --- a/datafusion/physical-optimizer/README.md +++ b/datafusion/physical-optimizer/README.md @@ -23,3 +23,10 @@ DataFusion is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. This crate contains the physical optimizer for DataFusion. + +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + +[df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs b/datafusion/physical-optimizer/src/aggregate_statistics.rs index 28ee10eb650a..6c44c8fe86c5 100644 --- a/datafusion/physical-optimizer/src/aggregate_statistics.rs +++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs @@ -53,7 +53,7 @@ impl PhysicalOptimizerRule for AggregateStatistics { .as_any() .downcast_ref::() .expect("take_optimizable() ensures that this is a AggregateExec"); - let stats = partial_agg_exec.input().statistics()?; + let stats = partial_agg_exec.input().partition_statistics(None)?; let mut projections = vec![]; for expr in partial_agg_exec.aggr_expr() { let field = expr.field(); diff --git a/datafusion/physical-optimizer/src/coalesce_async_exec_input.rs b/datafusion/physical-optimizer/src/coalesce_async_exec_input.rs new file mode 100644 index 000000000000..0b46c68f2dae --- /dev/null +++ b/datafusion/physical-optimizer/src/coalesce_async_exec_input.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::PhysicalOptimizerRule; +use datafusion_common::config::ConfigOptions; +use datafusion_common::internal_err; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_physical_plan::async_func::AsyncFuncExec; +use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_physical_plan::ExecutionPlan; +use std::sync::Arc; + +/// Optimizer rule that introduces CoalesceAsyncExec to reduce the number of async executions. +#[derive(Default, Debug)] +pub struct CoalesceAsyncExecInput {} + +impl CoalesceAsyncExecInput { + #[allow(missing_docs)] + pub fn new() -> Self { + Self::default() + } +} + +impl PhysicalOptimizerRule for CoalesceAsyncExecInput { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> datafusion_common::Result> { + let target_batch_size = config.execution.batch_size; + plan.transform(|plan| { + if let Some(async_exec) = plan.as_any().downcast_ref::() { + if async_exec.children().len() != 1 { + return internal_err!( + "Expected AsyncFuncExec to have exactly one child" + ); + } + let child = Arc::clone(async_exec.children()[0]); + let coalesce_exec = + Arc::new(CoalesceBatchesExec::new(child, target_batch_size)); + let coalesce_async_exec = plan.with_new_children(vec![coalesce_exec])?; + Ok(Transformed::yes(coalesce_async_exec)) + } else { + Ok(Transformed::no(plan)) + } + }) + .data() + } + + fn name(&self) -> &str { + "coalesce_async_exec_input" + } + + fn schema_check(&self) -> bool { + true + } +} diff --git a/datafusion/physical-optimizer/src/enforce_distribution.rs b/datafusion/physical-optimizer/src/enforce_distribution.rs index b314b43c6a14..39eb557ea601 100644 --- a/datafusion/physical-optimizer/src/enforce_distribution.rs +++ b/datafusion/physical-optimizer/src/enforce_distribution.rs @@ -42,7 +42,6 @@ use datafusion_physical_expr::utils::map_columns_before_projection; use datafusion_physical_expr::{ physical_exprs_equal, EquivalenceProperties, PhysicalExpr, PhysicalExprRef, }; -use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; @@ -296,7 +295,7 @@ pub fn adjust_input_keys_ordering( join_type, projection, mode, - null_equals_null, + null_equality, .. }) = plan.as_any().downcast_ref::() { @@ -315,7 +314,7 @@ pub fn adjust_input_keys_ordering( // TODO: although projection is not used in the join here, because projection pushdown is after enforce_distribution. Maybe we need to handle it later. Same as filter. projection.clone(), PartitionMode::Partitioned, - *null_equals_null, + *null_equality, ) .map(|e| Arc::new(e) as _) }; @@ -335,7 +334,7 @@ pub fn adjust_input_keys_ordering( left.schema().fields().len(), ) .unwrap_or_default(), - JoinType::RightSemi | JoinType::RightAnti => { + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { requirements.data.clone() } JoinType::Left @@ -365,7 +364,7 @@ pub fn adjust_input_keys_ordering( filter, join_type, sort_options, - null_equals_null, + null_equality, .. }) = plan.as_any().downcast_ref::() { @@ -380,7 +379,7 @@ pub fn adjust_input_keys_ordering( filter.clone(), *join_type, new_conditions.1, - *null_equals_null, + *null_equality, ) .map(|e| Arc::new(e) as _) }; @@ -617,7 +616,7 @@ pub fn reorder_join_keys_to_inputs( join_type, projection, mode, - null_equals_null, + null_equality, .. }) = plan_any.downcast_ref::() { @@ -643,7 +642,7 @@ pub fn reorder_join_keys_to_inputs( join_type, projection.clone(), PartitionMode::Partitioned, - *null_equals_null, + *null_equality, )?)); } } @@ -654,7 +653,7 @@ pub fn reorder_join_keys_to_inputs( filter, join_type, sort_options, - null_equals_null, + null_equality, .. }) = plan_any.downcast_ref::() { @@ -682,7 +681,7 @@ pub fn reorder_join_keys_to_inputs( filter.clone(), *join_type, new_sort_options, - *null_equals_null, + *null_equality, ) .map(|smj| Arc::new(smj) as _); } @@ -945,16 +944,10 @@ fn add_spm_on_top(input: DistributionContext) -> DistributionContext { // if any of the following conditions is true // - Preserving ordering is not helpful in terms of satisfying ordering requirements // - Usage of order preserving variants is not desirable - // (determined by flag `config.optimizer.bounded_order_preserving_variants`) - let should_preserve_ordering = input.plan.output_ordering().is_some(); - - let new_plan = if should_preserve_ordering { + // (determined by flag `config.optimizer.prefer_existing_sort`) + let new_plan = if let Some(ordering) = input.plan.output_ordering() { Arc::new(SortPreservingMergeExec::new( - input - .plan - .output_ordering() - .unwrap_or(&LexOrdering::default()) - .clone(), + ordering.clone(), Arc::clone(&input.plan), )) as _ } else { @@ -1018,7 +1011,7 @@ fn remove_dist_changing_operators( /// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", /// " DataSourceExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC], file_type=parquet", /// ``` -fn replace_order_preserving_variants( +pub fn replace_order_preserving_variants( mut context: DistributionContext, ) -> Result { context.children = context @@ -1035,7 +1028,9 @@ fn replace_order_preserving_variants( if is_sort_preserving_merge(&context.plan) { let child_plan = Arc::clone(&context.children[0].plan); - context.plan = Arc::new(CoalescePartitionsExec::new(child_plan)); + context.plan = Arc::new( + CoalescePartitionsExec::new(child_plan).with_fetch(context.plan.fetch()), + ); return Ok(context); } else if let Some(repartition) = context.plan.as_any().downcast_ref::() @@ -1112,7 +1107,8 @@ fn get_repartition_requirement_status( { // Decide whether adding a round robin is beneficial depending on // the statistical information we have on the number of rows: - let roundrobin_beneficial_stats = match child.statistics()?.num_rows { + let roundrobin_beneficial_stats = match child.partition_statistics(None)?.num_rows + { Precision::Exact(n_rows) => n_rows > batch_size, Precision::Inexact(n_rows) => !should_use_estimates || (n_rows > batch_size), Precision::Absent => true, @@ -1155,6 +1151,10 @@ fn get_repartition_requirement_status( /// operators to satisfy distribution requirements. Since this function /// takes care of such requirements, we should avoid manually adding data /// exchange operators in other places. +/// +/// This function is intended to be used in a bottom up traversal, as it +/// can first repartition (or newly partition) at the datasources -- these +/// source partitions may be later repartitioned with additional data exchange operators. pub fn ensure_distribution( dist_context: DistributionContext, config: &ConfigOptions, @@ -1244,6 +1244,10 @@ pub fn ensure_distribution( // When `repartition_file_scans` is set, attempt to increase // parallelism at the source. + // + // If repartitioning is not possible (a.k.a. None is returned from `ExecutionPlan::repartitioned`) + // then no repartitioning will have occurred. As the default implementation returns None, it is only + // specific physical plan nodes, such as certain datasources, which are repartitioned. if repartition_file_scans && roundrobin_beneficial_stats { if let Some(new_child) = child.plan.repartitioned(target_partitions, config)? @@ -1283,10 +1287,12 @@ pub fn ensure_distribution( // Either: // - Ordering requirement cannot be satisfied by preserving ordering through repartitions, or // - using order preserving variant is not desirable. + let sort_req = required_input_ordering.into_single(); let ordering_satisfied = child .plan .equivalence_properties() - .ordering_satisfy_requirement(&required_input_ordering); + .ordering_satisfy_requirement(sort_req.clone())?; + if (!ordering_satisfied || !order_preserving_variants_desirable) && child.data { @@ -1297,9 +1303,12 @@ pub fn ensure_distribution( // Make sure to satisfy ordering requirement: child = add_sort_above_with_check( child, - required_input_ordering.clone(), - None, - ); + sort_req, + plan.as_any() + .downcast_ref::() + .map(|output| output.fetch()) + .unwrap_or(None), + )?; } } // Stop tracking distribution changing operators diff --git a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs index 20733b65692f..8a71b28486a2 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs @@ -46,6 +46,7 @@ use crate::enforce_sorting::replace_with_order_preserving_variants::{ use crate::enforce_sorting::sort_pushdown::{ assign_initial_requirements, pushdown_sorts, SortPushDown, }; +use crate::output_requirements::OutputRequirementExec; use crate::utils::{ add_sort_above, add_sort_above_with_check, is_coalesce_partitions, is_limit, is_repartition, is_sort, is_sort_preserving_merge, is_union, is_window, @@ -191,14 +192,20 @@ fn update_coalesce_ctx_children( } /// Performs optimizations based upon a series of subrules. -/// /// Refer to each subrule for detailed descriptions of the optimizations performed: -/// [`ensure_sorting`], [`parallelize_sorts`], [`replace_with_order_preserving_variants()`], -/// and [`pushdown_sorts`]. -/// /// Subrule application is ordering dependent. /// -/// The subrule `parallelize_sorts` is only applied if `repartition_sorts` is enabled. +/// Optimizer consists of 5 main parts which work sequentially +/// 1. [`ensure_sorting`] Works down-to-top to be able to remove unnecessary [`SortExec`]s, [`SortPreservingMergeExec`]s +/// add [`SortExec`]s if necessary by a requirement and adjusts window operators. +/// 2. [`parallelize_sorts`] (Optional, depends on the `repartition_sorts` configuration) +/// Responsible to identify and remove unnecessary partition unifier operators +/// such as [`SortPreservingMergeExec`], [`CoalescePartitionsExec`] follows [`SortExec`]s does possible simplifications. +/// 3. [`replace_with_order_preserving_variants()`] Replaces with alternative operators, for example can merge +/// a [`SortExec`] and a [`CoalescePartitionsExec`] into one [`SortPreservingMergeExec`] +/// or a [`SortExec`] + [`RepartitionExec`] combination into an order preserving [`RepartitionExec`] +/// 4. [`sort_pushdown`] Works top-down. Responsible to push down sort operators as deep as possible in the plan. +/// 5. `replace_with_partial_sort` Checks if it's possible to replace [`SortExec`]s with [`PartialSortExec`] operators impl PhysicalOptimizerRule for EnforceSorting { fn optimize( &self, @@ -251,87 +258,93 @@ impl PhysicalOptimizerRule for EnforceSorting { } } +/// Only interested with [`SortExec`]s and their unbounded children. +/// If the plan is not a [`SortExec`] or its child is not unbounded, returns the original plan. +/// Otherwise, by checking the requirement satisfaction searches for a replacement chance. +/// If there's one replaces the [`SortExec`] plan with a [`PartialSortExec`] fn replace_with_partial_sort( plan: Arc, ) -> Result> { let plan_any = plan.as_any(); - if let Some(sort_plan) = plan_any.downcast_ref::() { - let child = Arc::clone(sort_plan.children()[0]); - if !child.boundedness().is_unbounded() { - return Ok(plan); - } + let Some(sort_plan) = plan_any.downcast_ref::() else { + return Ok(plan); + }; - // here we're trying to find the common prefix for sorted columns that is required for the - // sort and already satisfied by the given ordering - let child_eq_properties = child.equivalence_properties(); - let sort_req = LexRequirement::from(sort_plan.expr().clone()); + // It's safe to get first child of the SortExec + let child = Arc::clone(sort_plan.children()[0]); + if !child.boundedness().is_unbounded() { + return Ok(plan); + } - let mut common_prefix_length = 0; - while child_eq_properties.ordering_satisfy_requirement(&LexRequirement { - inner: sort_req[0..common_prefix_length + 1].to_vec(), - }) { - common_prefix_length += 1; - } - if common_prefix_length > 0 { - return Ok(Arc::new( - PartialSortExec::new( - LexOrdering::new(sort_plan.expr().to_vec()), - Arc::clone(sort_plan.input()), - common_prefix_length, - ) - .with_preserve_partitioning(sort_plan.preserve_partitioning()) - .with_fetch(sort_plan.fetch()), - )); - } + // Here we're trying to find the common prefix for sorted columns that is required for the + // sort and already satisfied by the given ordering + let child_eq_properties = child.equivalence_properties(); + let sort_exprs = sort_plan.expr().clone(); + + let mut common_prefix_length = 0; + while child_eq_properties + .ordering_satisfy(sort_exprs[0..common_prefix_length + 1].to_vec())? + { + common_prefix_length += 1; + } + if common_prefix_length > 0 { + return Ok(Arc::new( + PartialSortExec::new( + sort_exprs, + Arc::clone(sort_plan.input()), + common_prefix_length, + ) + .with_preserve_partitioning(sort_plan.preserve_partitioning()) + .with_fetch(sort_plan.fetch()), + )); } Ok(plan) } -/// Transform [`CoalescePartitionsExec`] + [`SortExec`] into -/// [`SortExec`] + [`SortPreservingMergeExec`] as illustrated below: +/// Transform [`CoalescePartitionsExec`] + [`SortExec`] cascades into [`SortExec`] +/// + [`SortPreservingMergeExec`] cascades, as illustrated below. /// -/// The [`CoalescePartitionsExec`] + [`SortExec`] cascades -/// combine the partitions first, and then sort: +/// A [`CoalescePartitionsExec`] + [`SortExec`] cascade combines partitions +/// first, and then sorts: /// ```text -/// ┌ ─ ─ ─ ─ ─ ┐ -/// ┌─┬─┬─┐ -/// ││B│A│D│... ├──┐ -/// └─┴─┴─┘ │ +/// ┌ ─ ─ ─ ─ ─ ┐ +/// ┌─┬─┬─┐ +/// ││B│A│D│... ├──┐ +/// └─┴─┴─┘ │ /// └ ─ ─ ─ ─ ─ ┘ │ ┌────────────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ┐ -/// Partition 1 │ │ Coalesce │ ┌─┬─┬─┬─┬─┐ │ │ ┌─┬─┬─┬─┬─┐ +/// Partition 1 │ │ Coalesce │ ┌─┬─┬─┬─┬─┐ │ │ ┌─┬─┬─┬─┬─┐ /// ├──▶(no ordering guarantees)│──▶││B│E│A│D│C│...───▶ Sort ├───▶││A│B│C│D│E│... │ -/// │ │ │ └─┴─┴─┴─┴─┘ │ │ └─┴─┴─┴─┴─┘ +/// │ │ │ └─┴─┴─┴─┴─┘ │ │ └─┴─┴─┴─┴─┘ /// ┌ ─ ─ ─ ─ ─ ┐ │ └────────────────────────┘ └ ─ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ─ ─ ┘ -/// ┌─┬─┐ │ Partition Partition -/// ││E│C│ ... ├──┘ -/// └─┴─┘ -/// └ ─ ─ ─ ─ ─ ┘ -/// Partition 2 -/// ``` +/// ┌─┬─┐ │ Partition Partition +/// ││E│C│ ... ├──┘ +/// └─┴─┘ +/// └ ─ ─ ─ ─ ─ ┘ +/// Partition 2 +/// ``` /// /// -/// The [`SortExec`] + [`SortPreservingMergeExec`] cascades -/// sorts each partition first, then merge partitions while retaining the sort: +/// A [`SortExec`] + [`SortPreservingMergeExec`] cascade sorts each partition +/// first, then merges partitions while preserving the sort: /// ```text -/// ┌ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ┐ -/// ┌─┬─┬─┐ │ │ ┌─┬─┬─┐ -/// ││B│A│D│... │──▶│ Sort │──▶││A│B│D│... │──┐ -/// └─┴─┴─┘ │ │ └─┴─┴─┘ │ +/// ┌ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ┐ +/// ┌─┬─┬─┐ │ │ ┌─┬─┬─┐ +/// ││B│A│D│... │──▶│ Sort │──▶││A│B│D│... │──┐ +/// └─┴─┴─┘ │ │ └─┴─┴─┘ │ /// └ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ┘ │ ┌─────────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ┐ -/// Partition 1 Partition 1 │ │ │ ┌─┬─┬─┬─┬─┐ +/// Partition 1 Partition 1 │ │ │ ┌─┬─┬─┬─┬─┐ /// ├──▶ SortPreservingMerge ├───▶││A│B│C│D│E│... │ -/// │ │ │ └─┴─┴─┴─┴─┘ +/// │ │ │ └─┴─┴─┴─┴─┘ /// ┌ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ┐ │ └─────────────────────┘ └ ─ ─ ─ ─ ─ ─ ─ ┘ -/// ┌─┬─┐ │ │ ┌─┬─┐ │ Partition -/// ││E│C│ ... │──▶│ Sort ├──▶││C│E│ ... │──┘ -/// └─┴─┘ │ │ └─┴─┘ -/// └ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ┘ -/// Partition 2 Partition 2 +/// ┌─┬─┐ │ │ ┌─┬─┐ │ Partition +/// ││E│C│ ... │──▶│ Sort ├──▶││C│E│ ... │──┘ +/// └─┴─┘ │ │ └─┴─┘ +/// └ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ┘ +/// Partition 2 Partition 2 /// ``` /// -/// The latter [`SortExec`] + [`SortPreservingMergeExec`] cascade performs the -/// sort first on a per-partition basis, thereby parallelizing the sort. -/// +/// The latter [`SortExec`] + [`SortPreservingMergeExec`] cascade performs +/// sorting first on a per-partition basis, thereby parallelizing the sort. /// /// The outcome is that plans of the form /// ```text @@ -348,16 +361,32 @@ fn replace_with_partial_sort( /// " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", /// ``` /// by following connections from [`CoalescePartitionsExec`]s to [`SortExec`]s. -/// By performing sorting in parallel, we can increase performance in some scenarios. +/// By performing sorting in parallel, we can increase performance in some +/// scenarios. /// -/// This requires that there are no nodes between the [`SortExec`] and [`CoalescePartitionsExec`] -/// which require single partitioning. Do not parallelize when the following scenario occurs: +/// This optimization requires that there are no nodes between the [`SortExec`] +/// and the [`CoalescePartitionsExec`], which requires single partitioning. Do +/// not parallelize when the following scenario occurs: /// ```text /// "SortExec: expr=\[a@0 ASC\]", /// " ...nodes requiring single partitioning..." /// " CoalescePartitionsExec", /// " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", /// ``` +/// +/// **Steps** +/// 1. Checks if the plan is either a [`SortExec`], a [`SortPreservingMergeExec`], +/// or a [`CoalescePartitionsExec`]. Otherwise, does nothing. +/// 2. If the plan is a [`SortExec`] or a final [`SortPreservingMergeExec`] +/// (i.e. output partitioning is 1): +/// - Check for [`CoalescePartitionsExec`] in children. If found, check if +/// it can be removed (with possible [`RepartitionExec`]s). If so, remove +/// (see `remove_bottleneck_in_subplan`). +/// - If the plan is satisfying the ordering requirements, add a `SortExec`. +/// - Add an SPM above the plan and return. +/// 3. If the plan is a [`CoalescePartitionsExec`]: +/// - Check if it can be removed (with possible [`RepartitionExec`]s). +/// If so, remove (see `remove_bottleneck_in_subplan`). pub fn parallelize_sorts( mut requirements: PlanWithCorrespondingCoalescePartitions, ) -> Result> { @@ -388,7 +417,7 @@ pub fn parallelize_sorts( // deals with the children and their children and so on. requirements = requirements.children.swap_remove(0); - requirements = add_sort_above_with_check(requirements, sort_reqs, fetch); + requirements = add_sort_above_with_check(requirements, sort_reqs, fetch)?; let spm = SortPreservingMergeExec::new(sort_exprs, Arc::clone(&requirements.plan)); @@ -400,6 +429,7 @@ pub fn parallelize_sorts( ), )) } else if is_coalesce_partitions(&requirements.plan) { + let fetch = requirements.plan.fetch(); // There is an unnecessary `CoalescePartitionsExec` in the plan. // This will handle the recursive `CoalescePartitionsExec` plans. requirements = remove_bottleneck_in_subplan(requirements)?; @@ -408,7 +438,10 @@ pub fn parallelize_sorts( Ok(Transformed::yes( PlanWithCorrespondingCoalescePartitions::new( - Arc::new(CoalescePartitionsExec::new(Arc::clone(&requirements.plan))), + Arc::new( + CoalescePartitionsExec::new(Arc::clone(&requirements.plan)) + .with_fetch(fetch), + ), false, vec![requirements], ), @@ -420,6 +453,25 @@ pub fn parallelize_sorts( /// This function enforces sorting requirements and makes optimizations without /// violating these requirements whenever possible. Requires a bottom-up traversal. +/// +/// **Steps** +/// 1. Analyze if there are any immediate removals of [`SortExec`]s. If so, +/// removes them (see `analyze_immediate_sort_removal`). +/// 2. For each child of the plan, if the plan requires an input ordering: +/// - Checks if ordering is satisfied with the child. If not: +/// - If the child has an output ordering, removes the unnecessary +/// `SortExec`. +/// - Adds sort above the child plan. +/// - (Plan not requires input ordering) +/// - Checks if the `SortExec` is neutralized in the plan. If so, +/// removes it. +/// 3. Check and modify window operator: +/// - Checks if the plan is a window operator, and connected with a sort. +/// If so, either tries to update the window definition or removes +/// unnecessary [`SortExec`]s (see `adjust_window_sort_removal`). +/// 4. Check and remove possibly unnecessary SPM: +/// - Checks if the plan is SPM and child 1 output partitions, if so +/// decides this SPM is unnecessary and removes it from the plan. pub fn ensure_sorting( mut requirements: PlanWithCorrespondingSort, ) -> Result> { @@ -429,7 +481,7 @@ pub fn ensure_sorting( if requirements.children.is_empty() { return Ok(Transformed::no(requirements)); } - let maybe_requirements = analyze_immediate_sort_removal(requirements); + let maybe_requirements = analyze_immediate_sort_removal(requirements)?; requirements = if !maybe_requirements.transformed { maybe_requirements.data } else { @@ -448,12 +500,20 @@ pub fn ensure_sorting( if let Some(required) = required_ordering { let eq_properties = child.plan.equivalence_properties(); - if !eq_properties.ordering_satisfy_requirement(&required) { + let req = required.into_single(); + if !eq_properties.ordering_satisfy_requirement(req.clone())? { // Make sure we preserve the ordering requirements: if physical_ordering.is_some() { child = update_child_to_remove_unnecessary_sort(idx, child, plan)?; } - child = add_sort_above(child, required, None); + child = add_sort_above( + child, + req, + plan.as_any() + .downcast_ref::() + .map(|output| output.fetch()) + .unwrap_or(None), + ); child = update_sort_ctx_children_data(child, true)?; } } else if physical_ordering.is_none() @@ -489,60 +549,56 @@ pub fn ensure_sorting( update_sort_ctx_children_data(requirements, false).map(Transformed::yes) } -/// Analyzes a given [`SortExec`] (`plan`) to determine whether its input -/// already has a finer ordering than it enforces. +/// Analyzes if there are any immediate sort removals by checking the `SortExec`s +/// and their ordering requirement satisfactions with children +/// If the sort is unnecessary, either replaces it with [`SortPreservingMergeExec`]/`LimitExec` +/// or removes the [`SortExec`]. +/// Otherwise, returns the original plan fn analyze_immediate_sort_removal( mut node: PlanWithCorrespondingSort, -) -> Transformed { - if let Some(sort_exec) = node.plan.as_any().downcast_ref::() { - let sort_input = sort_exec.input(); - // If this sort is unnecessary, we should remove it: - if sort_input.equivalence_properties().ordering_satisfy( - sort_exec - .properties() - .output_ordering() - .unwrap_or(LexOrdering::empty()), - ) { - node.plan = if !sort_exec.preserve_partitioning() - && sort_input.output_partitioning().partition_count() > 1 - { - // Replace the sort with a sort-preserving merge: - let expr = LexOrdering::new(sort_exec.expr().to_vec()); - Arc::new( - SortPreservingMergeExec::new(expr, Arc::clone(sort_input)) - .with_fetch(sort_exec.fetch()), - ) as _ +) -> Result> { + let Some(sort_exec) = node.plan.as_any().downcast_ref::() else { + return Ok(Transformed::no(node)); + }; + let sort_input = sort_exec.input(); + // Check if the sort is unnecessary: + let properties = sort_exec.properties(); + if let Some(ordering) = properties.output_ordering().cloned() { + let eqp = sort_input.equivalence_properties(); + if !eqp.ordering_satisfy(ordering)? { + return Ok(Transformed::no(node)); + } + } + node.plan = if !sort_exec.preserve_partitioning() + && sort_input.output_partitioning().partition_count() > 1 + { + // Replace the sort with a sort-preserving merge: + Arc::new( + SortPreservingMergeExec::new( + sort_exec.expr().clone(), + Arc::clone(sort_input), + ) + .with_fetch(sort_exec.fetch()), + ) as _ + } else { + // Remove the sort: + node.children = node.children.swap_remove(0).children; + if let Some(fetch) = sort_exec.fetch() { + // If the sort has a fetch, we need to add a limit: + if properties.output_partitioning().partition_count() == 1 { + Arc::new(GlobalLimitExec::new(Arc::clone(sort_input), 0, Some(fetch))) } else { - // Remove the sort: - node.children = node.children.swap_remove(0).children; - if let Some(fetch) = sort_exec.fetch() { - // If the sort has a fetch, we need to add a limit: - if sort_exec - .properties() - .output_partitioning() - .partition_count() - == 1 - { - Arc::new(GlobalLimitExec::new( - Arc::clone(sort_input), - 0, - Some(fetch), - )) - } else { - Arc::new(LocalLimitExec::new(Arc::clone(sort_input), fetch)) - } - } else { - Arc::clone(sort_input) - } - }; - for child in node.children.iter_mut() { - child.data = false; + Arc::new(LocalLimitExec::new(Arc::clone(sort_input), fetch)) } - node.data = false; - return Transformed::yes(node); + } else { + Arc::clone(sort_input) } + }; + for child in node.children.iter_mut() { + child.data = false; } - Transformed::no(node) + node.data = false; + Ok(Transformed::yes(node)) } /// Adjusts a [`WindowAggExec`] or a [`BoundedWindowAggExec`] to determine @@ -583,15 +639,13 @@ fn adjust_window_sort_removal( } else { // We were unable to change the window to accommodate the input, so we // will insert a sort. - let reqs = window_tree - .plan - .required_input_ordering() - .swap_remove(0) - .unwrap_or_default(); + let reqs = window_tree.plan.required_input_ordering().swap_remove(0); // Satisfy the ordering requirement so that the window can run: let mut child_node = window_tree.children.swap_remove(0); - child_node = add_sort_above(child_node, reqs, None); + if let Some(reqs) = reqs { + child_node = add_sort_above(child_node, reqs.into_single(), None); + } let child_plan = Arc::clone(&child_node.plan); window_tree.children.push(child_node); @@ -738,8 +792,7 @@ fn remove_corresponding_sort_from_sub_plan( let fetch = plan.fetch(); let plan = if let Some(ordering) = plan.output_ordering() { Arc::new( - SortPreservingMergeExec::new(LexOrdering::new(ordering.to_vec()), plan) - .with_fetch(fetch), + SortPreservingMergeExec::new(ordering.clone(), plan).with_fetch(fetch), ) as _ } else { Arc::new(CoalescePartitionsExec::new(plan)) as _ diff --git a/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs b/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs index 2c5c0d4d510e..b536e7960208 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs @@ -27,8 +27,7 @@ use crate::utils::{ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::Transformed; -use datafusion_common::Result; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_common::{internal_err, Result}; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::execution_plan::EmissionType; use datafusion_physical_plan::repartition::RepartitionExec; @@ -93,7 +92,7 @@ pub fn update_order_preservation_ctx_children_data(opc: &mut OrderPreservationCo /// inside `sort_input` with their order-preserving variants. This will /// generate an alternative plan, which will be accepted or rejected later on /// depending on whether it helps us remove a `SortExec`. -fn plan_with_order_preserving_variants( +pub fn plan_with_order_preserving_variants( mut sort_input: OrderPreservationContext, // Flag indicating that it is desirable to replace `RepartitionExec`s with // `SortPreservingRepartitionExec`s: @@ -138,6 +137,19 @@ fn plan_with_order_preserving_variants( } else if is_coalesce_partitions(&sort_input.plan) && is_spm_better { let child = &sort_input.children[0].plan; if let Some(ordering) = child.output_ordering() { + let mut fetch = fetch; + if let Some(coalesce_fetch) = sort_input.plan.fetch() { + if let Some(sort_fetch) = fetch { + if coalesce_fetch < sort_fetch { + return internal_err!( + "CoalescePartitionsExec fetch [{:?}] should be greater than or equal to SortExec fetch [{:?}]", coalesce_fetch, sort_fetch + ); + } + } else { + // If the sort node does not have a fetch, we need to keep the coalesce node's fetch. + fetch = Some(coalesce_fetch); + } + }; // When the input of a `CoalescePartitionsExec` has an ordering, // replace it with a `SortPreservingMergeExec` if appropriate: let spm = SortPreservingMergeExec::new(ordering.clone(), Arc::clone(child)) @@ -154,7 +166,7 @@ fn plan_with_order_preserving_variants( /// Calculates the updated plan by replacing operators that preserve ordering /// inside `sort_input` with their order-breaking variants. This will restore /// the original plan modified by [`plan_with_order_preserving_variants`]. -fn plan_with_order_breaking_variants( +pub fn plan_with_order_breaking_variants( mut sort_input: OrderPreservationContext, ) -> Result { let plan = &sort_input.plan; @@ -166,18 +178,17 @@ fn plan_with_order_breaking_variants( .map(|(node, maintains, required_ordering)| { // Replace with non-order preserving variants as long as ordering is // not required by intermediate operators: - if maintains - && (is_sort_preserving_merge(plan) - || !required_ordering.is_some_and(|required_ordering| { - node.plan - .equivalence_properties() - .ordering_satisfy_requirement(&required_ordering) - })) - { - plan_with_order_breaking_variants(node) - } else { - Ok(node) + if !maintains { + return Ok(node); + } else if is_sort_preserving_merge(plan) { + return plan_with_order_breaking_variants(node); + } else if let Some(required_ordering) = required_ordering { + let eqp = node.plan.equivalence_properties(); + if eqp.ordering_satisfy_requirement(required_ordering.into_single())? { + return Ok(node); + } } + plan_with_order_breaking_variants(node) }) .collect::>()?; sort_input.data = false; @@ -189,10 +200,12 @@ fn plan_with_order_breaking_variants( let partitioning = plan.output_partitioning().clone(); sort_input.plan = Arc::new(RepartitionExec::try_new(child, partitioning)?) as _; } else if is_sort_preserving_merge(plan) { - // Replace `SortPreservingMergeExec` with a `CoalescePartitionsExec`: + // Replace `SortPreservingMergeExec` with a `CoalescePartitionsExec` + // SPM may have `fetch`, so pass it to the `CoalescePartitionsExec` let child = Arc::clone(&sort_input.children[0].plan); - let coalesce = CoalescePartitionsExec::new(child); - sort_input.plan = Arc::new(coalesce) as _; + let coalesce = + Arc::new(CoalescePartitionsExec::new(child).with_fetch(plan.fetch())); + sort_input.plan = coalesce; } else { return sort_input.update_plan_from_children(); } @@ -264,25 +277,18 @@ pub fn replace_with_order_preserving_variants( )?; // If the alternate plan makes this sort unnecessary, accept the alternate: - if alternate_plan - .plan - .equivalence_properties() - .ordering_satisfy( - requirements - .plan - .output_ordering() - .unwrap_or(LexOrdering::empty()), - ) - { - for child in alternate_plan.children.iter_mut() { - child.data = false; + if let Some(ordering) = requirements.plan.output_ordering() { + let eqp = alternate_plan.plan.equivalence_properties(); + if !eqp.ordering_satisfy(ordering.clone())? { + // The alternate plan does not help, use faster order-breaking variants: + alternate_plan = plan_with_order_breaking_variants(alternate_plan)?; + alternate_plan.data = false; + requirements.children = vec![alternate_plan]; + return Ok(Transformed::yes(requirements)); } - Ok(Transformed::yes(alternate_plan)) - } else { - // The alternate plan does not help, use faster order-breaking variants: - alternate_plan = plan_with_order_breaking_variants(alternate_plan)?; - alternate_plan.data = false; - requirements.children = vec![alternate_plan]; - Ok(Transformed::yes(requirements)) } + for child in alternate_plan.children.iter_mut() { + child.data = false; + } + Ok(Transformed::yes(alternate_plan)) } diff --git a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs index 2e20608d0e9e..a9c0e4cb2858 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs @@ -24,12 +24,17 @@ use crate::utils::{ use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{plan_err, HashSet, JoinSide, Result}; +use datafusion_common::{internal_err, HashSet, JoinSide, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr::PhysicalSortRequirement; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr::{ + add_offset_to_physical_sort_exprs, EquivalenceProperties, +}; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortExpr, + PhysicalSortRequirement, +}; use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::{ calculate_join_output_ordering, ColumnIndex, @@ -50,7 +55,7 @@ use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; /// [`EnforceSorting`]: crate::enforce_sorting::EnforceSorting #[derive(Default, Clone, Debug)] pub struct ParentRequirements { - ordering_requirement: Option, + ordering_requirement: Option, fetch: Option, } @@ -69,6 +74,7 @@ pub fn assign_initial_requirements(sort_push_down: &mut SortPushDown) { } } +/// Tries to push down the sort requirements as far as possible, if decides a `SortExec` is unnecessary removes it. pub fn pushdown_sorts(sort_push_down: SortPushDown) -> Result { sort_push_down .transform_down(pushdown_sorts_helper) @@ -87,91 +93,107 @@ fn min_fetch(f1: Option, f2: Option) -> Option { fn pushdown_sorts_helper( mut sort_push_down: SortPushDown, ) -> Result> { - let plan = &sort_push_down.plan; - let parent_reqs = sort_push_down - .data - .ordering_requirement - .clone() - .unwrap_or_default(); - let satisfy_parent = plan - .equivalence_properties() - .ordering_satisfy_requirement(&parent_reqs); - - if is_sort(plan) { - let current_sort_fetch = plan.fetch(); - let parent_req_fetch = sort_push_down.data.fetch; - - let current_plan_reqs = plan - .output_ordering() - .cloned() - .map(LexRequirement::from) - .unwrap_or_default(); - let parent_is_stricter = plan - .equivalence_properties() - .requirements_compatible(&parent_reqs, ¤t_plan_reqs); - let current_is_stricter = plan - .equivalence_properties() - .requirements_compatible(¤t_plan_reqs, &parent_reqs); + let plan = sort_push_down.plan; + let parent_fetch = sort_push_down.data.fetch; - if !satisfy_parent && !parent_is_stricter { - // This new sort has different requirements than the ordering being pushed down. - // 1. add a `SortExec` here for the pushed down ordering (parent reqs). - // 2. continue sort pushdown, but with the new ordering of the new sort. + let Some(parent_requirement) = sort_push_down.data.ordering_requirement.clone() + else { + // If there are no ordering requirements from the parent, nothing to do + // unless we have a sort. + if is_sort(&plan) { + let Some(sort_ordering) = plan.output_ordering().cloned() else { + return internal_err!("SortExec should have output ordering"); + }; + // The sort is unnecessary, just propagate the stricter fetch and + // ordering requirements. + let fetch = min_fetch(plan.fetch(), parent_fetch); + sort_push_down = sort_push_down + .children + .swap_remove(0) + .update_plan_from_children()?; + sort_push_down.data.fetch = fetch; + sort_push_down.data.ordering_requirement = + Some(OrderingRequirements::from(sort_ordering)); + // Recursive call to helper, so it doesn't transform_down and miss + // the new node (previous child of sort): + return pushdown_sorts_helper(sort_push_down); + } + sort_push_down.plan = plan; + return Ok(Transformed::no(sort_push_down)); + }; - // remove current sort (which will be the new ordering to pushdown) - let new_reqs = current_plan_reqs; - sort_push_down = sort_push_down.children.swap_remove(0); - sort_push_down = sort_push_down.update_plan_from_children()?; // changed plan + let eqp = plan.equivalence_properties(); + let satisfy_parent = + eqp.ordering_satisfy_requirement(parent_requirement.first().clone())?; - // add back sort exec matching parent - sort_push_down = - add_sort_above(sort_push_down, parent_reqs, parent_req_fetch); + if is_sort(&plan) { + let Some(sort_ordering) = plan.output_ordering().cloned() else { + return internal_err!("SortExec should have output ordering"); + }; - // make pushdown requirements be the new ones. + let sort_fetch = plan.fetch(); + let parent_is_stricter = eqp.requirements_compatible( + parent_requirement.first().clone(), + sort_ordering.clone().into(), + ); + + // Remove the current sort as we are either going to prove that it is + // unnecessary, or replace it with a stricter sort. + sort_push_down = sort_push_down + .children + .swap_remove(0) + .update_plan_from_children()?; + if !satisfy_parent && !parent_is_stricter { + // The sort was imposing a different ordering than the one being + // pushed down. Replace it with a sort that matches the pushed-down + // ordering, and continue the pushdown. + // Add back the sort: + sort_push_down = add_sort_above( + sort_push_down, + parent_requirement.into_single(), + parent_fetch, + ); + // Update pushdown requirements: sort_push_down.children[0].data = ParentRequirements { - ordering_requirement: Some(new_reqs), - fetch: current_sort_fetch, + ordering_requirement: Some(OrderingRequirements::from(sort_ordering)), + fetch: sort_fetch, }; + return Ok(Transformed::yes(sort_push_down)); } else { - // Don't add a SortExec - // Do update what sort requirements to keep pushing down - - // remove current sort, and get the sort's child - sort_push_down = sort_push_down.children.swap_remove(0); - sort_push_down = sort_push_down.update_plan_from_children()?; // changed plan - - // set the stricter fetch - sort_push_down.data.fetch = min_fetch(current_sort_fetch, parent_req_fetch); - - // set the stricter ordering - if current_is_stricter { - sort_push_down.data.ordering_requirement = Some(current_plan_reqs); + // Sort was unnecessary, just propagate the stricter fetch and + // ordering requirements: + sort_push_down.data.fetch = min_fetch(sort_fetch, parent_fetch); + let current_is_stricter = eqp.requirements_compatible( + sort_ordering.clone().into(), + parent_requirement.first().clone(), + ); + sort_push_down.data.ordering_requirement = if current_is_stricter { + Some(OrderingRequirements::from(sort_ordering)) } else { - sort_push_down.data.ordering_requirement = Some(parent_reqs); - } - - // recursive call to helper, so it doesn't transform_down and miss the new node (previous child of sort) + Some(parent_requirement) + }; + // Recursive call to helper, so it doesn't transform_down and miss + // the new node (previous child of sort): return pushdown_sorts_helper(sort_push_down); } - } else if parent_reqs.is_empty() { - // note: this `satisfy_parent`, but we don't want to push down anything. - // Nothing to do. - return Ok(Transformed::no(sort_push_down)); - } else if satisfy_parent { + } + + sort_push_down.plan = plan; + if satisfy_parent { // For non-sort operators which satisfy ordering: - let reqs = plan.required_input_ordering(); - let parent_req_fetch = sort_push_down.data.fetch; + let reqs = sort_push_down.plan.required_input_ordering(); for (child, order) in sort_push_down.children.iter_mut().zip(reqs) { child.data.ordering_requirement = order; - child.data.fetch = min_fetch(parent_req_fetch, child.data.fetch); + child.data.fetch = min_fetch(parent_fetch, child.data.fetch); } - } else if let Some(adjusted) = pushdown_requirement_to_children(plan, &parent_reqs)? { - // For operators that can take a sort pushdown. - - // Continue pushdown, with updated requirements: - let parent_fetch = sort_push_down.data.fetch; - let current_fetch = plan.fetch(); + } else if let Some(adjusted) = pushdown_requirement_to_children( + &sort_push_down.plan, + parent_requirement.clone(), + )? { + // For operators that can take a sort pushdown, continue with updated + // requirements: + let current_fetch = sort_push_down.plan.fetch(); for (child, order) in sort_push_down.children.iter_mut().zip(adjusted) { child.data.ordering_requirement = order; child.data.fetch = min_fetch(current_fetch, parent_fetch); @@ -179,16 +201,13 @@ fn pushdown_sorts_helper( sort_push_down.data.ordering_requirement = None; } else { // Can not push down requirements, add new `SortExec`: - let sort_reqs = sort_push_down - .data - .ordering_requirement - .clone() - .unwrap_or_default(); - let fetch = sort_push_down.data.fetch; - sort_push_down = add_sort_above(sort_push_down, sort_reqs, fetch); + sort_push_down = add_sort_above( + sort_push_down, + parent_requirement.into_single(), + parent_fetch, + ); assign_initial_requirements(&mut sort_push_down); } - Ok(Transformed::yes(sort_push_down)) } @@ -196,21 +215,18 @@ fn pushdown_sorts_helper( /// If sort cannot be pushed down, return None. fn pushdown_requirement_to_children( plan: &Arc, - parent_required: &LexRequirement, -) -> Result>>> { + parent_required: OrderingRequirements, +) -> Result>>> { let maintains_input_order = plan.maintains_input_order(); if is_window(plan) { - let required_input_ordering = plan.required_input_ordering(); - let request_child = required_input_ordering[0].clone().unwrap_or_default(); + let mut required_input_ordering = plan.required_input_ordering(); + let maybe_child_requirement = required_input_ordering.swap_remove(0); let child_plan = plan.children().swap_remove(0); - - match determine_children_requirement(parent_required, &request_child, child_plan) - { - RequirementsCompatibility::Satisfy => { - let req = (!request_child.is_empty()) - .then(|| LexRequirement::new(request_child.to_vec())); - Ok(Some(vec![req])) - } + let Some(child_req) = maybe_child_requirement else { + return Ok(None); + }; + match determine_children_requirement(&parent_required, &child_req, child_plan) { + RequirementsCompatibility::Satisfy => Ok(Some(vec![Some(child_req)])), RequirementsCompatibility::Compatible(adjusted) => { // If parent requirements are more specific than output ordering // of the window plan, then we can deduce that the parent expects @@ -218,7 +234,7 @@ fn pushdown_requirement_to_children( // that's the case, we block the pushdown of sort operation. if !plan .equivalence_properties() - .ordering_satisfy_requirement(parent_required) + .ordering_satisfy_requirement(parent_required.into_single())? { return Ok(None); } @@ -228,82 +244,71 @@ fn pushdown_requirement_to_children( RequirementsCompatibility::NonCompatible => Ok(None), } } else if let Some(sort_exec) = plan.as_any().downcast_ref::() { - let sort_req = LexRequirement::from( - sort_exec - .properties() - .output_ordering() - .cloned() - .unwrap_or(LexOrdering::default()), - ); - if sort_exec + let Some(sort_ordering) = sort_exec.properties().output_ordering().cloned() + else { + return internal_err!("SortExec should have output ordering"); + }; + sort_exec .properties() .eq_properties - .requirements_compatible(parent_required, &sort_req) - { - debug_assert!(!parent_required.is_empty()); - Ok(Some(vec![Some(LexRequirement::new( - parent_required.to_vec(), - ))])) - } else { - Ok(None) - } + .requirements_compatible( + parent_required.first().clone(), + sort_ordering.into(), + ) + .then(|| Ok(vec![Some(parent_required)])) + .transpose() } else if plan.fetch().is_some() && plan.supports_limit_pushdown() && plan .maintains_input_order() - .iter() - .all(|maintain| *maintain) + .into_iter() + .all(|maintain| maintain) { - let output_req = LexRequirement::from( - plan.properties() - .output_ordering() - .cloned() - .unwrap_or(LexOrdering::default()), - ); // Push down through operator with fetch when: // - requirement is aligned with output ordering // - it preserves ordering during execution - if plan - .properties() - .eq_properties - .requirements_compatible(parent_required, &output_req) - { - let req = (!parent_required.is_empty()) - .then(|| LexRequirement::new(parent_required.to_vec())); - Ok(Some(vec![req])) + let Some(ordering) = plan.properties().output_ordering() else { + return Ok(Some(vec![Some(parent_required)])); + }; + if plan.properties().eq_properties.requirements_compatible( + parent_required.first().clone(), + ordering.clone().into(), + ) { + Ok(Some(vec![Some(parent_required)])) } else { Ok(None) } } else if is_union(plan) { - // UnionExec does not have real sort requirements for its input. Here we change the adjusted_request_ordering to UnionExec's output ordering and - // propagate the sort requirements down to correct the unnecessary descendant SortExec under the UnionExec - let req = (!parent_required.is_empty()).then(|| parent_required.clone()); - Ok(Some(vec![req; plan.children().len()])) + // `UnionExec` does not have real sort requirements for its input, we + // just propagate the sort requirements down: + Ok(Some(vec![Some(parent_required); plan.children().len()])) } else if let Some(smj) = plan.as_any().downcast_ref::() { - // If the current plan is SortMergeJoinExec let left_columns_len = smj.left().schema().fields().len(); - let parent_required_expr = LexOrdering::from(parent_required.clone()); - match expr_source_side( - parent_required_expr.as_ref(), - smj.join_type(), - left_columns_len, - ) { - Some(JoinSide::Left) => try_pushdown_requirements_to_join( + let parent_ordering: Vec = parent_required + .first() + .iter() + .cloned() + .map(Into::into) + .collect(); + let eqp = smj.properties().equivalence_properties(); + match expr_source_side(eqp, parent_ordering, smj.join_type(), left_columns_len) { + Some((JoinSide::Left, ordering)) => try_pushdown_requirements_to_join( smj, - parent_required, - parent_required_expr.as_ref(), + parent_required.into_single(), + ordering, JoinSide::Left, ), - Some(JoinSide::Right) => { + Some((JoinSide::Right, ordering)) => { let right_offset = smj.schema().fields.len() - smj.right().schema().fields.len(); - let new_right_required = - shift_right_required(parent_required, right_offset)?; - let new_right_required_expr = LexOrdering::from(new_right_required); + let ordering = add_offset_to_physical_sort_exprs( + ordering, + -(right_offset as isize), + )?; try_pushdown_requirements_to_join( smj, - parent_required, - new_right_required_expr.as_ref(), + parent_required.into_single(), + ordering, JoinSide::Right, ) } @@ -318,28 +323,26 @@ fn pushdown_requirement_to_children( || plan.as_any().is::() // TODO: Add support for Projection push down || plan.as_any().is::() - || pushdown_would_violate_requirements(parent_required, plan.as_ref()) + || pushdown_would_violate_requirements(&parent_required, plan.as_ref()) { // If the current plan is a leaf node or can not maintain any of the input ordering, can not pushed down requirements. // For RepartitionExec, we always choose to not push down the sort requirements even the RepartitionExec(input_partition=1) could maintain input ordering. // Pushing down is not beneficial Ok(None) } else if is_sort_preserving_merge(plan) { - let new_ordering = LexOrdering::from(parent_required.clone()); + let new_ordering = LexOrdering::from(parent_required.first().clone()); let mut spm_eqs = plan.equivalence_properties().clone(); + let old_ordering = spm_eqs.output_ordering().unwrap(); // Sort preserving merge will have new ordering, one requirement above is pushed down to its below. - spm_eqs = spm_eqs.with_reorder(new_ordering); - // Do not push-down through SortPreservingMergeExec when - // ordering requirement invalidates requirement of sort preserving merge exec. - if !spm_eqs.ordering_satisfy(&plan.output_ordering().cloned().unwrap_or_default()) - { - Ok(None) - } else { + let change = spm_eqs.reorder(new_ordering)?; + if !change || spm_eqs.ordering_satisfy(old_ordering)? { // Can push-down through SortPreservingMergeExec, because parent requirement is finer // than SortPreservingMergeExec output ordering. - let req = (!parent_required.is_empty()) - .then(|| LexRequirement::new(parent_required.to_vec())); - Ok(Some(vec![req])) + Ok(Some(vec![Some(parent_required)])) + } else { + // Do not push-down through SortPreservingMergeExec when + // ordering requirement invalidates requirement of sort preserving merge exec. + Ok(None) } } else if let Some(hash_join) = plan.as_any().downcast_ref::() { handle_hash_join(hash_join, parent_required) @@ -352,22 +355,21 @@ fn pushdown_requirement_to_children( /// Return true if pushing the sort requirements through a node would violate /// the input sorting requirements for the plan fn pushdown_would_violate_requirements( - parent_required: &LexRequirement, + parent_required: &OrderingRequirements, child: &dyn ExecutionPlan, ) -> bool { child .required_input_ordering() - .iter() + .into_iter() + // If there is no requirement, pushing down would not violate anything. + .flatten() .any(|child_required| { - let Some(child_required) = child_required.as_ref() else { - // no requirements, so pushing down would not violate anything - return false; - }; - // check if the plan's requirements would still e satisfied if we pushed - // down the parent requirements + // Check if the plan's requirements would still be satisfied if we + // pushed down the parent requirements: child_required + .into_single() .iter() - .zip(parent_required.iter()) + .zip(parent_required.first().iter()) .all(|(c, p)| !c.compatible(p)) }) } @@ -378,25 +380,24 @@ fn pushdown_would_violate_requirements( /// - If parent requirements are more specific, push down parent requirements. /// - If they are not compatible, need to add a sort. fn determine_children_requirement( - parent_required: &LexRequirement, - request_child: &LexRequirement, + parent_required: &OrderingRequirements, + child_requirement: &OrderingRequirements, child_plan: &Arc, ) -> RequirementsCompatibility { - if child_plan - .equivalence_properties() - .requirements_compatible(request_child, parent_required) - { + let eqp = child_plan.equivalence_properties(); + if eqp.requirements_compatible( + child_requirement.first().clone(), + parent_required.first().clone(), + ) { // Child requirements are more specific, no need to push down. RequirementsCompatibility::Satisfy - } else if child_plan - .equivalence_properties() - .requirements_compatible(parent_required, request_child) - { + } else if eqp.requirements_compatible( + parent_required.first().clone(), + child_requirement.first().clone(), + ) { // Parent requirements are more specific, adjust child's requirements // and push down the new requirements: - let adjusted = (!parent_required.is_empty()) - .then(|| LexRequirement::new(parent_required.to_vec())); - RequirementsCompatibility::Compatible(adjusted) + RequirementsCompatibility::Compatible(Some(parent_required.clone())) } else { RequirementsCompatibility::NonCompatible } @@ -404,42 +405,41 @@ fn determine_children_requirement( fn try_pushdown_requirements_to_join( smj: &SortMergeJoinExec, - parent_required: &LexRequirement, - sort_expr: &LexOrdering, + parent_required: LexRequirement, + sort_exprs: Vec, push_side: JoinSide, -) -> Result>>> { - let left_eq_properties = smj.left().equivalence_properties(); - let right_eq_properties = smj.right().equivalence_properties(); +) -> Result>>> { let mut smj_required_orderings = smj.required_input_ordering(); - let right_requirement = smj_required_orderings.swap_remove(1); - let left_requirement = smj_required_orderings.swap_remove(0); - let left_ordering = &smj.left().output_ordering().cloned().unwrap_or_default(); - let right_ordering = &smj.right().output_ordering().cloned().unwrap_or_default(); + let ordering = LexOrdering::new(sort_exprs.clone()); let (new_left_ordering, new_right_ordering) = match push_side { JoinSide::Left => { - let left_eq_properties = - left_eq_properties.clone().with_reorder(sort_expr.clone()); - if left_eq_properties - .ordering_satisfy_requirement(&left_requirement.unwrap_or_default()) + let mut left_eq_properties = smj.left().equivalence_properties().clone(); + left_eq_properties.reorder(sort_exprs)?; + let Some(left_requirement) = smj_required_orderings.swap_remove(0) else { + return Ok(None); + }; + if !left_eq_properties + .ordering_satisfy_requirement(left_requirement.into_single())? { - // After re-ordering requirement is still satisfied - (sort_expr, right_ordering) - } else { return Ok(None); } + // After re-ordering, requirement is still satisfied: + (ordering.as_ref(), smj.right().output_ordering()) } JoinSide::Right => { - let right_eq_properties = - right_eq_properties.clone().with_reorder(sort_expr.clone()); - if right_eq_properties - .ordering_satisfy_requirement(&right_requirement.unwrap_or_default()) + let mut right_eq_properties = smj.right().equivalence_properties().clone(); + right_eq_properties.reorder(sort_exprs)?; + let Some(right_requirement) = smj_required_orderings.swap_remove(1) else { + return Ok(None); + }; + if !right_eq_properties + .ordering_satisfy_requirement(right_requirement.into_single())? { - // After re-ordering requirement is still satisfied - (left_ordering, sort_expr) - } else { return Ok(None); } + // After re-ordering, requirement is still satisfied: + (smj.left().output_ordering(), ordering.as_ref()) } JoinSide::None => return Ok(None), }; @@ -449,18 +449,19 @@ fn try_pushdown_requirements_to_join( new_left_ordering, new_right_ordering, join_type, - smj.on(), smj.left().schema().fields.len(), &smj.maintains_input_order(), Some(probe_side), - ); + )?; let mut smj_eqs = smj.properties().equivalence_properties().clone(); - // smj will have this ordering when its input changes. - smj_eqs = smj_eqs.with_reorder(new_output_ordering.unwrap_or_default()); - let should_pushdown = smj_eqs.ordering_satisfy_requirement(parent_required); + if let Some(new_output_ordering) = new_output_ordering { + // smj will have this ordering when its input changes. + smj_eqs.reorder(new_output_ordering)?; + } + let should_pushdown = smj_eqs.ordering_satisfy_requirement(parent_required)?; Ok(should_pushdown.then(|| { let mut required_input_ordering = smj.required_input_ordering(); - let new_req = Some(LexRequirement::from(sort_expr.clone())); + let new_req = ordering.map(Into::into); match push_side { JoinSide::Left => { required_input_ordering[0] = new_req; @@ -475,77 +476,78 @@ fn try_pushdown_requirements_to_join( } fn expr_source_side( - required_exprs: &LexOrdering, + eqp: &EquivalenceProperties, + mut ordering: Vec, join_type: JoinType, left_columns_len: usize, -) -> Option { +) -> Option<(JoinSide, Vec)> { + // TODO: Handle the case where a prefix of the ordering comes from the left + // and a suffix from the right. match join_type { JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full - | JoinType::LeftMark => { - let all_column_sides = required_exprs - .iter() - .filter_map(|r| { - r.expr.as_any().downcast_ref::().map(|col| { - if col.index() < left_columns_len { - JoinSide::Left - } else { - JoinSide::Right + | JoinType::LeftMark + | JoinType::RightMark => { + let eq_group = eqp.eq_group(); + let mut right_ordering = ordering.clone(); + let (mut valid_left, mut valid_right) = (true, true); + for (left, right) in ordering.iter_mut().zip(right_ordering.iter_mut()) { + let col = left.expr.as_any().downcast_ref::()?; + let eq_class = eq_group.get_equivalence_class(&left.expr); + if col.index() < left_columns_len { + if valid_right { + valid_right = eq_class.is_some_and(|cls| { + for expr in cls.iter() { + if expr + .as_any() + .downcast_ref::() + .is_some_and(|c| c.index() >= left_columns_len) + { + right.expr = Arc::clone(expr); + return true; + } + } + false + }); + } + } else if valid_left { + valid_left = eq_class.is_some_and(|cls| { + for expr in cls.iter() { + if expr + .as_any() + .downcast_ref::() + .is_some_and(|c| c.index() < left_columns_len) + { + left.expr = Arc::clone(expr); + return true; + } } - }) - }) - .collect::>(); - - // If the exprs are all coming from one side, the requirements can be pushed down - if all_column_sides.len() != required_exprs.len() { - None - } else if all_column_sides - .iter() - .all(|side| matches!(side, JoinSide::Left)) - { - Some(JoinSide::Left) - } else if all_column_sides - .iter() - .all(|side| matches!(side, JoinSide::Right)) - { - Some(JoinSide::Right) + false + }); + }; + if !(valid_left || valid_right) { + return None; + } + } + if valid_left { + Some((JoinSide::Left, ordering)) + } else if valid_right { + Some((JoinSide::Right, right_ordering)) } else { + // TODO: Handle the case where we can push down to both sides. None } } - JoinType::LeftSemi | JoinType::LeftAnti => required_exprs + JoinType::LeftSemi | JoinType::LeftAnti => ordering .iter() - .all(|e| e.expr.as_any().downcast_ref::().is_some()) - .then_some(JoinSide::Left), - JoinType::RightSemi | JoinType::RightAnti => required_exprs + .all(|e| e.expr.as_any().is::()) + .then_some((JoinSide::Left, ordering)), + JoinType::RightSemi | JoinType::RightAnti => ordering .iter() - .all(|e| e.expr.as_any().downcast_ref::().is_some()) - .then_some(JoinSide::Right), - } -} - -fn shift_right_required( - parent_required: &LexRequirement, - left_columns_len: usize, -) -> Result { - let new_right_required = parent_required - .iter() - .filter_map(|r| { - let col = r.expr.as_any().downcast_ref::()?; - col.index().checked_sub(left_columns_len).map(|offset| { - r.clone() - .with_expr(Arc::new(Column::new(col.name(), offset))) - }) - }) - .collect::>(); - if new_right_required.len() == parent_required.len() { - Ok(LexRequirement::new(new_right_required)) - } else { - plan_err!( - "Expect to shift all the parent required column indexes for SortMergeJoin" - ) + .all(|e| e.expr.as_any().is::()) + .then_some((JoinSide::Right, ordering)), } } @@ -565,16 +567,18 @@ fn shift_right_required( /// pushed down, `Ok(None)` if not. On error, returns a `Result::Err`. fn handle_custom_pushdown( plan: &Arc, - parent_required: &LexRequirement, + parent_required: OrderingRequirements, maintains_input_order: Vec, -) -> Result>>> { - // If there's no requirement from the parent or the plan has no children, return early - if parent_required.is_empty() || plan.children().is_empty() { +) -> Result>>> { + // If the plan has no children, return early: + if plan.children().is_empty() { return Ok(None); } - // Collect all unique column indices used in the parent-required sorting expression - let all_indices: HashSet = parent_required + // Collect all unique column indices used in the parent-required sorting + // expression: + let requirement = parent_required.into_single(); + let all_indices: HashSet = requirement .iter() .flat_map(|order| { collect_columns(&order.expr) @@ -584,14 +588,14 @@ fn handle_custom_pushdown( }) .collect(); - // Get the number of fields in each child's schema - let len_of_child_schemas: Vec = plan + // Get the number of fields in each child's schema: + let children_schema_lengths: Vec = plan .children() .iter() .map(|c| c.schema().fields().len()) .collect(); - // Find the index of the child that maintains input order + // Find the index of the order-maintaining child: let Some(maintained_child_idx) = maintains_input_order .iter() .enumerate() @@ -601,26 +605,28 @@ fn handle_custom_pushdown( return Ok(None); }; - // Check if all required columns come from the child that maintains input order - let start_idx = len_of_child_schemas[..maintained_child_idx] + // Check if all required columns come from the order-maintaining child: + let start_idx = children_schema_lengths[..maintained_child_idx] .iter() .sum::(); - let end_idx = start_idx + len_of_child_schemas[maintained_child_idx]; + let end_idx = start_idx + children_schema_lengths[maintained_child_idx]; let all_from_maintained_child = all_indices.iter().all(|i| i >= &start_idx && i < &end_idx); - // If all columns are from the maintained child, update the parent requirements + // If all columns are from the maintained child, update the parent requirements: if all_from_maintained_child { - let sub_offset = len_of_child_schemas + let sub_offset = children_schema_lengths .iter() .take(maintained_child_idx) .sum::(); - // Transform the parent-required expression for the child schema by adjusting columns - let updated_parent_req = parent_required - .iter() + // Transform the parent-required expression for the child schema by + // adjusting columns: + let updated_parent_req = requirement + .into_iter() .map(|req| { let child_schema = plan.children()[maintained_child_idx].schema(); - let updated_columns = Arc::clone(&req.expr) + let updated_columns = req + .expr .transform_up(|expr| { if let Some(col) = expr.as_any().downcast_ref::() { let new_index = col.index() - sub_offset; @@ -642,7 +648,8 @@ fn handle_custom_pushdown( .iter() .map(|&maintains_order| { if maintains_order { - Some(LexRequirement::new(updated_parent_req.clone())) + LexRequirement::new(updated_parent_req.clone()) + .map(OrderingRequirements::new) } else { None } @@ -659,16 +666,17 @@ fn handle_custom_pushdown( // for join type: Inner, Right, RightSemi, RightAnti fn handle_hash_join( plan: &HashJoinExec, - parent_required: &LexRequirement, -) -> Result>>> { - // If there's no requirement from the parent or the plan has no children - // or the join type is not Inner, Right, RightSemi, RightAnti, return early - if parent_required.is_empty() || !plan.maintains_input_order()[1] { + parent_required: OrderingRequirements, +) -> Result>>> { + // If the plan has no children or does not maintain the right side ordering, + // return early: + if !plan.maintains_input_order()[1] { return Ok(None); } // Collect all unique column indices used in the parent-required sorting expression - let all_indices: HashSet = parent_required + let requirement = parent_required.into_single(); + let all_indices: HashSet<_> = requirement .iter() .flat_map(|order| { collect_columns(&order.expr) @@ -694,11 +702,12 @@ fn handle_hash_join( // If all columns are from the right child, update the parent requirements if all_from_right_child { // Transform the parent-required expression for the child schema by adjusting columns - let updated_parent_req = parent_required - .iter() + let updated_parent_req = requirement + .into_iter() .map(|req| { let child_schema = plan.children()[1].schema(); - let updated_columns = Arc::clone(&req.expr) + let updated_columns = req + .expr .transform_up(|expr| { if let Some(col) = expr.as_any().downcast_ref::() { let index = projected_indices[col.index()].index; @@ -718,7 +727,7 @@ fn handle_hash_join( // Populating with the updated requirements for children that maintain order Ok(Some(vec![ None, - Some(LexRequirement::new(updated_parent_req)), + LexRequirement::new(updated_parent_req).map(OrderingRequirements::new), ])) } else { Ok(None) @@ -757,7 +766,7 @@ enum RequirementsCompatibility { /// Requirements satisfy Satisfy, /// Requirements compatible - Compatible(Option), + Compatible(Option), /// Requirements not compatible NonCompatible, } diff --git a/datafusion/physical-optimizer/src/ensure_coop.rs b/datafusion/physical-optimizer/src/ensure_coop.rs new file mode 100644 index 000000000000..0c0b63c0b3e7 --- /dev/null +++ b/datafusion/physical-optimizer/src/ensure_coop.rs @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! The [`EnsureCooperative`] optimizer rule inspects the physical plan to find all +//! portions of the plan that will not yield cooperatively. +//! It will insert `CooperativeExec` nodes where appropriate to ensure execution plans +//! always yield cooperatively. + +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +use crate::PhysicalOptimizerRule; + +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::Result; +use datafusion_physical_plan::coop::CooperativeExec; +use datafusion_physical_plan::execution_plan::{EvaluationType, SchedulingType}; +use datafusion_physical_plan::ExecutionPlan; + +/// `EnsureCooperative` is a [`PhysicalOptimizerRule`] that inspects the physical plan for +/// sub plans that do not participate in cooperative scheduling. The plan is subdivided into sub +/// plans on eager evaluation boundaries. Leaf nodes and eager evaluation roots are checked +/// to see if they participate in cooperative scheduling. Those that do no are wrapped in +/// a [`CooperativeExec`] parent. +pub struct EnsureCooperative {} + +impl EnsureCooperative { + pub fn new() -> Self { + Self {} + } +} + +impl Default for EnsureCooperative { + fn default() -> Self { + Self::new() + } +} + +impl Debug for EnsureCooperative { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct(self.name()).finish() + } +} + +impl PhysicalOptimizerRule for EnsureCooperative { + fn name(&self) -> &str { + "EnsureCooperative" + } + + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + plan.transform_up(|plan| { + let is_leaf = plan.children().is_empty(); + let is_exchange = plan.properties().evaluation_type == EvaluationType::Eager; + if (is_leaf || is_exchange) + && plan.properties().scheduling_type != SchedulingType::Cooperative + { + // Wrap non-cooperative leaves or eager evaluation roots in a cooperative exec to + // ensure the plans they participate in are properly cooperative. + Ok(Transformed::new( + Arc::new(CooperativeExec::new(Arc::clone(&plan))), + true, + TreeNodeRecursion::Continue, + )) + } else { + Ok(Transformed::no(plan)) + } + }) + .map(|t| t.data) + } + + fn schema_check(&self) -> bool { + // Wrapping a leaf in YieldStreamExec preserves the schema, so it is safe. + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::config::ConfigOptions; + use datafusion_physical_plan::{displayable, test::scan_partitioned}; + use insta::assert_snapshot; + + #[tokio::test] + async fn test_cooperative_exec_for_custom_exec() { + let test_custom_exec = scan_partitioned(1); + let config = ConfigOptions::new(); + let optimized = EnsureCooperative::new() + .optimize(test_custom_exec, &config) + .unwrap(); + + let display = displayable(optimized.as_ref()).indent(true).to_string(); + // Use insta snapshot to ensure full plan structure + assert_snapshot!(display, @r###" + CooperativeExec + DataSourceExec: partitions=1, partition_sizes=[1] + "###); + } +} diff --git a/datafusion/physical-optimizer/src/filter_pushdown.rs b/datafusion/physical-optimizer/src/filter_pushdown.rs new file mode 100644 index 000000000000..885280576b4b --- /dev/null +++ b/datafusion/physical-optimizer/src/filter_pushdown.rs @@ -0,0 +1,568 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::PhysicalOptimizerRule; + +use datafusion_common::{config::ConfigOptions, Result}; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_plan::filter_pushdown::{ + ChildPushdownResult, FilterPushdownPhase, FilterPushdownPropagation, + PredicateSupport, PredicateSupports, +}; +use datafusion_physical_plan::{with_new_children_if_necessary, ExecutionPlan}; + +use itertools::izip; + +/// Attempts to recursively push given filters from the top of the tree into leafs. +/// +/// # Default Implementation +/// +/// The default implementation in [`ExecutionPlan::gather_filters_for_pushdown`] +/// and [`ExecutionPlan::handle_child_pushdown_result`] assumes that: +/// +/// * Parent filters can't be passed onto children (determined by [`ExecutionPlan::gather_filters_for_pushdown`]) +/// * This node has no filters to contribute (determined by [`ExecutionPlan::gather_filters_for_pushdown`]). +/// * Any filters that could not be pushed down to the children are marked as unsupported (determined by [`ExecutionPlan::handle_child_pushdown_result`]). +/// +/// # Example: Push filter into a `DataSourceExec` +/// +/// For example, consider the following plan: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ CoalesceBatchesExec │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ FilterExec │ +/// │ filters = [ id=1] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// └──────────────────────┘ +/// ``` +/// +/// Our goal is to move the `id = 1` filter from the [`FilterExec`] node to the `DataSourceExec` node. +/// +/// If this filter is selective pushing it into the scan can avoid massive +/// amounts of data being read from the source (the projection is `*` so all +/// matching columns are read). +/// +/// The new plan looks like: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ CoalesceBatchesExec │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// │ filters = [ id=1] │ +/// └──────────────────────┘ +/// ``` +/// +/// # Example: Push filters with `ProjectionExec` +/// +/// Let's consider a more complex example involving a [`ProjectionExec`] +/// node in between the [`FilterExec`] and `DataSourceExec` nodes that +/// creates a new column that the filter depends on. +/// +/// ```text +/// ┌──────────────────────┐ +/// │ CoalesceBatchesExec │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ FilterExec │ +/// │ filters = │ +/// │ [cost>50,id=1] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ ProjectionExec │ +/// │ cost = price * 1.2 │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// └──────────────────────┘ +/// ``` +/// +/// We want to push down the filters `[id=1]` to the `DataSourceExec` node, +/// but can't push down `cost>50` because it requires the [`ProjectionExec`] +/// node to be executed first. A simple thing to do would be to split up the +/// filter into two separate filters and push down the first one: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ CoalesceBatchesExec │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ FilterExec │ +/// │ filters = │ +/// │ [cost>50] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ ProjectionExec │ +/// │ cost = price * 1.2 │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// │ filters = [ id=1] │ +/// └──────────────────────┘ +/// ``` +/// +/// We can actually however do better by pushing down `price * 1.2 > 50` +/// instead of `cost > 50`: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ CoalesceBatchesExec │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ ProjectionExec │ +/// │ cost = price * 1.2 │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// │ filters = [id=1, │ +/// │ price * 1.2 > 50] │ +/// └──────────────────────┘ +/// ``` +/// +/// # Example: Push filters within a subtree +/// +/// There are also cases where we may be able to push down filters within a +/// subtree but not the entire tree. A good example of this is aggregation +/// nodes: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ ProjectionExec │ +/// │ projection = * │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ FilterExec │ +/// │ filters = [sum > 10] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌───────────────────────┐ +/// │ AggregateExec │ +/// │ group by = [id] │ +/// │ aggregate = │ +/// │ [sum(price)] │ +/// └───────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ FilterExec │ +/// │ filters = [id=1] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// └──────────────────────┘ +/// ``` +/// +/// The transformation here is to push down the `id=1` filter to the +/// `DataSourceExec` node: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ ProjectionExec │ +/// │ projection = * │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ FilterExec │ +/// │ filters = [sum > 10] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌───────────────────────┐ +/// │ AggregateExec │ +/// │ group by = [id] │ +/// │ aggregate = │ +/// │ [sum(price)] │ +/// └───────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// │ filters = [id=1] │ +/// └──────────────────────┘ +/// ``` +/// +/// The point here is that: +/// 1. We cannot push down `sum > 10` through the [`AggregateExec`] node into the `DataSourceExec` node. +/// Any filters above the [`AggregateExec`] node are not pushed down. +/// This is determined by calling [`ExecutionPlan::gather_filters_for_pushdown`] on the [`AggregateExec`] node. +/// 2. We need to keep recursing into the tree so that we can discover the other [`FilterExec`] node and push +/// down the `id=1` filter. +/// +/// # Example: Push filters through Joins +/// +/// It is also possible to push down filters through joins and filters that +/// originate from joins. For example, a hash join where we build a hash +/// table of the left side and probe the right side (ignoring why we would +/// choose this order, typically it depends on the size of each table, +/// etc.). +/// +/// ```text +/// ┌─────────────────────┐ +/// │ FilterExec │ +/// │ filters = │ +/// │ [d.size > 100] │ +/// └─────────────────────┘ +/// │ +/// │ +/// ┌──────────▼──────────┐ +/// │ │ +/// │ HashJoinExec │ +/// │ [u.dept@hash(d.id)] │ +/// │ │ +/// └─────────────────────┘ +/// │ +/// ┌────────────┴────────────┐ +/// ┌──────────▼──────────┐ ┌──────────▼──────────┐ +/// │ DataSourceExec │ │ DataSourceExec │ +/// │ alias [users as u] │ │ alias [dept as d] │ +/// │ │ │ │ +/// └─────────────────────┘ └─────────────────────┘ +/// ``` +/// +/// There are two pushdowns we can do here: +/// 1. Push down the `d.size > 100` filter through the `HashJoinExec` node to the `DataSourceExec` +/// node for the `departments` table. +/// 2. Push down the hash table state from the `HashJoinExec` node to the `DataSourceExec` node to avoid reading +/// rows from the `users` table that will be eliminated by the join. +/// This can be done via a bloom filter or similar and is not (yet) supported +/// in DataFusion. See . +/// +/// ```text +/// ┌─────────────────────┐ +/// │ │ +/// │ HashJoinExec │ +/// │ [u.dept@hash(d.id)] │ +/// │ │ +/// └─────────────────────┘ +/// │ +/// ┌────────────┴────────────┐ +/// ┌──────────▼──────────┐ ┌──────────▼──────────┐ +/// │ DataSourceExec │ │ DataSourceExec │ +/// │ alias [users as u] │ │ alias [dept as d] │ +/// │ filters = │ │ filters = │ +/// │ [depg@hash(d.id)] │ │ [ d.size > 100] │ +/// └─────────────────────┘ └─────────────────────┘ +/// ``` +/// +/// You may notice in this case that the filter is *dynamic*: the hash table +/// is built _after_ the `departments` table is read and at runtime. We +/// don't have a concrete `InList` filter or similar to push down at +/// optimization time. These sorts of dynamic filters are handled by +/// building a specialized [`PhysicalExpr`] that can be evaluated at runtime +/// and internally maintains a reference to the hash table or other state. +/// +/// To make working with these sorts of dynamic filters more tractable we have the method [`PhysicalExpr::snapshot`] +/// which attempts to simplify a dynamic filter into a "basic" non-dynamic filter. +/// For a join this could mean converting it to an `InList` filter or a min/max filter for example. +/// See `datafusion/physical-plan/src/dynamic_filters.rs` for more details. +/// +/// # Example: Push TopK filters into Scans +/// +/// Another form of dynamic filter is pushing down the state of a `TopK` +/// operator for queries like `SELECT * FROM t ORDER BY id LIMIT 10`: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ TopK │ +/// │ limit = 10 │ +/// │ order by = [id] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// └──────────────────────┘ +/// ``` +/// +/// We can avoid large amounts of data processing by transforming this into: +/// +/// ```text +/// ┌──────────────────────┐ +/// │ TopK │ +/// │ limit = 10 │ +/// │ order by = [id] │ +/// └──────────────────────┘ +/// │ +/// ▼ +/// ┌──────────────────────┐ +/// │ DataSourceExec │ +/// │ projection = * │ +/// │ filters = │ +/// │ [id < @ TopKHeap] │ +/// └──────────────────────┘ +/// ``` +/// +/// Now as we fill our `TopK` heap we can push down the state of the heap to +/// the `DataSourceExec` node to avoid reading files / row groups / pages / +/// rows that could not possibly be in the top 10. +/// +/// This is not yet implemented in DataFusion. See +/// +/// +/// [`PhysicalExpr`]: datafusion_physical_plan::PhysicalExpr +/// [`PhysicalExpr::snapshot`]: datafusion_physical_plan::PhysicalExpr::snapshot +/// [`FilterExec`]: datafusion_physical_plan::filter::FilterExec +/// [`ProjectionExec`]: datafusion_physical_plan::projection::ProjectionExec +/// [`AggregateExec`]: datafusion_physical_plan::aggregates::AggregateExec +#[derive(Debug)] +pub struct FilterPushdown { + phase: FilterPushdownPhase, + name: String, +} + +impl FilterPushdown { + fn new_with_phase(phase: FilterPushdownPhase) -> Self { + let name = match phase { + FilterPushdownPhase::Pre => "FilterPushdown", + FilterPushdownPhase::Post => "FilterPushdown(Post)", + } + .to_string(); + Self { phase, name } + } + + /// Create a new [`FilterPushdown`] optimizer rule that runs in the pre-optimization phase. + /// See [`FilterPushdownPhase`] for more details. + pub fn new() -> Self { + Self::new_with_phase(FilterPushdownPhase::Pre) + } + + /// Create a new [`FilterPushdown`] optimizer rule that runs in the post-optimization phase. + /// See [`FilterPushdownPhase`] for more details. + pub fn new_post_optimization() -> Self { + Self::new_with_phase(FilterPushdownPhase::Post) + } +} + +impl Default for FilterPushdown { + fn default() -> Self { + Self::new() + } +} + +impl PhysicalOptimizerRule for FilterPushdown { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + Ok( + push_down_filters(Arc::clone(&plan), vec![], config, self.phase)? + .updated_node + .unwrap_or(plan), + ) + } + + fn name(&self) -> &str { + &self.name + } + + fn schema_check(&self) -> bool { + true // Filter pushdown does not change the schema of the plan + } +} + +/// Support state of each predicate for the children of the node. +/// These predicates are coming from the parent node. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ParentPredicateStates { + NoChildren, + Unsupported, + Supported, +} + +fn push_down_filters( + node: Arc, + parent_predicates: Vec>, + config: &ConfigOptions, + phase: FilterPushdownPhase, +) -> Result>> { + // If the node has any child, these will be rewritten as supported or unsupported + let mut parent_predicates_pushdown_states = + vec![ParentPredicateStates::NoChildren; parent_predicates.len()]; + let mut self_filters_pushdown_supports = vec![]; + let mut new_children = Vec::with_capacity(node.children().len()); + + let children = node.children(); + let filter_description = + node.gather_filters_for_pushdown(phase, parent_predicates.clone(), config)?; + + for (child, parent_filters, self_filters) in izip!( + children, + filter_description.parent_filters(), + filter_description.self_filters() + ) { + // Here, `parent_filters` are the predicates which are provided by the parent node of + // the current node, and tried to be pushed down over the child which the loop points + // currently. `self_filters` are the predicates which are provided by the current node, + // and tried to be pushed down over the child similarly. + + let num_self_filters = self_filters.len(); + let mut parent_supported_predicate_indices = vec![]; + let mut all_predicates = self_filters; + + // Iterate over each predicate coming from the parent + for (idx, filter) in parent_filters.into_iter().enumerate() { + // Check if we can push this filter down to our child. + // These supports are defined in `gather_filters_for_pushdown()` + match filter { + PredicateSupport::Supported(predicate) => { + // Queue this filter up for pushdown to this child + all_predicates.push(predicate); + parent_supported_predicate_indices.push(idx); + // Mark this filter as supported by our children if no child has marked it as unsupported + if parent_predicates_pushdown_states[idx] + != ParentPredicateStates::Unsupported + { + parent_predicates_pushdown_states[idx] = + ParentPredicateStates::Supported; + } + } + PredicateSupport::Unsupported(_) => { + // Mark as unsupported by our children + parent_predicates_pushdown_states[idx] = + ParentPredicateStates::Unsupported; + } + } + } + + // Any filters that could not be pushed down to a child are marked as not-supported to our parents + let result = push_down_filters(Arc::clone(child), all_predicates, config, phase)?; + + if let Some(new_child) = result.updated_node { + // If we have a filter pushdown result, we need to update our children + new_children.push(new_child); + } else { + // If we don't have a filter pushdown result, we need to update our children + new_children.push(Arc::clone(child)); + } + + // Our child doesn't know the difference between filters that were passed down + // from our parents and filters that the current node injected. We need to de-entangle + // this since we do need to distinguish between them. + let mut all_filters = result.filters.into_inner(); + let parent_predicates = all_filters.split_off(num_self_filters); + let self_predicates = all_filters; + self_filters_pushdown_supports.push(PredicateSupports::new(self_predicates)); + + for (idx, result) in parent_supported_predicate_indices + .iter() + .zip(parent_predicates) + { + let current_node_state = match result { + PredicateSupport::Supported(_) => ParentPredicateStates::Supported, + PredicateSupport::Unsupported(_) => ParentPredicateStates::Unsupported, + }; + match (current_node_state, parent_predicates_pushdown_states[*idx]) { + (r, ParentPredicateStates::NoChildren) => { + // If we have no result, use the current state from this child + parent_predicates_pushdown_states[*idx] = r; + } + (ParentPredicateStates::Supported, ParentPredicateStates::Supported) => { + // If the current child and all previous children are supported, + // the filter continues to support it + parent_predicates_pushdown_states[*idx] = + ParentPredicateStates::Supported; + } + _ => { + // Either the current child or a previous child marked this filter as unsupported + parent_predicates_pushdown_states[*idx] = + ParentPredicateStates::Unsupported; + } + } + } + } + // Re-create this node with new children + let updated_node = with_new_children_if_necessary(Arc::clone(&node), new_children)?; + // Remap the result onto the parent filters as they were given to us. + // Any filters that were not pushed down to any children are marked as unsupported. + let parent_pushdown_result = PredicateSupports::new( + parent_predicates_pushdown_states + .into_iter() + .zip(parent_predicates) + .map(|(state, filter)| match state { + ParentPredicateStates::NoChildren => { + PredicateSupport::Unsupported(filter) + } + ParentPredicateStates::Unsupported => { + PredicateSupport::Unsupported(filter) + } + ParentPredicateStates::Supported => PredicateSupport::Supported(filter), + }) + .collect(), + ); + // TODO: by calling `handle_child_pushdown_result` we are assuming that the + // `ExecutionPlan` implementation will not change the plan itself. + // Should we have a separate method for dynamic pushdown that does not allow modifying the plan? + let mut res = updated_node.handle_child_pushdown_result( + phase, + ChildPushdownResult { + parent_filters: parent_pushdown_result, + self_filters: self_filters_pushdown_supports, + }, + config, + )?; + // Compare pointers for new_node and node, if they are different we must replace + // ourselves because of changes in our children. + if res.updated_node.is_none() && !Arc::ptr_eq(&updated_node, &node) { + res.updated_node = Some(updated_node) + } + Ok(res) +} diff --git a/datafusion/physical-optimizer/src/join_selection.rs b/datafusion/physical-optimizer/src/join_selection.rs index 5a772ccdd249..dc220332141b 100644 --- a/datafusion/physical-optimizer/src/join_selection.rs +++ b/datafusion/physical-optimizer/src/join_selection.rs @@ -65,8 +65,8 @@ pub(crate) fn should_swap_join_order( // Get the left and right table's total bytes // If both the left and right tables contain total_byte_size statistics, // use `total_byte_size` to determine `should_swap_join_order`, else use `num_rows` - let left_stats = left.statistics()?; - let right_stats = right.statistics()?; + let left_stats = left.partition_statistics(None)?; + let right_stats = right.partition_statistics(None)?; // First compare `total_byte_size` of left and right side, // if information in this field is insufficient fallback to the `num_rows` match ( @@ -91,7 +91,7 @@ fn supports_collect_by_thresholds( ) -> bool { // Currently we do not trust the 0 value from stats, due to stats collection might have bug // TODO check the logic in datasource::get_statistics_with_limit() - let Ok(stats) = plan.statistics() else { + let Ok(stats) = plan.partition_statistics(None) else { return false; }; @@ -245,7 +245,7 @@ pub(crate) fn try_collect_left( hash_join.join_type(), hash_join.projection.clone(), PartitionMode::CollectLeft, - hash_join.null_equals_null(), + hash_join.null_equality(), )?))) } } @@ -257,7 +257,7 @@ pub(crate) fn try_collect_left( hash_join.join_type(), hash_join.projection.clone(), PartitionMode::CollectLeft, - hash_join.null_equals_null(), + hash_join.null_equality(), )?))), (false, true) => { if hash_join.join_type().supports_swap() { @@ -292,7 +292,7 @@ pub(crate) fn partitioned_hash_join( hash_join.join_type(), hash_join.projection.clone(), PartitionMode::Partitioned, - hash_join.null_equals_null(), + hash_join.null_equality(), )?)) } } @@ -459,7 +459,7 @@ fn hash_join_convert_symmetric_subrule( JoinSide::Right => hash_join.right().output_ordering(), JoinSide::None => unreachable!(), } - .map(|p| LexOrdering::new(p.to_vec())) + .cloned() }) .flatten() }; @@ -474,7 +474,7 @@ fn hash_join_convert_symmetric_subrule( hash_join.on().to_vec(), hash_join.filter().cloned(), hash_join.join_type(), - hash_join.null_equals_null(), + hash_join.null_equality(), left_order, right_order, mode, diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs index 35503f3b0b5f..c828cc696063 100644 --- a/datafusion/physical-optimizer/src/lib.rs +++ b/datafusion/physical-optimizer/src/lib.rs @@ -25,17 +25,20 @@ #![deny(clippy::clone_on_ref_ptr)] pub mod aggregate_statistics; +pub mod coalesce_async_exec_input; pub mod coalesce_batches; pub mod combine_partial_final_agg; pub mod enforce_distribution; pub mod enforce_sorting; +pub mod ensure_coop; +pub mod filter_pushdown; pub mod join_selection; pub mod limit_pushdown; pub mod limited_distinct_aggregation; pub mod optimizer; pub mod output_requirements; pub mod projection_pushdown; -pub mod pruning; +pub use datafusion_pruning as pruning; pub mod sanity_checker; pub mod topk_aggregation; pub mod update_aggr_exprs; diff --git a/datafusion/physical-optimizer/src/limit_pushdown.rs b/datafusion/physical-optimizer/src/limit_pushdown.rs index 5887cb51a727..7469c3af9344 100644 --- a/datafusion/physical-optimizer/src/limit_pushdown.rs +++ b/datafusion/physical-optimizer/src/limit_pushdown.rs @@ -246,16 +246,7 @@ pub fn pushdown_limit_helper( Ok((Transformed::no(pushdown_plan), global_state)) } } else { - // Add fetch or a `LimitExec`: - // If the plan's children have limit and the child's limit < parent's limit, we shouldn't change the global state to true, - // because the children limit will be overridden if the global state is changed. - if !pushdown_plan - .children() - .iter() - .any(|&child| extract_limit(child).is_some()) - { - global_state.satisfied = true; - } + global_state.satisfied = true; pushdown_plan = if let Some(plan_with_fetch) = maybe_fetchable { if global_skip > 0 { add_global_limit(plan_with_fetch, global_skip, Some(global_fetch)) diff --git a/datafusion/physical-optimizer/src/optimizer.rs b/datafusion/physical-optimizer/src/optimizer.rs index bab31150e250..f9ad521b4f7c 100644 --- a/datafusion/physical-optimizer/src/optimizer.rs +++ b/datafusion/physical-optimizer/src/optimizer.rs @@ -25,6 +25,8 @@ use crate::coalesce_batches::CoalesceBatches; use crate::combine_partial_final_agg::CombinePartialFinalAggregate; use crate::enforce_distribution::EnforceDistribution; use crate::enforce_sorting::EnforceSorting; +use crate::ensure_coop::EnsureCooperative; +use crate::filter_pushdown::FilterPushdown; use crate::join_selection::JoinSelection; use crate::limit_pushdown::LimitPushdown; use crate::limited_distinct_aggregation::LimitedDistinctAggregation; @@ -34,6 +36,7 @@ use crate::sanity_checker::SanityCheckPlan; use crate::topk_aggregation::TopKAggregation; use crate::update_aggr_exprs::OptimizeAggregateOrder; +use crate::coalesce_async_exec_input::CoalesceAsyncExecInput; use datafusion_common::config::ConfigOptions; use datafusion_common::Result; use datafusion_physical_plan::ExecutionPlan; @@ -94,6 +97,12 @@ impl PhysicalOptimizer { // as that rule may inject other operations in between the different AggregateExecs. // Applying the rule early means only directly-connected AggregateExecs must be examined. Arc::new(LimitedDistinctAggregation::new()), + // The FilterPushdown rule tries to push down filters as far as it can. + // For example, it will push down filtering from a `FilterExec` to `DataSourceExec`. + // Note that this does not push down dynamic filters (such as those created by a `SortExec` operator in TopK mode), + // those are handled by the later `FilterPushdown` rule. + // See `FilterPushdownPhase` for more details. + Arc::new(FilterPushdown::new()), // The EnforceDistribution rule is for adding essential repartitioning to satisfy distribution // requirements. Please make sure that the whole plan tree is determined before this rule. // This rule increases parallelism if doing so is beneficial to the physical plan; i.e. at @@ -113,6 +122,7 @@ impl PhysicalOptimizer { // The CoalesceBatches rule will not influence the distribution and ordering of the // whole plan tree. Therefore, to avoid influencing other rules, it should run last. Arc::new(CoalesceBatches::new()), + Arc::new(CoalesceAsyncExecInput::new()), // Remove the ancillary output requirement operator since we are done with the planning // phase. Arc::new(OutputRequirements::new_remove_mode()), @@ -132,6 +142,11 @@ impl PhysicalOptimizer { // are not present, the load of executors such as join or union will be // reduced by narrowing their input tables. Arc::new(ProjectionPushdown::new()), + Arc::new(EnsureCooperative::new()), + // This FilterPushdown handles dynamic filters that may have references to the source ExecutionPlan. + // Therefore it should be run at the end of the optimization process since any changes to the plan may break the dynamic filter's references. + // See `FilterPushdownPhase` for more details. + Arc::new(FilterPushdown::new_post_optimization()), // The SanityCheckPlan rule checks whether the order and // distribution requirements of each node in the plan // is satisfied. It will also reject non-runnable query diff --git a/datafusion/physical-optimizer/src/output_requirements.rs b/datafusion/physical-optimizer/src/output_requirements.rs index 3ca0547aa11d..044d27811be6 100644 --- a/datafusion/physical-optimizer/src/output_requirements.rs +++ b/datafusion/physical-optimizer/src/output_requirements.rs @@ -30,16 +30,17 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Result, Statistics}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{Distribution, LexRequirement, PhysicalSortRequirement}; +use datafusion_physical_expr::Distribution; +use datafusion_physical_expr_common::sort_expr::OrderingRequirements; use datafusion_physical_plan::projection::{ - make_with_child, update_expr, ProjectionExec, + make_with_child, update_expr, update_ordering_requirement, ProjectionExec, }; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, SendableRecordBatchStream, + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + SendableRecordBatchStream, }; -use datafusion_physical_plan::{ExecutionPlanProperties, PlanProperties}; /// This rule either adds or removes [`OutputRequirements`]s to/from the physical /// plan according to its `mode` attribute, which is set by the constructors @@ -94,7 +95,7 @@ enum RuleMode { #[derive(Debug)] pub struct OutputRequirementExec { input: Arc, - order_requirement: Option, + order_requirement: Option, dist_requirement: Distribution, cache: PlanProperties, } @@ -102,7 +103,7 @@ pub struct OutputRequirementExec { impl OutputRequirementExec { pub fn new( input: Arc, - requirements: Option, + requirements: Option, dist_requirement: Distribution, ) -> Self { let cache = Self::compute_properties(&input); @@ -176,7 +177,7 @@ impl ExecutionPlan for OutputRequirementExec { vec![&self.input] } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { vec![self.order_requirement.clone()] } @@ -200,7 +201,11 @@ impl ExecutionPlan for OutputRequirementExec { } fn statistics(&self) -> Result { - self.input.statistics() + self.input.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input.partition_statistics(partition) } fn try_swapping_with_projection( @@ -208,23 +213,23 @@ impl ExecutionPlan for OutputRequirementExec { projection: &ProjectionExec, ) -> Result>> { // If the projection does not narrow the schema, we should not try to push it down: - if projection.expr().len() >= projection.input().schema().fields().len() { + let proj_exprs = projection.expr(); + if proj_exprs.len() >= projection.input().schema().fields().len() { return Ok(None); } - let mut updated_sort_reqs = LexRequirement::new(vec![]); - // None or empty_vec can be treated in the same way. - if let Some(reqs) = &self.required_input_ordering()[0] { - for req in &reqs.inner { - let Some(new_expr) = update_expr(&req.expr, projection.expr(), false)? + let mut requirements = self.required_input_ordering().swap_remove(0); + if let Some(reqs) = requirements { + let mut updated_reqs = vec![]; + let (lexes, soft) = reqs.into_alternatives(); + for lex in lexes.into_iter() { + let Some(updated_lex) = update_ordering_requirement(lex, proj_exprs)? else { return Ok(None); }; - updated_sort_reqs.push(PhysicalSortRequirement { - expr: new_expr, - options: req.options, - }); + updated_reqs.push(updated_lex); } + requirements = OrderingRequirements::new_alternatives(updated_reqs, soft); } let dist_req = match &self.required_input_distribution()[0] { @@ -242,15 +247,10 @@ impl ExecutionPlan for OutputRequirementExec { dist => dist.clone(), }; - make_with_child(projection, &self.input()) - .map(|input| { - OutputRequirementExec::new( - input, - (!updated_sort_reqs.is_empty()).then_some(updated_sort_reqs), - dist_req, - ) - }) - .map(|e| Some(Arc::new(e) as _)) + make_with_child(projection, &self.input()).map(|input| { + let e = OutputRequirementExec::new(input, requirements, dist_req); + Some(Arc::new(e) as _) + }) } } @@ -313,17 +313,18 @@ fn require_top_ordering_helper( if children.len() != 1 { Ok((plan, false)) } else if let Some(sort_exec) = plan.as_any().downcast_ref::() { - // In case of constant columns, output ordering of SortExec would give an empty set. - // Therefore; we check the sort expression field of the SortExec to assign the requirements. + // In case of constant columns, output ordering of the `SortExec` would + // be an empty set. Therefore; we check the sort expression field to + // assign the requirements. let req_ordering = sort_exec.expr(); - let req_dist = sort_exec.required_input_distribution()[0].clone(); - let reqs = LexRequirement::from(req_ordering.clone()); + let req_dist = sort_exec.required_input_distribution().swap_remove(0); + let reqs = OrderingRequirements::from(req_ordering.clone()); Ok(( Arc::new(OutputRequirementExec::new(plan, Some(reqs), req_dist)) as _, true, )) } else if let Some(spm) = plan.as_any().downcast_ref::() { - let reqs = LexRequirement::from(spm.expr().clone()); + let reqs = OrderingRequirements::from(spm.expr().clone()); Ok(( Arc::new(OutputRequirementExec::new( plan, @@ -333,7 +334,9 @@ fn require_top_ordering_helper( true, )) } else if plan.maintains_input_order()[0] - && plan.required_input_ordering()[0].is_none() + && (plan.required_input_ordering()[0] + .as_ref() + .is_none_or(|o| matches!(o, OrderingRequirements::Soft(_)))) { // Keep searching for a `SortExec` as long as ordering is maintained, // and on-the-way operators do not themselves require an ordering. diff --git a/datafusion/physical-optimizer/src/sanity_checker.rs b/datafusion/physical-optimizer/src/sanity_checker.rs index 8edbb0f09114..acc70d39f057 100644 --- a/datafusion/physical-optimizer/src/sanity_checker.rs +++ b/datafusion/physical-optimizer/src/sanity_checker.rs @@ -137,7 +137,8 @@ pub fn check_plan_sanity( ) { let child_eq_props = child.equivalence_properties(); if let Some(sort_req) = sort_req { - if !child_eq_props.ordering_satisfy_requirement(&sort_req) { + let sort_req = sort_req.into_single(); + if !child_eq_props.ordering_satisfy_requirement(sort_req.clone())? { let plan_str = get_plan_string(&plan); return plan_err!( "Plan: {:?} does not satisfy order requirements: {}. Child-{} order: {}", diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs b/datafusion/physical-optimizer/src/topk_aggregation.rs index faedea55ca15..bff0b1e49684 100644 --- a/datafusion/physical-optimizer/src/topk_aggregation.rs +++ b/datafusion/physical-optimizer/src/topk_aggregation.rs @@ -25,7 +25,6 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::LexOrdering; use datafusion_physical_plan::aggregates::AggregateExec; use datafusion_physical_plan::execution_plan::CardinalityEffect; use datafusion_physical_plan::projection::ProjectionExec; @@ -131,7 +130,7 @@ impl TopKAggregation { Ok(Transformed::no(plan)) }; let child = Arc::clone(child).transform_down(closure).data().ok()?; - let sort = SortExec::new(LexOrdering::new(sort.expr().to_vec()), child) + let sort = SortExec::new(sort.expr().clone(), child) .with_fetch(sort.fetch()) .with_preserve_partitioning(sort.preserve_partitioning()); Some(Arc::new(sort)) diff --git a/datafusion/physical-optimizer/src/update_aggr_exprs.rs b/datafusion/physical-optimizer/src/update_aggr_exprs.rs index 6228ed10ec34..61bc715592af 100644 --- a/datafusion/physical-optimizer/src/update_aggr_exprs.rs +++ b/datafusion/physical-optimizer/src/update_aggr_exprs.rs @@ -24,15 +24,10 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{plan_datafusion_err, Result}; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; -use datafusion_physical_expr::{ - reverse_order_bys, EquivalenceProperties, PhysicalSortRequirement, -}; -use datafusion_physical_expr::{LexOrdering, LexRequirement}; -use datafusion_physical_plan::aggregates::concat_slices; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement}; +use datafusion_physical_plan::aggregates::{concat_slices, AggregateExec}; use datafusion_physical_plan::windows::get_ordered_partition_by_indices; -use datafusion_physical_plan::{ - aggregates::AggregateExec, ExecutionPlan, ExecutionPlanProperties, -}; +use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use crate::PhysicalOptimizerRule; @@ -90,32 +85,30 @@ impl PhysicalOptimizerRule for OptimizeAggregateOrder { return Ok(Transformed::no(plan)); } let input = aggr_exec.input(); - let mut aggr_expr = aggr_exec.aggr_expr().to_vec(); + let mut aggr_exprs = aggr_exec.aggr_expr().to_vec(); let groupby_exprs = aggr_exec.group_expr().input_exprs(); // If the existing ordering satisfies a prefix of the GROUP BY // expressions, prefix requirements with this section. In this // case, aggregation will work more efficiently. - let indices = get_ordered_partition_by_indices(&groupby_exprs, input); + let indices = get_ordered_partition_by_indices(&groupby_exprs, input)?; let requirement = indices .iter() .map(|&idx| { PhysicalSortRequirement::new( - Arc::::clone( - &groupby_exprs[idx], - ), + Arc::clone(&groupby_exprs[idx]), None, ) }) .collect::>(); - aggr_expr = try_convert_aggregate_if_better( - aggr_expr, + aggr_exprs = try_convert_aggregate_if_better( + aggr_exprs, &requirement, input.equivalence_properties(), )?; - let aggr_exec = aggr_exec.with_new_aggr_exprs(aggr_expr); + let aggr_exec = aggr_exec.with_new_aggr_exprs(aggr_exprs); Ok(Transformed::yes(Arc::new(aggr_exec) as _)) } else { @@ -159,31 +152,30 @@ fn try_convert_aggregate_if_better( aggr_exprs .into_iter() .map(|aggr_expr| { - let aggr_sort_exprs = aggr_expr.order_bys().unwrap_or(LexOrdering::empty()); - let reverse_aggr_sort_exprs = reverse_order_bys(aggr_sort_exprs); - let aggr_sort_reqs = LexRequirement::from(aggr_sort_exprs.clone()); - let reverse_aggr_req = LexRequirement::from(reverse_aggr_sort_exprs); - + let order_bys = aggr_expr.order_bys(); // If the aggregate expression benefits from input ordering, and // there is an actual ordering enabling this, try to update the // aggregate expression to benefit from the existing ordering. // Otherwise, leave it as is. - if aggr_expr.order_sensitivity().is_beneficial() && !aggr_sort_reqs.is_empty() - { - let reqs = LexRequirement { - inner: concat_slices(prefix_requirement, &aggr_sort_reqs), - }; - - let prefix_requirement = LexRequirement { - inner: prefix_requirement.to_vec(), - }; - - if eq_properties.ordering_satisfy_requirement(&reqs) { + if !aggr_expr.order_sensitivity().is_beneficial() { + Ok(aggr_expr) + } else if !order_bys.is_empty() { + if eq_properties.ordering_satisfy_requirement(concat_slices( + prefix_requirement, + &order_bys + .iter() + .map(|e| e.clone().into()) + .collect::>(), + ))? { // Existing ordering satisfies the aggregator requirements: aggr_expr.with_beneficial_ordering(true)?.map(Arc::new) - } else if eq_properties.ordering_satisfy_requirement(&LexRequirement { - inner: concat_slices(&prefix_requirement, &reverse_aggr_req), - }) { + } else if eq_properties.ordering_satisfy_requirement(concat_slices( + prefix_requirement, + &order_bys + .iter() + .map(|e| e.reverse().into()) + .collect::>(), + ))? { // Converting to reverse enables more efficient execution // given the existing ordering (if possible): aggr_expr diff --git a/datafusion/physical-optimizer/src/utils.rs b/datafusion/physical-optimizer/src/utils.rs index 57a193315a5c..3655e555a744 100644 --- a/datafusion/physical-optimizer/src/utils.rs +++ b/datafusion/physical-optimizer/src/utils.rs @@ -17,8 +17,8 @@ use std::sync::Arc; -use datafusion_physical_expr::LexRequirement; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_common::Result; +use datafusion_physical_expr::{LexOrdering, LexRequirement}; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; @@ -40,14 +40,18 @@ pub fn add_sort_above( sort_requirements: LexRequirement, fetch: Option, ) -> PlanContext { - let mut sort_expr = LexOrdering::from(sort_requirements); - sort_expr.retain(|sort_expr| { - !node - .plan + let mut sort_reqs: Vec<_> = sort_requirements.into(); + sort_reqs.retain(|sort_expr| { + node.plan .equivalence_properties() .is_expr_constant(&sort_expr.expr) + .is_none() }); - let mut new_sort = SortExec::new(sort_expr, Arc::clone(&node.plan)).with_fetch(fetch); + let sort_exprs = sort_reqs.into_iter().map(Into::into).collect::>(); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + return node; + }; + let mut new_sort = SortExec::new(ordering, Arc::clone(&node.plan)).with_fetch(fetch); if node.plan.output_partitioning().partition_count() > 1 { new_sort = new_sort.with_preserve_partitioning(true); } @@ -61,15 +65,15 @@ pub fn add_sort_above_with_check( node: PlanContext, sort_requirements: LexRequirement, fetch: Option, -) -> PlanContext { +) -> Result> { if !node .plan .equivalence_properties() - .ordering_satisfy_requirement(&sort_requirements) + .ordering_satisfy_requirement(sort_requirements.clone())? { - add_sort_above(node, sort_requirements, fetch) + Ok(add_sort_above(node, sort_requirements, fetch)) } else { - node + Ok(node) } } diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 5210ee26755c..095ee78cd0d6 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -36,6 +36,8 @@ workspace = true [features] force_hash_collisions = [] +tokio_coop = [] +tokio_coop_fallback = [] [lib] name = "datafusion_physical_plan" @@ -86,3 +88,7 @@ name = "partial_ordering" [[bench]] harness = false name = "spill_io" + +[[bench]] +harness = false +name = "sort_preserving_merge" diff --git a/datafusion/physical-plan/README.md b/datafusion/physical-plan/README.md index ec604253fd2e..37cc1658015c 100644 --- a/datafusion/physical-plan/README.md +++ b/datafusion/physical-plan/README.md @@ -24,4 +24,9 @@ This crate is a submodule of DataFusion that contains the `ExecutionPlan` trait and the various implementations of that trait for built in operators such as filters, projections, joins, aggregations, etc. +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/physical-plan/benches/partial_ordering.rs b/datafusion/physical-plan/benches/partial_ordering.rs index 422826abcc8b..e1a9d0b583e9 100644 --- a/datafusion/physical-plan/benches/partial_ordering.rs +++ b/datafusion/physical-plan/benches/partial_ordering.rs @@ -18,11 +18,10 @@ use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array}; -use arrow_schema::{DataType, Field, Schema, SortOptions}; -use criterion::{criterion_group, criterion_main, Criterion}; -use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; use datafusion_physical_plan::aggregates::order::GroupOrderingPartial; +use criterion::{criterion_group, criterion_main, Criterion}; + const BATCH_SIZE: usize = 8192; fn create_test_arrays(num_columns: usize) -> Vec { @@ -39,31 +38,15 @@ fn bench_new_groups(c: &mut Criterion) { // Test with 1, 2, 4, and 8 order indices for num_columns in [1, 2, 4, 8] { - let fields: Vec = (0..num_columns) - .map(|i| Field::new(format!("col{}", i), DataType::Int32, false)) - .collect(); - let schema = Schema::new(fields); - let order_indices: Vec = (0..num_columns).collect(); - let ordering = LexOrdering::new( - (0..num_columns) - .map(|i| { - PhysicalSortExpr::new( - col(&format!("col{}", i), &schema).unwrap(), - SortOptions::default(), - ) - }) - .collect(), - ); - group.bench_function(format!("order_indices_{}", num_columns), |b| { + group.bench_function(format!("order_indices_{num_columns}"), |b| { let batch_group_values = create_test_arrays(num_columns); let group_indices: Vec = (0..BATCH_SIZE).collect(); b.iter(|| { let mut ordering = - GroupOrderingPartial::try_new(&schema, &order_indices, &ordering) - .unwrap(); + GroupOrderingPartial::try_new(order_indices.clone()).unwrap(); ordering .new_groups(&batch_group_values, &group_indices, BATCH_SIZE) .unwrap(); diff --git a/datafusion/physical-plan/benches/sort_preserving_merge.rs b/datafusion/physical-plan/benches/sort_preserving_merge.rs new file mode 100644 index 000000000000..f223fd806b69 --- /dev/null +++ b/datafusion/physical-plan/benches/sort_preserving_merge.rs @@ -0,0 +1,197 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{ArrayRef, StringArray, UInt64Array}, + record_batch::RecordBatch, +}; +use arrow_schema::{SchemaRef, SortOptions}; +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; +use datafusion_physical_plan::test::TestMemoryExec; +use datafusion_physical_plan::{ + collect, sorts::sort_preserving_merge::SortPreservingMergeExec, +}; + +use std::sync::Arc; + +const BENCH_ROWS: usize = 1_000_000; // 1 million rows + +fn get_large_string(idx: usize) -> String { + let base_content = [ + concat!( + "# Advanced Topics in Computer Science\n\n", + "## Summary\nThis article explores complex system design patterns and...\n\n", + "```rust\nfn process_data(data: &mut [i32]) {\n // Parallel processing example\n data.par_iter_mut().for_each(|x| *x *= 2);\n}\n```\n\n", + "## Performance Considerations\nWhen implementing concurrent systems...\n" + ), + concat!( + "## API Documentation\n\n", + "```json\n{\n \"endpoint\": \"/api/v2/users\",\n \"methods\": [\"GET\", \"POST\"],\n \"parameters\": {\n \"page\": \"number\"\n }\n}\n```\n\n", + "# Authentication Guide\nSecure your API access using OAuth 2.0...\n" + ), + concat!( + "# Data Processing Pipeline\n\n", + "```python\nfrom multiprocessing import Pool\n\ndef main():\n with Pool(8) as p:\n results = p.map(process_item, data)\n```\n\n", + "## Summary of Optimizations\n1. Batch processing\n2. Memory pooling\n3. Concurrent I/O operations\n" + ), + concat!( + "# System Architecture Overview\n\n", + "## Components\n- Load Balancer\n- Database Cluster\n- Cache Service\n\n", + "```go\nfunc main() {\n router := gin.Default()\n router.GET(\"/api/health\", healthCheck)\n router.Run(\":8080\")\n}\n```\n" + ), + concat!( + "## Configuration Reference\n\n", + "```yaml\nserver:\n port: 8080\n max_threads: 32\n\ndatabase:\n url: postgres://user@prod-db:5432/main\n```\n\n", + "# Deployment Strategies\nBlue-green deployment patterns with...\n" + ), + ]; + base_content[idx % base_content.len()].to_string() +} + +fn generate_sorted_string_column(rows: usize) -> ArrayRef { + let mut values = Vec::with_capacity(rows); + for i in 0..rows { + values.push(get_large_string(i)); + } + values.sort(); + Arc::new(StringArray::from(values)) +} + +fn generate_sorted_u64_column(rows: usize) -> ArrayRef { + Arc::new(UInt64Array::from((0_u64..rows as u64).collect::>())) +} + +fn create_partitions( + num_partitions: usize, + num_columns: usize, + num_rows: usize, +) -> Vec> { + (0..num_partitions) + .map(|_| { + let rows = (0..num_columns) + .map(|i| { + ( + format!("col-{i}"), + if IS_LARGE_COLUMN_TYPE { + generate_sorted_string_column(num_rows) + } else { + generate_sorted_u64_column(num_rows) + }, + ) + }) + .collect::>(); + + let batch = RecordBatch::try_from_iter(rows).unwrap(); + vec![batch] + }) + .collect() +} + +struct BenchData { + bench_name: String, + partitions: Vec>, + schema: SchemaRef, + sort_order: LexOrdering, +} + +fn get_bench_data() -> Vec { + let mut ret = Vec::new(); + let mut push_bench_data = |bench_name: &str, partitions: Vec>| { + let schema = partitions[0][0].schema(); + // Define sort order (col1 ASC, col2 ASC, col3 ASC) + let sort_order = LexOrdering::new(schema.fields().iter().map(|field| { + PhysicalSortExpr::new( + col(field.name(), &schema).unwrap(), + SortOptions::default(), + ) + })) + .unwrap(); + ret.push(BenchData { + bench_name: bench_name.to_string(), + partitions, + schema, + sort_order, + }); + }; + // 1. single large string column + { + let partitions = create_partitions::(3, 1, BENCH_ROWS); + push_bench_data("single_large_string_column_with_1m_rows", partitions); + } + // 2. single u64 column + { + let partitions = create_partitions::(3, 1, BENCH_ROWS); + push_bench_data("single_u64_column_with_1m_rows", partitions); + } + // 3. multiple large string columns + { + let partitions = create_partitions::(3, 3, BENCH_ROWS); + push_bench_data("multiple_large_string_columns_with_1m_rows", partitions); + } + // 4. multiple u64 columns + { + let partitions = create_partitions::(3, 3, BENCH_ROWS); + push_bench_data("multiple_u64_columns_with_1m_rows", partitions); + } + ret +} + +/// Add a benchmark to test the optimization effect of reusing Rows. +/// Run this benchmark with: +/// ```sh +/// cargo bench --features="bench" --bench sort_preserving_merge -- --sample-size=10 +/// ``` +fn bench_merge_sorted_preserving(c: &mut Criterion) { + let task_ctx = Arc::new(TaskContext::default()); + let bench_data = get_bench_data(); + for data in bench_data.into_iter() { + let BenchData { + bench_name, + partitions, + schema, + sort_order, + } = data; + c.bench_function( + &format!("bench_merge_sorted_preserving/{bench_name}"), + |b| { + b.iter_batched( + || { + let exec = TestMemoryExec::try_new_exec( + &partitions, + schema.clone(), + None, + ) + .unwrap(); + Arc::new(SortPreservingMergeExec::new(sort_order.clone(), exec)) + }, + |merge_exec| { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + collect(merge_exec, task_ctx.clone()).await.unwrap(); + }); + }, + BatchSize::LargeInput, + ) + }, + ); + } +} + +criterion_group!(benches, bench_merge_sorted_preserving); +criterion_main!(benches); diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs index c4525256dbae..be1f68ea453f 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs @@ -24,6 +24,7 @@ use arrow::array::{ use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::datatypes::{ByteArrayType, DataType, GenericBinaryType}; use datafusion_common::utils::proxy::VecAllocExt; +use datafusion_common::{DataFusionError, Result}; use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY}; use itertools::izip; use std::mem::size_of; @@ -50,6 +51,8 @@ where offsets: Vec, /// Nulls nulls: MaybeNullBufferBuilder, + /// The maximum size of the buffer for `0` + max_buffer_size: usize, } impl ByteGroupValueBuilder @@ -62,6 +65,11 @@ where buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY), offsets: vec![O::default()], nulls: MaybeNullBufferBuilder::new(), + max_buffer_size: if O::IS_LARGE { + i64::MAX as usize + } else { + i32::MAX as usize + }, } } @@ -73,7 +81,7 @@ where self.do_equal_to_inner(lhs_row, array, rhs_row) } - fn append_val_inner(&mut self, array: &ArrayRef, row: usize) + fn append_val_inner(&mut self, array: &ArrayRef, row: usize) -> Result<()> where B: ByteArrayType, { @@ -85,8 +93,10 @@ where self.offsets.push(O::usize_as(offset)); } else { self.nulls.append(false); - self.do_append_val_inner(arr, row); + self.do_append_val_inner(arr, row)?; } + + Ok(()) } fn vectorized_equal_to_inner( @@ -116,7 +126,11 @@ where } } - fn vectorized_append_inner(&mut self, array: &ArrayRef, rows: &[usize]) + fn vectorized_append_inner( + &mut self, + array: &ArrayRef, + rows: &[usize], + ) -> Result<()> where B: ByteArrayType, { @@ -134,22 +148,14 @@ where match all_null_or_non_null { None => { for &row in rows { - if arr.is_null(row) { - self.nulls.append(true); - // nulls need a zero length in the offset buffer - let offset = self.buffer.len(); - self.offsets.push(O::usize_as(offset)); - } else { - self.nulls.append(false); - self.do_append_val_inner(arr, row); - } + self.append_val_inner::(array, row)? } } Some(true) => { self.nulls.append_n(rows.len(), false); for &row in rows { - self.do_append_val_inner(arr, row); + self.do_append_val_inner(arr, row)?; } } @@ -161,6 +167,8 @@ where self.offsets.resize(new_len, O::usize_as(offset)); } } + + Ok(()) } fn do_equal_to_inner( @@ -181,13 +189,26 @@ where self.value(lhs_row) == (array.value(rhs_row).as_ref() as &[u8]) } - fn do_append_val_inner(&mut self, array: &GenericByteArray, row: usize) + fn do_append_val_inner( + &mut self, + array: &GenericByteArray, + row: usize, + ) -> Result<()> where B: ByteArrayType, { let value: &[u8] = array.value(row).as_ref(); self.buffer.append_slice(value); + + if self.buffer.len() > self.max_buffer_size { + return Err(DataFusionError::Execution(format!( + "offset overflow, buffer size > {}", + self.max_buffer_size + ))); + } + self.offsets.push(O::usize_as(self.buffer.len())); + Ok(()) } /// return the current value of the specified row irrespective of null @@ -224,7 +245,7 @@ where } } - fn append_val(&mut self, column: &ArrayRef, row: usize) { + fn append_val(&mut self, column: &ArrayRef, row: usize) -> Result<()> { // Sanity array type match self.output_type { OutputType::Binary => { @@ -232,17 +253,19 @@ where column.data_type(), DataType::Binary | DataType::LargeBinary )); - self.append_val_inner::>(column, row) + self.append_val_inner::>(column, row)? } OutputType::Utf8 => { debug_assert!(matches!( column.data_type(), DataType::Utf8 | DataType::LargeUtf8 )); - self.append_val_inner::>(column, row) + self.append_val_inner::>(column, row)? } _ => unreachable!("View types should use `ArrowBytesViewMap`"), }; + + Ok(()) } fn vectorized_equal_to( @@ -282,24 +305,26 @@ where } } - fn vectorized_append(&mut self, column: &ArrayRef, rows: &[usize]) { + fn vectorized_append(&mut self, column: &ArrayRef, rows: &[usize]) -> Result<()> { match self.output_type { OutputType::Binary => { debug_assert!(matches!( column.data_type(), DataType::Binary | DataType::LargeBinary )); - self.vectorized_append_inner::>(column, rows) + self.vectorized_append_inner::>(column, rows)? } OutputType::Utf8 => { debug_assert!(matches!( column.data_type(), DataType::Utf8 | DataType::LargeUtf8 )); - self.vectorized_append_inner::>(column, rows) + self.vectorized_append_inner::>(column, rows)? } _ => unreachable!("View types should use `ArrowBytesViewMap`"), }; + + Ok(()) } fn len(&self) -> usize { @@ -318,6 +343,7 @@ where mut buffer, offsets, nulls, + .. } = *self; let null_buffer = nulls.build(); @@ -406,27 +432,50 @@ mod tests { use crate::aggregates::group_values::multi_group_by::bytes::ByteGroupValueBuilder; use arrow::array::{ArrayRef, NullBufferBuilder, StringArray}; + use datafusion_common::DataFusionError; use datafusion_physical_expr::binary_map::OutputType; use super::GroupColumn; + #[test] + fn test_byte_group_value_builder_overflow() { + let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); + + let large_string = "a".repeat(1024 * 1024); + + let array = + Arc::new(StringArray::from(vec![Some(large_string.as_str())])) as ArrayRef; + + // Append items until our buffer length is i32::MAX as usize + for _ in 0..2047 { + builder.append_val(&array, 0).unwrap(); + } + + assert!(matches!( + builder.append_val(&array, 0), + Err(DataFusionError::Execution(e)) if e.contains("offset overflow") + )); + + assert_eq!(builder.value(2046), large_string.as_bytes()); + } + #[test] fn test_byte_take_n() { let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); let array = Arc::new(StringArray::from(vec![Some("a"), None])) as ArrayRef; // a, null, null - builder.append_val(&array, 0); - builder.append_val(&array, 1); - builder.append_val(&array, 1); + builder.append_val(&array, 0).unwrap(); + builder.append_val(&array, 1).unwrap(); + builder.append_val(&array, 1).unwrap(); // (a, null) remaining: null let output = builder.take_n(2); assert_eq!(&output, &array); // null, a, null, a - builder.append_val(&array, 0); - builder.append_val(&array, 1); - builder.append_val(&array, 0); + builder.append_val(&array, 0).unwrap(); + builder.append_val(&array, 1).unwrap(); + builder.append_val(&array, 0).unwrap(); // (null, a) remaining: (null, a) let output = builder.take_n(2); @@ -440,9 +489,9 @@ mod tests { ])) as ArrayRef; // null, a, longstringfortest, null, null - builder.append_val(&array, 2); - builder.append_val(&array, 1); - builder.append_val(&array, 1); + builder.append_val(&array, 2).unwrap(); + builder.append_val(&array, 1).unwrap(); + builder.append_val(&array, 1).unwrap(); // (null, a, longstringfortest, null) remaining: (null) let output = builder.take_n(4); @@ -461,7 +510,7 @@ mod tests { builder_array: &ArrayRef, append_rows: &[usize]| { for &index in append_rows { - builder.append_val(builder_array, index); + builder.append_val(builder_array, index).unwrap(); } }; @@ -484,7 +533,9 @@ mod tests { let append = |builder: &mut ByteGroupValueBuilder, builder_array: &ArrayRef, append_rows: &[usize]| { - builder.vectorized_append(builder_array, append_rows); + builder + .vectorized_append(builder_array, append_rows) + .unwrap(); }; let equal_to = |builder: &ByteGroupValueBuilder, @@ -518,7 +569,9 @@ mod tests { None, None, ])) as _; - builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]); + builder + .vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); let mut equal_to_results = vec![true; all_nulls_input_array.len()]; builder.vectorized_equal_to( @@ -542,7 +595,9 @@ mod tests { Some("string4"), Some("string5"), ])) as _; - builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]); + builder + .vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; builder.vectorized_equal_to( diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs index b6d97b5d788d..63018874a1e4 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs @@ -20,6 +20,7 @@ use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; use arrow::array::{make_view, Array, ArrayRef, AsArray, ByteView, GenericByteViewArray}; use arrow::buffer::{Buffer, ScalarBuffer}; use arrow::datatypes::ByteViewType; +use datafusion_common::Result; use itertools::izip; use std::marker::PhantomData; use std::mem::{replace, size_of}; @@ -148,14 +149,7 @@ impl ByteViewGroupValueBuilder { match all_null_or_non_null { None => { for &row in rows { - // Null row case, set and return - if arr.is_valid(row) { - self.nulls.append(false); - self.do_append_val_inner(arr, row); - } else { - self.nulls.append(true); - self.views.push(0); - } + self.append_val_inner(array, row); } } @@ -493,8 +487,9 @@ impl GroupColumn for ByteViewGroupValueBuilder { self.equal_to_inner(lhs_row, array, rhs_row) } - fn append_val(&mut self, array: &ArrayRef, row: usize) { - self.append_val_inner(array, row) + fn append_val(&mut self, array: &ArrayRef, row: usize) -> Result<()> { + self.append_val_inner(array, row); + Ok(()) } fn vectorized_equal_to( @@ -507,8 +502,9 @@ impl GroupColumn for ByteViewGroupValueBuilder { self.vectorized_equal_to_inner(group_indices, array, rows, equal_to_results); } - fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) { + fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) -> Result<()> { self.vectorized_append_inner(array, rows); + Ok(()) } fn len(&self) -> usize { @@ -563,7 +559,7 @@ mod tests { ]); let builder_array: ArrayRef = Arc::new(builder_array); for row in 0..builder_array.len() { - builder.append_val(&builder_array, row); + builder.append_val(&builder_array, row).unwrap(); } let output = Box::new(builder).build(); @@ -578,7 +574,7 @@ mod tests { builder_array: &ArrayRef, append_rows: &[usize]| { for &index in append_rows { - builder.append_val(builder_array, index); + builder.append_val(builder_array, index).unwrap(); } }; @@ -601,7 +597,9 @@ mod tests { let append = |builder: &mut ByteViewGroupValueBuilder, builder_array: &ArrayRef, append_rows: &[usize]| { - builder.vectorized_append(builder_array, append_rows); + builder + .vectorized_append(builder_array, append_rows) + .unwrap(); }; let equal_to = |builder: &ByteViewGroupValueBuilder, @@ -636,7 +634,9 @@ mod tests { None, None, ])) as _; - builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]); + builder + .vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); let mut equal_to_results = vec![true; all_nulls_input_array.len()]; builder.vectorized_equal_to( @@ -660,7 +660,9 @@ mod tests { Some("stringview4"), Some("stringview5"), ])) as _; - builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]); + builder + .vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; builder.vectorized_equal_to( @@ -841,7 +843,7 @@ mod tests { // ####### Test situation 1~5 ####### for row in 0..first_ones_to_append { - builder.append_val(&input_array, row); + builder.append_val(&input_array, row).unwrap(); } assert_eq!(builder.completed.len(), 2); @@ -879,7 +881,7 @@ mod tests { assert!(builder.views.is_empty()); for row in first_ones_to_append..first_ones_to_append + second_ones_to_append { - builder.append_val(&input_array, row); + builder.append_val(&input_array, row).unwrap(); } assert!(builder.completed.is_empty()); @@ -894,7 +896,7 @@ mod tests { ByteViewGroupValueBuilder::::new().with_max_block_size(60); for row in 0..final_ones_to_append { - builder.append_val(&input_array, row); + builder.append_val(&input_array, row).unwrap(); } assert_eq!(builder.completed.len(), 3); diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index ac96a98edfe1..2ac0389454de 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -65,7 +65,7 @@ pub trait GroupColumn: Send + Sync { fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool; /// Appends the row at `row` in `array` to this builder - fn append_val(&mut self, array: &ArrayRef, row: usize); + fn append_val(&mut self, array: &ArrayRef, row: usize) -> Result<()>; /// The vectorized version equal to /// @@ -86,7 +86,7 @@ pub trait GroupColumn: Send + Sync { ); /// The vectorized version `append_val` - fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]); + fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) -> Result<()>; /// Returns the number of rows stored in this builder fn len(&self) -> usize; @@ -270,7 +270,7 @@ impl GroupValuesColumn { map_size: 0, group_values: vec![], hashes_buffer: Default::default(), - random_state: Default::default(), + random_state: crate::aggregates::AGGREGATION_HASH_SEED, }) } @@ -384,7 +384,7 @@ impl GroupValuesColumn { let mut checklen = 0; let group_idx = self.group_values[0].len(); for (i, group_value) in self.group_values.iter_mut().enumerate() { - group_value.append_val(&cols[i], row); + group_value.append_val(&cols[i], row)?; let len = group_value.len(); if i == 0 { checklen = len; @@ -460,14 +460,14 @@ impl GroupValuesColumn { self.collect_vectorized_process_context(&batch_hashes, groups); // 2. Perform `vectorized_append` - self.vectorized_append(cols); + self.vectorized_append(cols)?; // 3. Perform `vectorized_equal_to` self.vectorized_equal_to(cols, groups); // 4. Perform scalarized inter for remaining rows // (about remaining rows, can see comments for `remaining_row_indices`) - self.scalarized_intern_remaining(cols, &batch_hashes, groups); + self.scalarized_intern_remaining(cols, &batch_hashes, groups)?; self.hashes_buffer = batch_hashes; @@ -563,13 +563,13 @@ impl GroupValuesColumn { } /// Perform `vectorized_append`` for `rows` in `vectorized_append_row_indices` - fn vectorized_append(&mut self, cols: &[ArrayRef]) { + fn vectorized_append(&mut self, cols: &[ArrayRef]) -> Result<()> { if self .vectorized_operation_buffers .append_row_indices .is_empty() { - return; + return Ok(()); } let iter = self.group_values.iter_mut().zip(cols.iter()); @@ -577,8 +577,10 @@ impl GroupValuesColumn { group_column.vectorized_append( col, &self.vectorized_operation_buffers.append_row_indices, - ); + )?; } + + Ok(()) } /// Perform `vectorized_equal_to` @@ -719,13 +721,13 @@ impl GroupValuesColumn { cols: &[ArrayRef], batch_hashes: &[u64], groups: &mut [usize], - ) { + ) -> Result<()> { if self .vectorized_operation_buffers .remaining_row_indices .is_empty() { - return; + return Ok(()); } let mut map = mem::take(&mut self.map); @@ -758,7 +760,7 @@ impl GroupValuesColumn { let group_idx = self.group_values[0].len(); let mut checklen = 0; for (i, group_value) in self.group_values.iter_mut().enumerate() { - group_value.append_val(&cols[i], row); + group_value.append_val(&cols[i], row)?; let len = group_value.len(); if i == 0 { checklen = len; @@ -795,6 +797,7 @@ impl GroupValuesColumn { } self.map = map; + Ok(()) } fn scalarized_equal_to_remaining( @@ -1756,11 +1759,9 @@ mod tests { (i, actual_line), (i, expected_line), "Inconsistent result\n\n\ - Actual batch:\n{}\n\ - Expected batch:\n{}\n\ + Actual batch:\n{formatted_actual_batch}\n\ + Expected batch:\n{formatted_expected_batch}\n\ ", - formatted_actual_batch, - formatted_expected_batch, ); } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs index e9c3c42e632b..afec25fd3d66 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs @@ -17,9 +17,11 @@ use crate::aggregates::group_values::multi_group_by::{nulls_equal_to, GroupColumn}; use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; +use arrow::array::ArrowNativeTypeOp; use arrow::array::{cast::AsArray, Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray}; use arrow::buffer::ScalarBuffer; use arrow::datatypes::DataType; +use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::VecAllocExt; use itertools::izip; use std::iter; @@ -71,7 +73,7 @@ impl GroupColumn self.group_values[lhs_row] == array.as_primitive::().value(rhs_row) } - fn append_val(&mut self, array: &ArrayRef, row: usize) { + fn append_val(&mut self, array: &ArrayRef, row: usize) -> Result<()> { // Perf: skip null check if input can't have nulls if NULLABLE { if array.is_null(row) { @@ -84,6 +86,8 @@ impl GroupColumn } else { self.group_values.push(array.as_primitive::().value(row)); } + + Ok(()) } fn vectorized_equal_to( @@ -118,11 +122,11 @@ impl GroupColumn // Otherwise, we need to check their values } - *equal_to_result = self.group_values[lhs_row] == array.value(rhs_row); + *equal_to_result = self.group_values[lhs_row].is_eq(array.value(rhs_row)); } } - fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) { + fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) -> Result<()> { let arr = array.as_primitive::(); let null_count = array.null_count(); @@ -167,6 +171,8 @@ impl GroupColumn } } } + + Ok(()) } fn len(&self) -> usize { @@ -222,7 +228,7 @@ mod tests { builder_array: &ArrayRef, append_rows: &[usize]| { for &index in append_rows { - builder.append_val(builder_array, index); + builder.append_val(builder_array, index).unwrap(); } }; @@ -245,7 +251,9 @@ mod tests { let append = |builder: &mut PrimitiveGroupValueBuilder, builder_array: &ArrayRef, append_rows: &[usize]| { - builder.vectorized_append(builder_array, append_rows); + builder + .vectorized_append(builder_array, append_rows) + .unwrap(); }; let equal_to = |builder: &PrimitiveGroupValueBuilder, @@ -335,7 +343,7 @@ mod tests { builder_array: &ArrayRef, append_rows: &[usize]| { for &index in append_rows { - builder.append_val(builder_array, index); + builder.append_val(builder_array, index).unwrap(); } }; @@ -358,7 +366,9 @@ mod tests { let append = |builder: &mut PrimitiveGroupValueBuilder, builder_array: &ArrayRef, append_rows: &[usize]| { - builder.vectorized_append(builder_array, append_rows); + builder + .vectorized_append(builder_array, append_rows) + .unwrap(); }; let equal_to = |builder: &PrimitiveGroupValueBuilder, @@ -432,7 +442,9 @@ mod tests { None, None, ])) as _; - builder.vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]); + builder + .vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); let mut equal_to_results = vec![true; all_nulls_input_array.len()]; builder.vectorized_equal_to( @@ -456,7 +468,9 @@ mod tests { Some(4), Some(5), ])) as _; - builder.vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]); + builder + .vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]) + .unwrap(); let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; builder.vectorized_equal_to( diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 63751d470313..34893fcc4ed9 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -82,7 +82,7 @@ impl GroupValuesRows { pub fn try_new(schema: SchemaRef) -> Result { // Print a debugging message, so it is clear when the (slower) fallback // GroupValuesRows is used. - debug!("Creating GroupValuesRows for schema: {}", schema); + debug!("Creating GroupValuesRows for schema: {schema}"); let row_converter = RowConverter::new( schema .fields() @@ -106,7 +106,7 @@ impl GroupValuesRows { group_values: None, hashes_buffer: Default::default(), rows_buffer, - random_state: Default::default(), + random_state: crate::aggregates::AGGREGATION_HASH_SEED, }) } } @@ -202,6 +202,7 @@ impl GroupValues for GroupValuesRows { EmitTo::All => { let output = self.row_converter.convert_rows(&group_values)?; group_values.clear(); + self.map.clear(); output } EmitTo::First(n) => { diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index d945d3ddcbf5..8b1905e54041 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -81,11 +81,14 @@ hash_float!(f16, f32, f64); pub struct GroupValuesPrimitive { /// The data type of the output array data_type: DataType, - /// Stores the group index based on the hash of its value + /// Stores the `(group_index, hash)` based on the hash of its value /// - /// We don't store the hashes as hashing fixed width primitives - /// is fast enough for this not to benefit performance - map: HashTable, + /// We also store `hash` is for reducing cost of rehashing. Such cost + /// is obvious in high cardinality group by situation. + /// More details can see: + /// + /// + map: HashTable<(usize, u64)>, /// The group index of the null value if any null_group: Option, /// The values for each group index @@ -102,7 +105,7 @@ impl GroupValuesPrimitive { map: HashTable::with_capacity(128), values: Vec::with_capacity(128), null_group: None, - random_state: Default::default(), + random_state: crate::aggregates::AGGREGATION_HASH_SEED, } } } @@ -127,15 +130,15 @@ where let hash = key.hash(state); let insert = self.map.entry( hash, - |g| unsafe { self.values.get_unchecked(*g).is_eq(key) }, - |g| unsafe { self.values.get_unchecked(*g).hash(state) }, + |&(g, _)| unsafe { self.values.get_unchecked(g).is_eq(key) }, + |&(_, h)| h, ); match insert { - hashbrown::hash_table::Entry::Occupied(o) => *o.get(), + hashbrown::hash_table::Entry::Occupied(o) => o.get().0, hashbrown::hash_table::Entry::Vacant(v) => { let g = self.values.len(); - v.insert(g); + v.insert((g, hash)); self.values.push(key); g } @@ -148,7 +151,7 @@ where } fn size(&self) -> usize { - self.map.capacity() * size_of::() + self.values.allocated_size() + self.map.capacity() * size_of::<(usize, u64)>() + self.values.allocated_size() } fn is_empty(&self) -> bool { @@ -181,12 +184,13 @@ where build_primitive(std::mem::take(&mut self.values), self.null_group.take()) } EmitTo::First(n) => { - self.map.retain(|group_idx| { + self.map.retain(|entry| { // Decrement group index by n + let group_idx = entry.0; match group_idx.checked_sub(n) { // Group index was >= n, shift value down Some(sub) => { - *group_idx = sub; + entry.0 = sub; true } // Group index was < n, so remove from table diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 8906468f68db..14b2d0a932c2 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -27,7 +27,6 @@ use crate::aggregates::{ }; use crate::execution_plan::{CardinalityEffect, EmissionType}; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use crate::projection::get_field_metadata; use crate::windows::get_ordered_partition_by_indices; use crate::{ DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, @@ -37,18 +36,22 @@ use crate::{ use arrow::array::{ArrayRef, UInt16Array, UInt32Array, UInt64Array, UInt8Array}; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_schema::FieldRef; use datafusion_common::stats::Precision; use datafusion_common::{internal_err, not_impl_err, Constraint, Constraints, Result}; use datafusion_execution::TaskContext; use datafusion_expr::{Accumulator, Aggregate}; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use datafusion_physical_expr::equivalence::ProjectionMapping; +use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{ - equivalence::ProjectionMapping, expressions::Column, physical_exprs_contains, - ConstExpr, EquivalenceProperties, LexOrdering, LexRequirement, PhysicalExpr, - PhysicalSortRequirement, + physical_exprs_contains, ConstExpr, EquivalenceProperties, +}; +use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExpr}; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortRequirement, }; -use datafusion_physical_expr_common::physical_expr::fmt_sql; use itertools::Itertools; pub(crate) mod group_values; @@ -58,6 +61,10 @@ mod row_hash; mod topk; mod topk_stream; +/// Hard-coded seed for aggregations to ensure hash values differ from `RepartitionExec`, avoiding collisions. +const AGGREGATION_HASH_SEED: ahash::RandomState = + ahash::RandomState::with_seeds('A' as u64, 'G' as u64, 'G' as u64, 'R' as u64); + /// Aggregation modes /// /// See [`Accumulator::state`] for background information on multi-phase @@ -274,7 +281,7 @@ impl PhysicalGroupBy { } /// Returns the fields that are used as the grouping keys. - fn group_fields(&self, input_schema: &Schema) -> Result> { + fn group_fields(&self, input_schema: &Schema) -> Result> { let mut fields = Vec::with_capacity(self.num_group_exprs()); for ((expr, name), group_expr_nullable) in self.expr.iter().zip(self.exprs_nullable().into_iter()) @@ -285,17 +292,19 @@ impl PhysicalGroupBy { expr.data_type(input_schema)?, group_expr_nullable || expr.nullable(input_schema)?, ) - .with_metadata( - get_field_metadata(expr, input_schema).unwrap_or_default(), - ), + .with_metadata(expr.return_field(input_schema)?.metadata().clone()) + .into(), ); } if !self.is_single() { - fields.push(Field::new( - Aggregate::INTERNAL_GROUPING_ID, - Aggregate::grouping_id_type(self.expr.len()), - false, - )); + fields.push( + Field::new( + Aggregate::INTERNAL_GROUPING_ID, + Aggregate::grouping_id_type(self.expr.len()), + false, + ) + .into(), + ); } Ok(fields) } @@ -304,7 +313,7 @@ impl PhysicalGroupBy { /// /// This might be different from the `group_fields` that might contain internal expressions that /// should not be part of the output schema. - fn output_fields(&self, input_schema: &Schema) -> Result> { + fn output_fields(&self, input_schema: &Schema) -> Result> { let mut fields = self.group_fields(input_schema)?; fields.truncate(self.num_output_exprs()); Ok(fields) @@ -349,6 +358,7 @@ impl PartialEq for PhysicalGroupBy { } } +#[allow(clippy::large_enum_variant)] enum StreamType { AggregateStream(AggregateStream), GroupedHash(GroupedHashAggregateStream), @@ -390,7 +400,7 @@ pub struct AggregateExec { pub input_schema: SchemaRef, /// Execution metrics metrics: ExecutionPlanMetricsSet, - required_input_ordering: Option, + required_input_ordering: Option, /// Describes how the input is ordered relative to the group by columns input_order_mode: InputOrderMode, cache: PlanProperties, @@ -477,16 +487,13 @@ impl AggregateExec { // If existing ordering satisfies a prefix of the GROUP BY expressions, // prefix requirements with this section. In this case, aggregation will // work more efficiently. - let indices = get_ordered_partition_by_indices(&groupby_exprs, &input); - let mut new_requirement = LexRequirement::new( - indices - .iter() - .map(|&idx| PhysicalSortRequirement { - expr: Arc::clone(&groupby_exprs[idx]), - options: None, - }) - .collect::>(), - ); + let indices = get_ordered_partition_by_indices(&groupby_exprs, &input)?; + let mut new_requirements = indices + .iter() + .map(|&idx| { + PhysicalSortRequirement::new(Arc::clone(&groupby_exprs[idx]), None) + }) + .collect::>(); let req = get_finer_aggregate_exprs_requirement( &mut aggr_expr, @@ -494,8 +501,10 @@ impl AggregateExec { input_eq_properties, &mode, )?; - new_requirement.inner.extend(req); - new_requirement = new_requirement.collapse(); + new_requirements.extend(req); + + let required_input_ordering = + LexRequirement::new(new_requirements).map(OrderingRequirements::new_soft); // If our aggregation has grouping sets then our base grouping exprs will // be expanded based on the flags in `group_by.groups` where for each @@ -520,10 +529,7 @@ impl AggregateExec { // construct a map from the input expression to the output expression of the Aggregation group by let group_expr_mapping = - ProjectionMapping::try_new(&group_by.expr, &input.schema())?; - - let required_input_ordering = - (!new_requirement.is_empty()).then_some(new_requirement); + ProjectionMapping::try_new(group_by.expr.clone(), &input.schema())?; let cache = Self::compute_properties( &input, @@ -532,7 +538,7 @@ impl AggregateExec { &mode, &input_order_mode, aggr_expr.as_slice(), - ); + )?; Ok(AggregateExec { mode, @@ -623,7 +629,7 @@ impl AggregateExec { } /// Finds the DataType and SortDirection for this Aggregate, if there is one - pub fn get_minmax_desc(&self) -> Option<(Field, bool)> { + pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> { let agg_expr = self.aggr_expr.iter().exactly_one().ok()?; agg_expr.get_minmax_desc() } @@ -647,7 +653,7 @@ impl AggregateExec { return false; } // ensure there are no order by expressions - if self.aggr_expr().iter().any(|e| e.order_bys().is_some()) { + if !self.aggr_expr().iter().all(|e| e.order_bys().is_empty()) { return false; } // ensure there is no output ordering; can this rule be relaxed? @@ -655,8 +661,8 @@ impl AggregateExec { return false; } // ensure no ordering is required on the input - if self.required_input_ordering()[0].is_some() { - return false; + if let Some(requirement) = self.required_input_ordering().swap_remove(0) { + return matches!(requirement, OrderingRequirements::Hard(_)); } true } @@ -669,7 +675,7 @@ impl AggregateExec { mode: &AggregateMode, input_order_mode: &InputOrderMode, aggr_exprs: &[Arc], - ) -> PlanProperties { + ) -> Result { // Construct equivalence properties: let mut eq_properties = input .equivalence_properties() @@ -677,13 +683,12 @@ impl AggregateExec { // If the group by is empty, then we ensure that the operator will produce // only one row, and mark the generated result as a constant value. - if group_expr_mapping.map.is_empty() { - let mut constants = eq_properties.constants().to_vec(); + if group_expr_mapping.is_empty() { let new_constants = aggr_exprs.iter().enumerate().map(|(idx, func)| { - ConstExpr::new(Arc::new(Column::new(func.name(), idx))) + let column = Arc::new(Column::new(func.name(), idx)); + ConstExpr::from(column as Arc) }); - constants.extend(new_constants); - eq_properties = eq_properties.with_constants(constants); + eq_properties.add_constants(new_constants)?; } // Group by expression will be a distinct value after the aggregation. @@ -691,13 +696,11 @@ impl AggregateExec { let mut constraints = eq_properties.constraints().to_vec(); let new_constraint = Constraint::Unique( group_expr_mapping - .map .iter() - .filter_map(|(_, target_col)| { - target_col - .as_any() - .downcast_ref::() - .map(|c| c.index()) + .flat_map(|(_, target_cols)| { + target_cols.iter().flat_map(|(expr, _)| { + expr.as_any().downcast_ref::().map(|c| c.index()) + }) }) .collect(), ); @@ -724,17 +727,80 @@ impl AggregateExec { input.pipeline_behavior() }; - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, emission_type, input.boundedness(), - ) + )) } pub fn input_order_mode(&self) -> &InputOrderMode { &self.input_order_mode } + + fn statistics_inner(&self, child_statistics: Statistics) -> Result { + // TODO stats: group expressions: + // - once expressions will be able to compute their own stats, use it here + // - case where we group by on a column for which with have the `distinct` stat + // TODO stats: aggr expression: + // - aggregations sometimes also preserve invariants such as min, max... + + let column_statistics = { + // self.schema: [, ] + let mut column_statistics = Statistics::unknown_column(&self.schema()); + + for (idx, (expr, _)) in self.group_by.expr.iter().enumerate() { + if let Some(col) = expr.as_any().downcast_ref::() { + column_statistics[idx].max_value = child_statistics.column_statistics + [col.index()] + .max_value + .clone(); + + column_statistics[idx].min_value = child_statistics.column_statistics + [col.index()] + .min_value + .clone(); + } + } + + column_statistics + }; + match self.mode { + AggregateMode::Final | AggregateMode::FinalPartitioned + if self.group_by.expr.is_empty() => + { + Ok(Statistics { + num_rows: Precision::Exact(1), + column_statistics, + total_byte_size: Precision::Absent, + }) + } + _ => { + // When the input row count is 1, we can adopt that statistic keeping its reliability. + // When it is larger than 1, we degrade the precision since it may decrease after aggregation. + let num_rows = if let Some(value) = child_statistics.num_rows.get_value() + { + if *value > 1 { + child_statistics.num_rows.to_inexact() + } else if *value == 0 { + child_statistics.num_rows + } else { + // num_rows = 1 case + let grouping_set_num = self.group_by.groups.len(); + child_statistics.num_rows.map(|x| x * grouping_set_num) + } + } else { + Precision::Absent + }; + Ok(Statistics { + num_rows, + column_statistics, + total_byte_size: Precision::Absent, + }) + } + } + } } impl DisplayAs for AggregateExec { @@ -888,7 +954,7 @@ impl ExecutionPlan for AggregateExec { } } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { vec![self.required_input_ordering.clone()] } @@ -941,50 +1007,11 @@ impl ExecutionPlan for AggregateExec { } fn statistics(&self) -> Result { - // TODO stats: group expressions: - // - once expressions will be able to compute their own stats, use it here - // - case where we group by on a column for which with have the `distinct` stat - // TODO stats: aggr expression: - // - aggregations sometimes also preserve invariants such as min, max... - let column_statistics = Statistics::unknown_column(&self.schema()); - match self.mode { - AggregateMode::Final | AggregateMode::FinalPartitioned - if self.group_by.expr.is_empty() => - { - Ok(Statistics { - num_rows: Precision::Exact(1), - column_statistics, - total_byte_size: Precision::Absent, - }) - } - _ => { - // When the input row count is 0 or 1, we can adopt that statistic keeping its reliability. - // When it is larger than 1, we degrade the precision since it may decrease after aggregation. - let num_rows = if let Some(value) = - self.input().statistics()?.num_rows.get_value() - { - if *value > 1 { - self.input().statistics()?.num_rows.to_inexact() - } else if *value == 0 { - // Aggregation on an empty table creates a null row. - self.input() - .statistics()? - .num_rows - .add(&Precision::Exact(1)) - } else { - // num_rows = 1 case - self.input().statistics()?.num_rows - } - } else { - Precision::Absent - }; - Ok(Statistics { - num_rows, - column_statistics, - total_byte_size: Precision::Absent, - }) - } - } + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.statistics_inner(self.input().partition_statistics(partition)?) } fn cardinality_effect(&self) -> CardinalityEffect { @@ -1044,16 +1071,14 @@ fn get_aggregate_expr_req( aggr_expr: &AggregateFunctionExpr, group_by: &PhysicalGroupBy, agg_mode: &AggregateMode, -) -> LexOrdering { +) -> Option { // If the aggregation function is ordering requirement is not absolutely // necessary, or the aggregation is performing a "second stage" calculation, // then ignore the ordering requirement. if !aggr_expr.order_sensitivity().hard_requires() || !agg_mode.is_first_stage() { - return LexOrdering::default(); + return None; } - - let mut req = aggr_expr.order_bys().cloned().unwrap_or_default(); - + let mut sort_exprs = aggr_expr.order_bys().to_vec(); // In non-first stage modes, we accumulate data (using `merge_batch`) from // different partitions (i.e. merge partial results). During this merge, we // consider the ordering of each partial result. Hence, we do not need to @@ -1064,38 +1089,11 @@ fn get_aggregate_expr_req( // will definitely be satisfied -- Each group by expression will have // distinct values per group, hence all requirements are satisfied. let physical_exprs = group_by.input_exprs(); - req.retain(|sort_expr| { + sort_exprs.retain(|sort_expr| { !physical_exprs_contains(&physical_exprs, &sort_expr.expr) }); } - req -} - -/// Computes the finer ordering for between given existing ordering requirement -/// of aggregate expression. -/// -/// # Parameters -/// -/// * `existing_req` - The existing lexical ordering that needs refinement. -/// * `aggr_expr` - A reference to an aggregate expression trait object. -/// * `group_by` - Information about the physical grouping (e.g group by expression). -/// * `eq_properties` - Equivalence properties relevant to the computation. -/// * `agg_mode` - The mode of aggregation (e.g., Partial, Final, etc.). -/// -/// # Returns -/// -/// An `Option` representing the computed finer lexical ordering, -/// or `None` if there is no finer ordering; e.g. the existing requirement and -/// the aggregator requirement is incompatible. -fn finer_ordering( - existing_req: &LexOrdering, - aggr_expr: &AggregateFunctionExpr, - group_by: &PhysicalGroupBy, - eq_properties: &EquivalenceProperties, - agg_mode: &AggregateMode, -) -> Option { - let aggr_req = get_aggregate_expr_req(aggr_expr, group_by, agg_mode); - eq_properties.get_finer_ordering(existing_req, aggr_req.as_ref()) + LexOrdering::new(sort_exprs) } /// Concatenates the given slices. @@ -1103,7 +1101,23 @@ pub fn concat_slices(lhs: &[T], rhs: &[T]) -> Vec { [lhs, rhs].concat() } -/// Get the common requirement that satisfies all the aggregate expressions. +// Determines if the candidate ordering is finer than the current ordering. +// Returns `None` if they are incomparable, `Some(true)` if there is no current +// ordering or candidate ordering is finer, and `Some(false)` otherwise. +fn determine_finer( + current: &Option, + candidate: &LexOrdering, +) -> Option { + if let Some(ordering) = current { + candidate.partial_cmp(ordering).map(|cmp| cmp.is_gt()) + } else { + Some(true) + } +} + +/// Gets the common requirement that satisfies all the aggregate expressions. +/// When possible, chooses the requirement that is already satisfied by the +/// equivalence properties. /// /// # Parameters /// @@ -1118,75 +1132,75 @@ pub fn concat_slices(lhs: &[T], rhs: &[T]) -> Vec { /// /// # Returns /// -/// A `LexRequirement` instance, which is the requirement that satisfies all the -/// aggregate requirements. Returns an error in case of conflicting requirements. +/// A `Result>` instance, which is the requirement +/// that satisfies all the aggregate requirements. Returns an error in case of +/// conflicting requirements. pub fn get_finer_aggregate_exprs_requirement( aggr_exprs: &mut [Arc], group_by: &PhysicalGroupBy, eq_properties: &EquivalenceProperties, agg_mode: &AggregateMode, -) -> Result { - let mut requirement = LexOrdering::default(); +) -> Result> { + let mut requirement = None; for aggr_expr in aggr_exprs.iter_mut() { - if let Some(finer_ordering) = - finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) - { - if eq_properties.ordering_satisfy(finer_ordering.as_ref()) { - // Requirement is satisfied by existing ordering - requirement = finer_ordering; + let Some(aggr_req) = get_aggregate_expr_req(aggr_expr, group_by, agg_mode) + .and_then(|o| eq_properties.normalize_sort_exprs(o)) + else { + // There is no aggregate ordering requirement, or it is trivially + // satisfied -- we can skip this expression. + continue; + }; + // If the common requirement is finer than the current expression's, + // we can skip this expression. If the latter is finer than the former, + // adopt it if it is satisfied by the equivalence properties. Otherwise, + // defer the analysis to the reverse expression. + let forward_finer = determine_finer(&requirement, &aggr_req); + if let Some(finer) = forward_finer { + if !finer { + continue; + } else if eq_properties.ordering_satisfy(aggr_req.clone())? { + requirement = Some(aggr_req); continue; } } if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() { - if let Some(finer_ordering) = finer_ordering( - &requirement, - &reverse_aggr_expr, - group_by, - eq_properties, - agg_mode, - ) { - if eq_properties.ordering_satisfy(finer_ordering.as_ref()) { - // Reverse requirement is satisfied by exiting ordering. - // Hence reverse the aggregator - requirement = finer_ordering; - *aggr_expr = Arc::new(reverse_aggr_expr); - continue; - } - } - } - if let Some(finer_ordering) = - finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) - { - // There is a requirement that both satisfies existing requirement and current - // aggregate requirement. Use updated requirement - requirement = finer_ordering; - continue; - } - if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() { - if let Some(finer_ordering) = finer_ordering( - &requirement, - &reverse_aggr_expr, - group_by, - eq_properties, - agg_mode, - ) { - // There is a requirement that both satisfies existing requirement and reverse - // aggregate requirement. Use updated requirement - requirement = finer_ordering; + let Some(rev_aggr_req) = + get_aggregate_expr_req(&reverse_aggr_expr, group_by, agg_mode) + .and_then(|o| eq_properties.normalize_sort_exprs(o)) + else { + // The reverse requirement is trivially satisfied -- just reverse + // the expression and continue with the next one: *aggr_expr = Arc::new(reverse_aggr_expr); continue; + }; + // If the common requirement is finer than the reverse expression's, + // just reverse it and continue the loop with the next aggregate + // expression. If the latter is finer than the former, adopt it if + // it is satisfied by the equivalence properties. Otherwise, adopt + // the forward expression. + if let Some(finer) = determine_finer(&requirement, &rev_aggr_req) { + if !finer { + *aggr_expr = Arc::new(reverse_aggr_expr); + } else if eq_properties.ordering_satisfy(rev_aggr_req.clone())? { + *aggr_expr = Arc::new(reverse_aggr_expr); + requirement = Some(rev_aggr_req); + } else { + requirement = Some(aggr_req); + } + } else if forward_finer.is_some() { + requirement = Some(aggr_req); + } else { + // Neither the existing requirement nor the current aggregate + // requirement satisfy the other (forward or reverse), this + // means they are conflicting. + return not_impl_err!( + "Conflicting ordering requirements in aggregate functions is not supported" + ); } } - - // Neither the existing requirement and current aggregate requirement satisfy the other, this means - // requirements are conflicting. Currently, we do not support - // conflicting requirements. - return not_impl_err!( - "Conflicting ordering requirements in aggregate functions is not supported" - ); } - Ok(LexRequirement::from(requirement)) + Ok(requirement.map_or_else(Vec::new, |o| o.into_iter().map(Into::into).collect())) } /// Returns physical expressions for arguments to evaluate against a batch. @@ -1209,9 +1223,7 @@ pub fn aggregate_expressions( // Append ordering requirements to expressions' results. This // way order sensitive aggregators can satisfy requirement // themselves. - if let Some(ordering_req) = agg.order_bys() { - result.extend(ordering_req.iter().map(|item| Arc::clone(&item.expr))); - } + result.extend(agg.order_bys().iter().map(|item| Arc::clone(&item.expr))); result }) .collect()), @@ -1924,6 +1936,13 @@ mod tests { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(self.schema().as_ref())); + } let (_, batches) = some_data(); Ok(common::compute_record_batch_statistics( &[batches], @@ -2211,14 +2230,14 @@ mod tests { schema: &Schema, sort_options: SortOptions, ) -> Result> { - let ordering_req = [PhysicalSortExpr { + let order_bys = vec![PhysicalSortExpr { expr: col("b", schema)?, options: sort_options, }]; let args = [col("b", schema)?]; AggregateExprBuilder::new(first_value_udaf(), args.to_vec()) - .order_by(LexOrdering::new(ordering_req.to_vec())) + .order_by(order_bys) .schema(Arc::new(schema.clone())) .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]")) .build() @@ -2230,13 +2249,13 @@ mod tests { schema: &Schema, sort_options: SortOptions, ) -> Result> { - let ordering_req = [PhysicalSortExpr { + let order_bys = vec![PhysicalSortExpr { expr: col("b", schema)?, options: sort_options, }]; let args = [col("b", schema)?]; AggregateExprBuilder::new(last_value_udaf(), args.to_vec()) - .order_by(LexOrdering::new(ordering_req.to_vec())) + .order_by(order_bys) .schema(Arc::new(schema.clone())) .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]")) .build() @@ -2358,9 +2377,7 @@ mod tests { async fn test_get_finest_requirements() -> Result<()> { let test_schema = create_test_schema()?; - // Assume column a and b are aliases - // Assume also that a ASC and c DESC describe the same global ordering for the table. (Since they are ordering equivalent). - let options1 = SortOptions { + let options = SortOptions { descending: false, nulls_first: false, }; @@ -2369,58 +2386,51 @@ mod tests { let col_c = &col("c", &test_schema)?; let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); // Columns a and b are equal. - eq_properties.add_equal_conditions(col_a, col_b)?; + eq_properties.add_equal_conditions(Arc::clone(col_a), Arc::clone(col_b))?; // Aggregate requirements are // [None], [a ASC], [a ASC, b ASC, c ASC], [a ASC, b ASC] respectively let order_by_exprs = vec![ - None, - Some(vec![PhysicalSortExpr { + vec![], + vec![PhysicalSortExpr { expr: Arc::clone(col_a), - options: options1, - }]), - Some(vec![ + options, + }], + vec![ PhysicalSortExpr { expr: Arc::clone(col_a), - options: options1, + options, }, PhysicalSortExpr { expr: Arc::clone(col_b), - options: options1, + options, }, PhysicalSortExpr { expr: Arc::clone(col_c), - options: options1, + options, }, - ]), - Some(vec![ + ], + vec![ PhysicalSortExpr { expr: Arc::clone(col_a), - options: options1, + options, }, PhysicalSortExpr { expr: Arc::clone(col_b), - options: options1, + options, }, - ]), + ], ]; - let common_requirement = LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::clone(col_a), - options: options1, - }, - PhysicalSortExpr { - expr: Arc::clone(col_c), - options: options1, - }, - ]); + let common_requirement = vec![ + PhysicalSortRequirement::new(Arc::clone(col_a), Some(options)), + PhysicalSortRequirement::new(Arc::clone(col_c), Some(options)), + ]; let mut aggr_exprs = order_by_exprs .into_iter() .map(|order_by_expr| { - let ordering_req = order_by_expr.unwrap_or_default(); AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)]) .alias("a") - .order_by(LexOrdering::new(ordering_req.to_vec())) + .order_by(order_by_expr) .schema(Arc::clone(&test_schema)) .build() .map(Arc::new) @@ -2428,14 +2438,13 @@ mod tests { }) .collect::>(); let group_by = PhysicalGroupBy::new_single(vec![]); - let res = get_finer_aggregate_exprs_requirement( + let result = get_finer_aggregate_exprs_requirement( &mut aggr_exprs, &group_by, &eq_properties, &AggregateMode::Partial, )?; - let res = LexOrdering::from(res); - assert_eq!(res, common_requirement); + assert_eq!(result, common_requirement); Ok(()) } diff --git a/datafusion/physical-plan/src/aggregates/order/mod.rs b/datafusion/physical-plan/src/aggregates/order/mod.rs index 0b742b3d20fd..bbcb30d877cf 100644 --- a/datafusion/physical-plan/src/aggregates/order/mod.rs +++ b/datafusion/physical-plan/src/aggregates/order/mod.rs @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::mem::size_of; + use arrow::array::ArrayRef; -use arrow::datatypes::Schema; use datafusion_common::Result; use datafusion_expr::EmitTo; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use std::mem::size_of; mod full; mod partial; @@ -42,15 +41,11 @@ pub enum GroupOrdering { impl GroupOrdering { /// Create a `GroupOrdering` for the specified ordering - pub fn try_new( - input_schema: &Schema, - mode: &InputOrderMode, - ordering: &LexOrdering, - ) -> Result { + pub fn try_new(mode: &InputOrderMode) -> Result { match mode { InputOrderMode::Linear => Ok(GroupOrdering::None), InputOrderMode::PartiallySorted(order_indices) => { - GroupOrderingPartial::try_new(input_schema, order_indices, ordering) + GroupOrderingPartial::try_new(order_indices.clone()) .map(GroupOrdering::Partial) } InputOrderMode::Sorted => Ok(GroupOrdering::Full(GroupOrderingFull::new())), diff --git a/datafusion/physical-plan/src/aggregates/order/partial.rs b/datafusion/physical-plan/src/aggregates/order/partial.rs index c7a75e5f2640..3e495900f77a 100644 --- a/datafusion/physical-plan/src/aggregates/order/partial.rs +++ b/datafusion/physical-plan/src/aggregates/order/partial.rs @@ -15,18 +15,17 @@ // specific language governing permissions and limitations // under the License. +use std::cmp::Ordering; +use std::mem::size_of; +use std::sync::Arc; + use arrow::array::ArrayRef; use arrow::compute::SortOptions; -use arrow::datatypes::Schema; use arrow_ord::partition::partition; use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{Result, ScalarValue}; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; -use datafusion_physical_expr_common::sort_expr::LexOrdering; -use std::cmp::Ordering; -use std::mem::size_of; -use std::sync::Arc; /// Tracks grouping state when the data is ordered by some subset of /// the group keys. @@ -118,17 +117,11 @@ impl State { impl GroupOrderingPartial { /// TODO: Remove unnecessary `input_schema` parameter. - pub fn try_new( - _input_schema: &Schema, - order_indices: &[usize], - ordering: &LexOrdering, - ) -> Result { - assert!(!order_indices.is_empty()); - assert!(order_indices.len() <= ordering.len()); - + pub fn try_new(order_indices: Vec) -> Result { + debug_assert!(!order_indices.is_empty()); Ok(Self { state: State::Start, - order_indices: order_indices.to_vec(), + order_indices, }) } @@ -276,29 +269,15 @@ impl GroupOrderingPartial { #[cfg(test)] mod tests { - use arrow::array::Int32Array; - use arrow_schema::{DataType, Field}; - use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; - use super::*; + use arrow::array::Int32Array; + #[test] fn test_group_ordering_partial() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - ]); - // Ordered on column a let order_indices = vec![0]; - - let ordering = LexOrdering::new(vec![PhysicalSortExpr::new( - col("a", &schema)?, - SortOptions::default(), - )]); - - let mut group_ordering = - GroupOrderingPartial::try_new(&schema, &order_indices, &ordering)?; + let mut group_ordering = GroupOrderingPartial::try_new(order_indices)?; let batch_group_values: Vec = vec![ Arc::new(Int32Array::from(vec![1, 2, 3])), diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 232565a04466..6233abde63c6 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -21,6 +21,8 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::vec; +use super::order::GroupOrdering; +use super::AggregateExec; use crate::aggregates::group_values::{new_group_values, GroupValues}; use crate::aggregates::order::GroupOrderingFull; use crate::aggregates::{ @@ -29,28 +31,24 @@ use crate::aggregates::{ }; use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput}; use crate::sorts::sort::sort_batch; -use crate::sorts::streaming_merge::StreamingMergeBuilder; +use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::spill::spill_manager::SpillManager; use crate::stream::RecordBatchStreamAdapter; -use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr}; +use crate::{aggregates, metrics, PhysicalExpr}; use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::*; -use arrow::compute::SortOptions; use arrow::datatypes::SchemaRef; use datafusion_common::{internal_err, DataFusionError, Result}; -use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr}; - -use super::order::GroupOrdering; -use super::AggregateExec; -use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datafusion_physical_expr_common::sort_expr::LexOrdering; + use futures::ready; use futures::stream::{Stream, StreamExt}; use log::debug; @@ -100,7 +98,7 @@ struct SpillState { // ======================================================================== /// If data has previously been spilled, the locations of the /// spill files (in Arrow IPC format) - spills: Vec, + spills: Vec, /// true when streaming merge is in progress is_stream_merging: bool, @@ -519,30 +517,31 @@ impl GroupedHashAggregateStream { let partial_agg_schema = Arc::new(partial_agg_schema); - let spill_expr = group_schema - .fields - .into_iter() - .enumerate() - .map(|(idx, field)| PhysicalSortExpr { - expr: Arc::new(Column::new(field.name().as_str(), idx)) as _, - options: SortOptions::default(), - }) - .collect(); + let spill_expr = + group_schema + .fields + .into_iter() + .enumerate() + .map(|(idx, field)| { + PhysicalSortExpr::new_default(Arc::new(Column::new( + field.name().as_str(), + idx, + )) as _) + }); + let Some(spill_expr) = LexOrdering::new(spill_expr) else { + return internal_err!("Spill expression is empty"); + }; - let name = format!("GroupedHashAggregateStream[{partition}]"); + let agg_fn_names = aggregate_exprs + .iter() + .map(|expr| expr.human_display()) + .collect::>() + .join(", "); + let name = format!("GroupedHashAggregateStream[{partition}] ({agg_fn_names})"); let reservation = MemoryConsumer::new(name) .with_can_spill(true) .register(context.memory_pool()); - let (ordering, _) = agg - .properties() - .equivalence_properties() - .find_longest_permutation(&agg_group_by.output_exprs()); - let group_ordering = GroupOrdering::try_new( - &group_schema, - &agg.input_order_mode, - ordering.as_ref(), - )?; - + let group_ordering = GroupOrdering::try_new(&agg.input_order_mode)?; let group_values = new_group_values(group_schema, &group_ordering)?; timer.done(); @@ -552,7 +551,8 @@ impl GroupedHashAggregateStream { context.runtime_env(), metrics::SpillMetrics::new(&agg.metrics, partition), Arc::clone(&partial_agg_schema), - ); + ) + .with_compression_type(context.session_config().spill_compression()); let spill_state = SpillState { spills: vec![], @@ -996,16 +996,24 @@ impl GroupedHashAggregateStream { let Some(emit) = self.emit(EmitTo::All, true)? else { return Ok(()); }; - let sorted = sort_batch(&emit, self.spill_state.spill_expr.as_ref(), None)?; + let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?; // Spill sorted state to disk - let spillfile = self.spill_state.spill_manager.spill_record_batch_by_size( - &sorted, - "HashAggSpill", - self.batch_size, - )?; + let spillfile = self + .spill_state + .spill_manager + .spill_record_batch_by_size_and_return_max_batch_memory( + &sorted, + "HashAggSpill", + self.batch_size, + )?; match spillfile { - Some(spillfile) => self.spill_state.spills.push(spillfile), + Some((spillfile, max_record_batch_memory)) => { + self.spill_state.spills.push(SortedSpillFile { + file: spillfile, + max_record_batch_memory, + }) + } None => { return internal_err!( "Calling spill with no intermediate batch to spill" @@ -1063,18 +1071,17 @@ impl GroupedHashAggregateStream { streams.push(Box::pin(RecordBatchStreamAdapter::new( Arc::clone(&schema), futures::stream::once(futures::future::lazy(move |_| { - sort_batch(&batch, expr.as_ref(), None) + sort_batch(&batch, &expr, None) })), ))); - for spill in self.spill_state.spills.drain(..) { - let stream = self.spill_state.spill_manager.read_spill_as_stream(spill)?; - streams.push(stream); - } + self.spill_state.is_stream_merging = true; self.input = StreamingMergeBuilder::new() .with_streams(streams) .with_schema(schema) - .with_expressions(self.spill_state.spill_expr.as_ref()) + .with_spill_manager(self.spill_state.spill_manager.clone()) + .with_sorted_spill_files(std::mem::take(&mut self.spill_state.spills)) + .with_expressions(self.spill_state.spill_expr) .with_metrics(self.baseline_metrics.clone()) .with_batch_size(self.batch_size) .with_reservation(self.reservation.new_empty()) diff --git a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs index ae44eb35e6d0..47052fd52511 100644 --- a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs @@ -461,7 +461,7 @@ mod tests { let (_heap_idxs, map_idxs): (Vec<_>, Vec<_>) = heap_to_map.into_iter().unzip(); let ids = unsafe { map.take_all(map_idxs) }; assert_eq!( - format!("{:?}", ids), + format!("{ids:?}"), r#"[Some("1"), Some("2"), Some("3"), Some("4"), Some("5")]"# ); assert_eq!(map.len(), 0, "Map should have been cleared!"); diff --git a/datafusion/physical-plan/src/aggregates/topk/heap.rs b/datafusion/physical-plan/src/aggregates/topk/heap.rs index 8b4b07d211a0..ce47504daf03 100644 --- a/datafusion/physical-plan/src/aggregates/topk/heap.rs +++ b/datafusion/physical-plan/src/aggregates/topk/heap.rs @@ -348,7 +348,7 @@ impl TopKHeap { prefix, connector, hi.val, idx, hi.map_idx )); let new_prefix = if is_tail { "" } else { "│ " }; - let child_prefix = format!("{}{}", prefix, new_prefix); + let child_prefix = format!("{prefix}{new_prefix}"); let left_idx = idx * 2 + 1; let right_idx = idx * 2 + 2; @@ -372,7 +372,7 @@ impl Display for TopKHeap { if !self.heap.is_empty() { self._tree_print(0, String::new(), true, &mut output); } - write!(f, "{}", output) + write!(f, "{output}") } } diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs new file mode 100644 index 000000000000..7e9ae827d5d1 --- /dev/null +++ b/datafusion/physical-plan/src/async_func.rs @@ -0,0 +1,299 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use crate::stream::RecordBatchStreamAdapter; +use crate::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, +}; +use arrow::array::RecordBatch; +use arrow_schema::{Fields, Schema, SchemaRef}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::{internal_err, Result}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; +use datafusion_physical_expr::equivalence::ProjectionMapping; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::ScalarFunctionExpr; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use futures::stream::StreamExt; +use log::trace; +use std::any::Any; +use std::sync::Arc; + +/// This structure evaluates a set of async expressions on a record +/// batch producing a new record batch +/// +/// The schema of the output of the AsyncFuncExec is: +/// Input columns followed by one column for each async expression +#[derive(Debug)] +pub struct AsyncFuncExec { + /// The async expressions to evaluate + async_exprs: Vec>, + input: Arc, + cache: PlanProperties, + metrics: ExecutionPlanMetricsSet, +} + +impl AsyncFuncExec { + pub fn try_new( + async_exprs: Vec>, + input: Arc, + ) -> Result { + let async_fields = async_exprs + .iter() + .map(|async_expr| async_expr.field(input.schema().as_ref())) + .collect::>>()?; + + // compute the output schema: input schema then async expressions + let fields: Fields = input + .schema() + .fields() + .iter() + .cloned() + .chain(async_fields.into_iter().map(Arc::new)) + .collect(); + + let schema = Arc::new(Schema::new(fields)); + let tuples = async_exprs + .iter() + .map(|expr| (Arc::clone(&expr.func), expr.name().to_string())) + .collect::>(); + let async_expr_mapping = ProjectionMapping::try_new(tuples, &input.schema())?; + let cache = + AsyncFuncExec::compute_properties(&input, schema, &async_expr_mapping)?; + Ok(Self { + input, + async_exprs, + cache, + metrics: ExecutionPlanMetricsSet::new(), + }) + } + + /// This function creates the cache object that stores the plan properties + /// such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties( + input: &Arc, + schema: SchemaRef, + async_expr_mapping: &ProjectionMapping, + ) -> Result { + Ok(PlanProperties::new( + input + .equivalence_properties() + .project(async_expr_mapping, schema), + input.output_partitioning().clone(), + input.pipeline_behavior(), + input.boundedness(), + )) + } +} + +impl DisplayAs for AsyncFuncExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + let expr: Vec = self + .async_exprs + .iter() + .map(|async_expr| async_expr.to_string()) + .collect(); + let exprs = expr.join(", "); + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "AsyncFuncExec: async_expr=[{exprs}]") + } + DisplayFormatType::TreeRender => { + writeln!(f, "format=async_expr")?; + writeln!(f, "async_expr={exprs}")?; + Ok(()) + } + } + } +} + +impl ExecutionPlan for AsyncFuncExec { + fn name(&self) -> &str { + "async_func" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() != 1 { + return internal_err!("AsyncFuncExec wrong number of children"); + } + Ok(Arc::new(AsyncFuncExec::try_new( + self.async_exprs.clone(), + Arc::clone(&children[0]), + )?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + trace!( + "Start AsyncFuncExpr::execute for partition {} of context session_id {} and task_id {:?}", + partition, + context.session_id(), + context.task_id() + ); + // TODO figure out how to record metrics + + // first execute the input stream + let input_stream = self.input.execute(partition, Arc::clone(&context))?; + + // now, for each record batch, evaluate the async expressions and add the columns to the result + let async_exprs_captured = Arc::new(self.async_exprs.clone()); + let schema_captured = self.schema(); + let config_option_ref = Arc::new(context.session_config().options().clone()); + + let stream_with_async_functions = input_stream.then(move |batch| { + // need to clone *again* to capture the async_exprs and schema in the + // stream and satisfy lifetime requirements. + let async_exprs_captured = Arc::clone(&async_exprs_captured); + let schema_captured = Arc::clone(&schema_captured); + let config_option = Arc::clone(&config_option_ref); + + async move { + let batch = batch?; + // append the result of evaluating the async expressions to the output + let mut output_arrays = batch.columns().to_vec(); + for async_expr in async_exprs_captured.iter() { + let output = + async_expr.invoke_with_args(&batch, &config_option).await?; + output_arrays.push(output.to_array(batch.num_rows())?); + } + let batch = RecordBatch::try_new(schema_captured, output_arrays)?; + Ok(batch) + } + }); + + // Adapt the stream with the output schema + let adapter = + RecordBatchStreamAdapter::new(self.schema(), stream_with_async_functions); + Ok(Box::pin(adapter)) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} + +const ASYNC_FN_PREFIX: &str = "__async_fn_"; + +/// Maps async_expressions to new columns +/// +/// The output of the async functions are appended, in order, to the end of the input schema +#[derive(Debug)] +pub struct AsyncMapper { + /// the number of columns in the input plan + /// used to generate the output column names. + /// the first async expr is `__async_fn_0`, the second is `__async_fn_1`, etc + num_input_columns: usize, + /// the expressions to map + pub async_exprs: Vec>, +} + +impl AsyncMapper { + pub fn new(num_input_columns: usize) -> Self { + Self { + num_input_columns, + async_exprs: Vec::new(), + } + } + + pub fn is_empty(&self) -> bool { + self.async_exprs.is_empty() + } + + pub fn next_column_name(&self) -> String { + format!("{}{}", ASYNC_FN_PREFIX, self.async_exprs.len()) + } + + /// Finds any references to async functions in the expression and adds them to the map + pub fn find_references( + &mut self, + physical_expr: &Arc, + schema: &Schema, + ) -> Result<()> { + // recursively look for references to async functions + physical_expr.apply(|expr| { + if let Some(scalar_func_expr) = + expr.as_any().downcast_ref::() + { + if scalar_func_expr.fun().as_async().is_some() { + let next_name = self.next_column_name(); + self.async_exprs.push(Arc::new(AsyncFuncExpr::try_new( + next_name, + Arc::clone(expr), + schema, + )?)); + } + } + Ok(TreeNodeRecursion::Continue) + })?; + Ok(()) + } + + /// If the expression matches any of the async functions, return the new column + pub fn map_expr( + &self, + expr: Arc, + ) -> Transformed> { + // find the first matching async function if any + let Some(idx) = + self.async_exprs + .iter() + .enumerate() + .find_map(|(idx, async_expr)| { + if async_expr.func == Arc::clone(&expr) { + Some(idx) + } else { + None + } + }) + else { + return Transformed::no(expr); + }; + // rewrite in terms of the output column + Transformed::yes(self.output_column(idx)) + } + + /// return the output column for the async function at index idx + pub fn output_column(&self, idx: usize) -> Arc { + let async_expr = &self.async_exprs[idx]; + let output_idx = self.num_input_columns + idx; + Arc::new(Column::new(async_expr.name(), output_idx)) + } +} diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index 5244038b9ae2..78bd4b4fc3a0 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -32,9 +32,15 @@ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_execution::TaskContext; +use datafusion_physical_expr::PhysicalExpr; use crate::coalesce::{BatchCoalescer, CoalescerState}; use crate::execution_plan::CardinalityEffect; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, +}; +use datafusion_common::config::ConfigOptions; use futures::ready; use futures::stream::{Stream, StreamExt}; @@ -192,7 +198,16 @@ impl ExecutionPlan for CoalesceBatchesExec { } fn statistics(&self) -> Result { - Statistics::with_fetch(self.input.statistics()?, self.schema(), self.fetch, 0, 1) + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input.partition_statistics(partition)?.with_fetch( + self.schema(), + self.fetch, + 0, + 1, + ) } fn with_fetch(&self, limit: Option) -> Option> { @@ -212,6 +227,27 @@ impl ExecutionPlan for CoalesceBatchesExec { fn cardinality_effect(&self) -> CardinalityEffect { CardinalityEffect::Equal } + + fn gather_filters_for_pushdown( + &self, + _phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + Ok(FilterDescription::new_with_child_count(1) + .all_parent_filters_supported(parent_filters)) + } + + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::transparent( + child_pushdown_result, + )) + } } /// Stream for [`CoalesceBatchesExec`]. See [`CoalesceBatchesExec`] for more details. @@ -321,6 +357,7 @@ impl CoalesceBatchesStream { } } CoalesceBatchesStreamState::ReturnBuffer => { + let _timer = cloned_time.timer(); // Combine buffered batches into one batch and return it. let batch = self.coalescer.finish_batch()?; // Set to pull state for the next iteration. @@ -333,6 +370,7 @@ impl CoalesceBatchesStream { // If buffer is empty, return None indicating the stream is fully consumed. Poll::Ready(None) } else { + let _timer = cloned_time.timer(); // If the buffer still contains batches, prepare to return them. let batch = self.coalescer.finish_batch()?; Poll::Ready(Some(Ok(batch))) diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 95a0c8f6ce83..976ff70502b7 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -27,7 +27,7 @@ use super::{ DisplayAs, ExecutionPlanProperties, PlanProperties, SendableRecordBatchStream, Statistics, }; -use crate::execution_plan::CardinalityEffect; +use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType}; use crate::projection::{make_with_child, ProjectionExec}; use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; @@ -59,6 +59,12 @@ impl CoalescePartitionsExec { } } + /// Update fetch with the argument + pub fn with_fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } + /// Input execution plan pub fn input(&self) -> &Arc { &self.input @@ -66,6 +72,16 @@ impl CoalescePartitionsExec { /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties(input: &Arc) -> PlanProperties { + let input_partitions = input.output_partitioning().partition_count(); + let (drive, scheduling) = if input_partitions > 1 { + (EvaluationType::Eager, SchedulingType::Cooperative) + } else { + ( + input.properties().evaluation_type, + input.properties().scheduling_type, + ) + }; + // Coalescing partitions loses existing orderings: let mut eq_properties = input.equivalence_properties().clone(); eq_properties.clear_orderings(); @@ -76,6 +92,8 @@ impl CoalescePartitionsExec { input.pipeline_behavior(), input.boundedness(), ) + .with_evaluation_type(drive) + .with_scheduling_type(scheduling) } } @@ -190,7 +208,13 @@ impl ExecutionPlan for CoalescePartitionsExec { } fn statistics(&self) -> Result { - Statistics::with_fetch(self.input.statistics()?, self.schema(), self.fetch, 0, 1) + self.partition_statistics(None) + } + + fn partition_statistics(&self, _partition: Option) -> Result { + self.input + .partition_statistics(None)? + .with_fetch(self.schema(), self.fetch, 0, 1) } fn supports_limit_pushdown(&self) -> bool { diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index a8d4a3ddf3d1..35f3e8d16e22 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -18,8 +18,7 @@ //! Defines common code used in execution plans use std::fs; -use std::fs::{metadata, File}; -use std::path::{Path, PathBuf}; +use std::fs::metadata; use std::sync::Arc; use super::SendableRecordBatchStream; @@ -28,10 +27,9 @@ use crate::{ColumnStatistics, Statistics}; use arrow::array::Array; use arrow::datatypes::Schema; -use arrow::ipc::writer::{FileWriter, IpcWriteOptions}; use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; -use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_common::{plan_err, Result}; use datafusion_execution::memory_pool::MemoryReservation; use futures::{StreamExt, TryStreamExt}; @@ -180,77 +178,6 @@ pub fn compute_record_batch_statistics( } } -/// Write in Arrow IPC File format. -pub struct IPCWriter { - /// Path - pub path: PathBuf, - /// Inner writer - pub writer: FileWriter, - /// Batches written - pub num_batches: usize, - /// Rows written - pub num_rows: usize, - /// Bytes written - pub num_bytes: usize, -} - -impl IPCWriter { - /// Create new writer - pub fn new(path: &Path, schema: &Schema) -> Result { - let file = File::create(path).map_err(|e| { - DataFusionError::Execution(format!( - "Failed to create partition file at {path:?}: {e:?}" - )) - })?; - Ok(Self { - num_batches: 0, - num_rows: 0, - num_bytes: 0, - path: path.into(), - writer: FileWriter::try_new(file, schema)?, - }) - } - - /// Create new writer with IPC write options - pub fn new_with_options( - path: &Path, - schema: &Schema, - write_options: IpcWriteOptions, - ) -> Result { - let file = File::create(path).map_err(|e| { - DataFusionError::Execution(format!( - "Failed to create partition file at {path:?}: {e:?}" - )) - })?; - Ok(Self { - num_batches: 0, - num_rows: 0, - num_bytes: 0, - path: path.into(), - writer: FileWriter::try_new_with_options(file, schema, write_options)?, - }) - } - /// Write one single batch - pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { - self.writer.write(batch)?; - self.num_batches += 1; - self.num_rows += batch.num_rows(); - let num_bytes: usize = batch.get_array_memory_size(); - self.num_bytes += num_bytes; - Ok(()) - } - - /// Finish the writer - pub fn finish(&mut self) -> Result<()> { - self.writer.finish().map_err(Into::into) - } - - /// Path write to - pub fn path(&self) -> &Path { - &self.path - } -} - /// Checks if the given projection is valid for the given schema. pub fn can_project( schema: &arrow::datatypes::SchemaRef, diff --git a/datafusion/physical-plan/src/coop.rs b/datafusion/physical-plan/src/coop.rs new file mode 100644 index 000000000000..be0afa07eac2 --- /dev/null +++ b/datafusion/physical-plan/src/coop.rs @@ -0,0 +1,370 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for improved cooperative scheduling. +//! +//! # Cooperative scheduling +//! +//! A single call to `poll_next` on a top-level [`Stream`] may potentially perform a lot of work +//! before it returns a `Poll::Pending`. Think for instance of calculating an aggregation over a +//! large dataset. +//! If a `Stream` runs for a long period of time without yielding back to the Tokio executor, +//! it can starve other tasks waiting on that executor to execute them. +//! Additionally, this prevents the query execution from being cancelled. +//! +//! To ensure that `Stream` implementations yield regularly, operators can insert explicit yield +//! points using the utilities in this module. For most operators this is **not** necessary. The +//! `Stream`s of the built-in DataFusion operators that generate (rather than manipulate) +//! `RecordBatch`es such as `DataSourceExec` and those that eagerly consume `RecordBatch`es +//! (for instance, `RepartitionExec`) contain yield points that will make most query `Stream`s yield +//! periodically. +//! +//! There are a couple of types of operators that _should_ insert yield points: +//! - New source operators that do not make use of Tokio resources +//! - Exchange like operators that do not use Tokio's `Channel` implementation to pass data between +//! tasks +//! +//! ## Adding yield points +//! +//! Yield points can be inserted manually using the facilities provided by the +//! [Tokio coop module](https://docs.rs/tokio/latest/tokio/task/coop/index.html) such as +//! [`tokio::task::coop::consume_budget`](https://docs.rs/tokio/latest/tokio/task/coop/fn.consume_budget.html). +//! +//! Another option is to use the wrapper `Stream` implementation provided by this module which will +//! consume a unit of task budget every time a `RecordBatch` is produced. +//! Wrapper `Stream`s can be created using the [`cooperative`] and [`make_cooperative`] functions. +//! +//! [`cooperative`] is a generic function that takes ownership of the wrapped [`RecordBatchStream`]. +//! This function has the benefit of not requiring an additional heap allocation and can avoid +//! dynamic dispatch. +//! +//! [`make_cooperative`] is a non-generic function that wraps a [`SendableRecordBatchStream`]. This +//! can be used to wrap dynamically typed, heap allocated [`RecordBatchStream`]s. +//! +//! ## Automatic cooperation +//! +//! The `EnsureCooperative` physical optimizer rule, which is included in the default set of +//! optimizer rules, inspects query plans for potential cooperative scheduling issues. +//! It injects the [`CooperativeExec`] wrapper `ExecutionPlan` into the query plan where necessary. +//! This `ExecutionPlan` uses [`make_cooperative`] to wrap the `Stream` of its input. +//! +//! The optimizer rule currently checks the plan for exchange-like operators and leave operators +//! that report [`SchedulingType::NonCooperative`] in their [plan properties](ExecutionPlan::properties). + +#[cfg(any( + datafusion_coop = "tokio_fallback", + not(any(datafusion_coop = "tokio", datafusion_coop = "per_stream")) +))] +use futures::Future; +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::execution_plan::CardinalityEffect::{self, Equal}; +use crate::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, + SendableRecordBatchStream, +}; +use arrow::record_batch::RecordBatch; +use arrow_schema::Schema; +use datafusion_common::{internal_err, Result, Statistics}; +use datafusion_execution::TaskContext; + +use crate::execution_plan::SchedulingType; +use crate::stream::RecordBatchStreamAdapter; +use futures::{Stream, StreamExt}; + +/// A stream that passes record batches through unchanged while cooperating with the Tokio runtime. +/// It consumes cooperative scheduling budget for each returned [`RecordBatch`], +/// allowing other tasks to execute when the budget is exhausted. +/// +/// See the [module level documentation](crate::coop) for an in-depth discussion. +pub struct CooperativeStream +where + T: RecordBatchStream + Unpin, +{ + inner: T, + #[cfg(datafusion_coop = "per_stream")] + budget: u8, +} + +#[cfg(datafusion_coop = "per_stream")] +// Magic value that matches Tokio's task budget value +const YIELD_FREQUENCY: u8 = 128; + +impl CooperativeStream +where + T: RecordBatchStream + Unpin, +{ + /// Creates a new `CooperativeStream` that wraps the provided stream. + /// The resulting stream will cooperate with the Tokio scheduler by consuming a unit of + /// scheduling budget when the wrapped `Stream` returns a record batch. + pub fn new(inner: T) -> Self { + Self { + inner, + #[cfg(datafusion_coop = "per_stream")] + budget: YIELD_FREQUENCY, + } + } +} + +impl Stream for CooperativeStream +where + T: RecordBatchStream + Unpin, +{ + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + #[cfg(datafusion_coop = "tokio")] + { + // TODO this should be the default implementation + // Enable once https://github.com/tokio-rs/tokio/issues/7403 is merged and released + let coop = std::task::ready!(tokio::task::coop::poll_proceed(cx)); + let value = self.inner.poll_next_unpin(cx); + if value.is_ready() { + coop.made_progress(); + } + value + } + + #[cfg(any( + datafusion_coop = "tokio_fallback", + not(any(datafusion_coop = "tokio", datafusion_coop = "per_stream")) + ))] + { + // This is a temporary placeholder implementation that may have slightly + // worse performance compared to `poll_proceed` + if !tokio::task::coop::has_budget_remaining() { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + + let value = self.inner.poll_next_unpin(cx); + if value.is_ready() { + // In contrast to `poll_proceed` we are not able to consume + // budget before proceeding to do work. Instead, we try to consume budget + // after the work has been done and just assume that that succeeded. + // The poll result is ignored because we don't want to discard + // or buffer the Ready result we got from the inner stream. + let consume = tokio::task::coop::consume_budget(); + let consume_ref = std::pin::pin!(consume); + let _ = consume_ref.poll(cx); + } + value + } + + #[cfg(datafusion_coop = "per_stream")] + { + if self.budget == 0 { + self.budget = YIELD_FREQUENCY; + cx.waker().wake_by_ref(); + return Poll::Pending; + } + + let value = { self.inner.poll_next_unpin(cx) }; + + if value.is_ready() { + self.budget -= 1; + } else { + self.budget = YIELD_FREQUENCY; + } + value + } + } +} + +impl RecordBatchStream for CooperativeStream +where + T: RecordBatchStream + Unpin, +{ + fn schema(&self) -> Arc { + self.inner.schema() + } +} + +/// An execution plan decorator that enables cooperative multitasking. +/// It wraps the streams produced by its input execution plan using the [`make_cooperative`] function, +/// which makes the stream participate in Tokio cooperative scheduling. +#[derive(Debug)] +pub struct CooperativeExec { + input: Arc, + properties: PlanProperties, +} + +impl CooperativeExec { + /// Creates a new `CooperativeExec` operator that wraps the given input execution plan. + pub fn new(input: Arc) -> Self { + let properties = input + .properties() + .clone() + .with_scheduling_type(SchedulingType::Cooperative); + + Self { input, properties } + } + + /// Returns a reference to the wrapped input execution plan. + pub fn input(&self) -> &Arc { + &self.input + } +} + +impl DisplayAs for CooperativeExec { + fn fmt_as( + &self, + _t: DisplayFormatType, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "CooperativeExec") + } +} + +impl ExecutionPlan for CooperativeExec { + fn name(&self) -> &str { + "CooperativeExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> Arc { + self.input.schema() + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn maintains_input_order(&self) -> Vec { + self.input.maintains_input_order() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + if children.len() != 1 { + return internal_err!("CooperativeExec requires exactly one child"); + } + Ok(Arc::new(CooperativeExec::new(children.swap_remove(0)))) + } + + fn execute( + &self, + partition: usize, + task_ctx: Arc, + ) -> Result { + let child_stream = self.input.execute(partition, task_ctx)?; + Ok(make_cooperative(child_stream)) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input.partition_statistics(partition) + } + + fn supports_limit_pushdown(&self) -> bool { + true + } + + fn cardinality_effect(&self) -> CardinalityEffect { + Equal + } +} + +/// Creates a [`CooperativeStream`] wrapper around the given [`RecordBatchStream`]. +/// This wrapper collaborates with the Tokio cooperative scheduler by consuming a unit of +/// scheduling budget for each returned record batch. +pub fn cooperative(stream: T) -> CooperativeStream +where + T: RecordBatchStream + Unpin + Send + 'static, +{ + CooperativeStream::new(stream) +} + +/// Wraps a `SendableRecordBatchStream` inside a [`CooperativeStream`] to enable cooperative multitasking. +/// Since `SendableRecordBatchStream` is a `dyn RecordBatchStream` this requires the use of dynamic +/// method dispatch. +/// When the stream type is statically known, consider use the generic [`cooperative`] function +/// to allow static method dispatch. +pub fn make_cooperative(stream: SendableRecordBatchStream) -> SendableRecordBatchStream { + // TODO is there a more elegant way to overload cooperative + Box::pin(cooperative(RecordBatchStreamAdapter::new( + stream.schema(), + stream, + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::stream::RecordBatchStreamAdapter; + + use arrow_schema::SchemaRef; + + use futures::{stream, StreamExt}; + + // This is the hardcoded value Tokio uses + const TASK_BUDGET: usize = 128; + + /// Helper: construct a SendableRecordBatchStream containing `n` empty batches + fn make_empty_batches(n: usize) -> SendableRecordBatchStream { + let schema: SchemaRef = Arc::new(Schema::empty()); + let schema_for_stream = Arc::clone(&schema); + + let s = + stream::iter((0..n).map(move |_| { + Ok(RecordBatch::new_empty(Arc::clone(&schema_for_stream))) + })); + + Box::pin(RecordBatchStreamAdapter::new(schema, s)) + } + + #[tokio::test] + async fn yield_less_than_threshold() -> Result<()> { + let count = TASK_BUDGET - 10; + let inner = make_empty_batches(count); + let out = make_cooperative(inner).collect::>().await; + assert_eq!(out.len(), count); + Ok(()) + } + + #[tokio::test] + async fn yield_equal_to_threshold() -> Result<()> { + let count = TASK_BUDGET; + let inner = make_empty_batches(count); + let out = make_cooperative(inner).collect::>().await; + assert_eq!(out.len(), count); + Ok(()) + } + + #[tokio::test] + async fn yield_more_than_threshold() -> Result<()> { + let count = TASK_BUDGET + 20; + let inner = make_empty_batches(count); + let out = make_cooperative(inner).collect::>().await; + assert_eq!(out.len(), count); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index e247f5ad9d19..56335f13d01b 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -394,8 +394,8 @@ impl ExecutionPlanVisitor for IndentVisitor<'_, '_> { } } if self.show_statistics { - let stats = plan.statistics().map_err(|_e| fmt::Error)?; - write!(self.f, ", statistics=[{}]", stats)?; + let stats = plan.partition_statistics(None).map_err(|_e| fmt::Error)?; + write!(self.f, ", statistics=[{stats}]")?; } if self.show_schema { write!( @@ -479,8 +479,8 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { }; let statistics = if self.show_statistics { - let stats = plan.statistics().map_err(|_e| fmt::Error)?; - format!("statistics=[{}]", stats) + let stats = plan.partition_statistics(None).map_err(|_e| fmt::Error)?; + format!("statistics=[{stats}]") } else { "".to_string() }; @@ -495,7 +495,7 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { self.f, id, &label, - Some(&format!("{}{}{}", metrics, delimiter, statistics)), + Some(&format!("{metrics}{delimiter}{statistics}")), )?; if let Some(parent_node_id) = self.parents.last() { @@ -686,7 +686,7 @@ impl TreeRenderVisitor<'_, '_> { &render_text, Self::NODE_RENDER_WIDTH - 2, ); - write!(self.f, "{}", render_text)?; + write!(self.f, "{render_text}")?; if render_y == halfway_point && node.child_positions.len() > 1 { write!(self.f, "{}", Self::LMIDDLE)?; @@ -856,10 +856,10 @@ impl TreeRenderVisitor<'_, '_> { if str.is_empty() { str = key.to_string(); } else if !is_multiline && total_size < available_width { - str = format!("{}: {}", key, str); + str = format!("{key}: {str}"); is_inlined = true; } else { - str = format!("{}:\n{}", key, str); + str = format!("{key}:\n{str}"); } if is_inlined && was_inlined { @@ -902,7 +902,7 @@ impl TreeRenderVisitor<'_, '_> { let render_width = source.chars().count(); if render_width > max_render_width { let truncated = &source[..max_render_width - 3]; - format!("{}...", truncated) + format!("{truncated}...") } else { let total_spaces = max_render_width - render_width; let half_spaces = total_spaces / 2; @@ -1034,27 +1034,22 @@ impl fmt::Display for ProjectSchemaDisplay<'_> { } pub fn display_orderings(f: &mut Formatter, orderings: &[LexOrdering]) -> fmt::Result { - if let Some(ordering) = orderings.first() { - if !ordering.is_empty() { - let start = if orderings.len() == 1 { - ", output_ordering=" - } else { - ", output_orderings=[" - }; - write!(f, "{}", start)?; - for (idx, ordering) in - orderings.iter().enumerate().filter(|(_, o)| !o.is_empty()) - { - match idx { - 0 => write!(f, "[{}]", ordering)?, - _ => write!(f, ", [{}]", ordering)?, - } + if !orderings.is_empty() { + let start = if orderings.len() == 1 { + ", output_ordering=" + } else { + ", output_orderings=[" + }; + write!(f, "{start}")?; + for (idx, ordering) in orderings.iter().enumerate() { + match idx { + 0 => write!(f, "[{ordering}]")?, + _ => write!(f, ", [{ordering}]")?, } - let end = if orderings.len() == 1 { "" } else { "]" }; - write!(f, "{}", end)?; } + let end = if orderings.len() == 1 { "" } else { "]" }; + write!(f, "{end}")?; } - Ok(()) } @@ -1120,6 +1115,13 @@ mod tests { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(self.schema().as_ref())); + } match self { Self::Panic => panic!("expected panic"), Self::Error => { diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index 3fdde39df6f1..40b4ec61dc10 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -33,6 +33,7 @@ use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; +use crate::execution_plan::SchedulingType; use log::trace; /// Execution plan for empty relation with produce_one_row=false @@ -81,6 +82,7 @@ impl EmptyExec { EmissionType::Incremental, Boundedness::Bounded, ) + .with_scheduling_type(SchedulingType::Cooperative) } } @@ -150,6 +152,20 @@ impl ExecutionPlan for EmptyExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if let Some(partition) = partition { + if partition >= self.partitions { + return internal_err!( + "EmptyExec invalid partition {} (expected less than {})", + partition, + self.partitions + ); + } + } + let batch = self .data() .expect("Create empty RecordBatch should not fail"); diff --git a/datafusion/physical-plan/src/execution_plan.rs b/datafusion/physical-plan/src/execution_plan.rs index 2bc5706ee0e1..90385c58a6ac 100644 --- a/datafusion/physical-plan/src/execution_plan.rs +++ b/datafusion/physical-plan/src/execution_plan.rs @@ -16,6 +16,10 @@ // under the License. pub use crate::display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, +}; pub use crate::metrics::Metric; pub use crate::ordering::InputOrderMode; pub use crate::stream::EmptyRecordBatchStream; @@ -38,8 +42,6 @@ use crate::coalesce_partitions::CoalescePartitionsExec; use crate::display::DisplayableExecutionPlan; use crate::metrics::MetricsSet; use crate::projection::ProjectionExec; -use crate::repartition::RepartitionExec; -use crate::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::stream::RecordBatchStreamAdapter; use arrow::array::{Array, RecordBatch}; @@ -48,8 +50,8 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::{exec_err, Constraints, Result}; use datafusion_common_runtime::JoinSet; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; -use datafusion_physical_expr_common::sort_expr::LexRequirement; +use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; use futures::stream::{StreamExt, TryStreamExt}; @@ -136,7 +138,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// NOTE that checking `!is_empty()` does **not** check for a /// required input ordering. Instead, the correct check is that at /// least one entry must be `Some` - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { vec![None; self.children().len()] } @@ -267,11 +269,13 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// batch is superlinear. See this [general guideline][async-guideline] for more context /// on this point, which explains why one should avoid spending a long time without /// reaching an `await`/yield point in asynchronous runtimes. - /// This can be achieved by manually returning [`Poll::Pending`] and setting up wakers - /// appropriately, or the use of [`tokio::task::yield_now()`] when appropriate. + /// This can be achieved by using the utilities from the [`coop`](crate::coop) module, by + /// manually returning [`Poll::Pending`] and setting up wakers appropriately, or by calling + /// [`tokio::task::yield_now()`] when appropriate. /// In special cases that warrant manual yielding, determination for "regularly" may be - /// made using a timer (being careful with the overhead-heavy system call needed to - /// take the time), or by counting rows or batches. + /// made using the [Tokio task budget](https://docs.rs/tokio/latest/tokio/task/coop/index.html), + /// a timer (being careful with the overhead-heavy system call needed to take the time), or by + /// counting rows or batches. /// /// The [cancellation benchmark] tracks some cases of how quickly queries can /// be cancelled. @@ -423,10 +427,30 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// /// For TableScan executors, which supports filter pushdown, special attention /// needs to be paid to whether the stats returned by this method are exact or not + #[deprecated(since = "48.0.0", note = "Use `partition_statistics` method instead")] fn statistics(&self) -> Result { Ok(Statistics::new_unknown(&self.schema())) } + /// Returns statistics for a specific partition of this `ExecutionPlan` node. + /// If statistics are not available, should return [`Statistics::new_unknown`] + /// (the default), not an error. + /// If `partition` is `None`, it returns statistics for the entire plan. + fn partition_statistics(&self, partition: Option) -> Result { + if let Some(idx) = partition { + // Validate partition index + let partition_count = self.properties().partitioning.partition_count(); + if idx >= partition_count { + return internal_err!( + "Invalid partition index: {}, the partition count is {}", + idx, + partition_count + ); + } + } + Ok(Statistics::new_unknown(&self.schema())) + } + /// Returns `true` if a limit can be safely pushed down through this /// `ExecutionPlan` node. /// @@ -467,6 +491,154 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { ) -> Result>> { Ok(None) } + + /// Collect filters that this node can push down to its children. + /// Filters that are being pushed down from parents are passed in, + /// and the node may generate additional filters to push down. + /// For example, given the plan FilterExec -> HashJoinExec -> DataSourceExec, + /// what will happen is that we recurse down the plan calling `ExecutionPlan::gather_filters_for_pushdown`: + /// 1. `FilterExec::gather_filters_for_pushdown` is called with no parent + /// filters so it only returns that `FilterExec` wants to push down its own predicate. + /// 2. `HashJoinExec::gather_filters_for_pushdown` is called with the filter from + /// `FilterExec`, which it only allows to push down to one side of the join (unless it's on the join key) + /// but it also adds its own filters (e.g. pushing down a bloom filter of the hash table to the scan side of the join). + /// 3. `DataSourceExec::gather_filters_for_pushdown` is called with both filters from `HashJoinExec` + /// and `FilterExec`, however `DataSourceExec::gather_filters_for_pushdown` doesn't actually do anything + /// since it has no children and no additional filters to push down. + /// It's only once [`ExecutionPlan::handle_child_pushdown_result`] is called on `DataSourceExec` as we recurse + /// up the plan that `DataSourceExec` can actually bind the filters. + /// + /// The default implementation bars all parent filters from being pushed down and adds no new filters. + /// This is the safest option, making filter pushdown opt-in on a per-node pasis. + /// + /// There are two different phases in filter pushdown, which some operators may handle the same and some differently. + /// Depending on the phase the operator may or may not be allowed to modify the plan. + /// See [`FilterPushdownPhase`] for more details. + fn gather_filters_for_pushdown( + &self, + _phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + Ok( + FilterDescription::new_with_child_count(self.children().len()) + .all_parent_filters_unsupported(parent_filters), + ) + } + + /// Handle the result of a child pushdown. + /// This method is called as we recurse back up the plan tree after pushing + /// filters down to child nodes via [`ExecutionPlan::gather_filters_for_pushdown`]. + /// It allows the current node to process the results of filter pushdown from + /// its children, deciding whether to absorb filters, modify the plan, or pass + /// filters back up to its parent. + /// + /// **Purpose and Context:** + /// Filter pushdown is a critical optimization in DataFusion that aims to + /// reduce the amount of data processed by applying filters as early as + /// possible in the query plan. This method is part of the second phase of + /// filter pushdown, where results are propagated back up the tree after + /// being pushed down. Each node can inspect the pushdown results from its + /// children and decide how to handle any unapplied filters, potentially + /// optimizing the plan structure or filter application. + /// + /// **Behavior in Different Nodes:** + /// - For a `DataSourceExec`, this often means absorbing the filters to apply + /// them during the scan phase (late materialization), reducing the data + /// read from the source. + /// - A `FilterExec` may absorb any filters its children could not handle, + /// combining them with its own predicate. If no filters remain (i.e., the + /// predicate becomes trivially true), it may remove itself from the plan + /// altogether. It typically marks parent filters as supported, indicating + /// they have been handled. + /// - A `HashJoinExec` might ignore the pushdown result if filters need to + /// be applied during the join operation. It passes the parent filters back + /// up wrapped in [`FilterPushdownPropagation::transparent`], discarding + /// any self-filters from children. + /// + /// **Example Walkthrough:** + /// Consider a query plan: `FilterExec (f1) -> HashJoinExec -> DataSourceExec`. + /// 1. **Downward Phase (`gather_filters_for_pushdown`):** Starting at + /// `FilterExec`, the filter `f1` is gathered and pushed down to + /// `HashJoinExec`. `HashJoinExec` may allow `f1` to pass to one side of + /// the join or add its own filters (e.g., a min-max filter from the build side), + /// then pushes filters to `DataSourceExec`. `DataSourceExec`, being a leaf node, + /// has no children to push to, so it prepares to handle filters in the + /// upward phase. + /// 2. **Upward Phase (`handle_child_pushdown_result`):** Starting at + /// `DataSourceExec`, it absorbs applicable filters from `HashJoinExec` + /// for late materialization during scanning, marking them as supported. + /// `HashJoinExec` receives the result, decides whether to apply any + /// remaining filters during the join, and passes unhandled filters back + /// up to `FilterExec`. `FilterExec` absorbs any unhandled filters, + /// updates its predicate if necessary, or removes itself if the predicate + /// becomes trivial (e.g., `lit(true)`), and marks filters as supported + /// for its parent. + /// + /// The default implementation is a no-op that passes the result of pushdown + /// from the children to its parent transparently, ensuring no filters are + /// lost if a node does not override this behavior. + /// + /// **Notes for Implementation:** + /// When returning filters via [`FilterPushdownPropagation`], the order of + /// filters need not match the order they were passed in via + /// `child_pushdown_result`. However, preserving the order is recommended for + /// debugging and ease of reasoning about the resulting plans. + /// + /// **Helper Methods for Customization:** + /// There are various helper methods to simplify implementing this method: + /// - [`FilterPushdownPropagation::unsupported`]: Indicates that the node + /// does not support filter pushdown at all, rejecting all filters. + /// - [`FilterPushdownPropagation::transparent`]: Indicates that the node + /// supports filter pushdown but does not modify it, simply transmitting + /// the children's pushdown results back up to its parent. + /// - [`PredicateSupports::new_with_supported_check`]: Takes a callback to + /// dynamically determine support for each filter, useful with + /// [`FilterPushdownPropagation::with_filters`] and + /// [`FilterPushdownPropagation::with_updated_node`] to build mixed results + /// of supported and unsupported filters. + /// + /// **Filter Pushdown Phases:** + /// There are two different phases in filter pushdown (`Pre` and others), + /// which some operators may handle differently. Depending on the phase, the + /// operator may or may not be allowed to modify the plan. See + /// [`FilterPushdownPhase`] for more details on phase-specific behavior. + /// + /// [`PredicateSupport::Supported`]: crate::filter_pushdown::PredicateSupport::Supported + /// [`PredicateSupports::new_with_supported_check`]: crate::filter_pushdown::PredicateSupports::new_with_supported_check + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::transparent( + child_pushdown_result, + )) + } + + /// Injects arbitrary run-time state into this execution plan, returning a new plan + /// instance that incorporates that state *if* it is relevant to the concrete + /// node implementation. + /// + /// This is a generic entry point: the `state` can be any type wrapped in + /// `Arc`. A node that cares about the state should + /// down-cast it to the concrete type it expects and, if successful, return a + /// modified copy of itself that captures the provided value. If the state is + /// not applicable, the default behaviour is to return `None` so that parent + /// nodes can continue propagating the attempt further down the plan tree. + /// + /// For example, [`WorkTableExec`](crate::work_table::WorkTableExec) + /// down-casts the supplied state to an `Arc` + /// in order to wire up the working table used during recursive-CTE execution. + /// Similar patterns can be followed by custom nodes that need late-bound + /// dependencies or shared state. + fn with_new_state( + &self, + _state: Arc, + ) -> Option> { + None + } } /// [`ExecutionPlan`] Invariant Level @@ -519,13 +691,15 @@ pub trait ExecutionPlanProperties { /// If this ExecutionPlan makes no changes to the schema of the rows flowing /// through it or how columns within each row relate to each other, it /// should return the equivalence properties of its input. For - /// example, since `FilterExec` may remove rows from its input, but does not + /// example, since [`FilterExec`] may remove rows from its input, but does not /// otherwise modify them, it preserves its input equivalence properties. /// However, since `ProjectionExec` may calculate derived expressions, it /// needs special handling. /// /// See also [`ExecutionPlan::maintains_input_order`] and [`Self::output_ordering`] /// for related concepts. + /// + /// [`FilterExec`]: crate::filter::FilterExec fn equivalence_properties(&self) -> &EquivalenceProperties; } @@ -639,6 +813,49 @@ pub enum EmissionType { Both, } +/// Represents whether an operator's `Stream` has been implemented to actively cooperate with the +/// Tokio scheduler or not. Please refer to the [`coop`](crate::coop) module for more details. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SchedulingType { + /// The stream generated by [`execute`](ExecutionPlan::execute) does not actively participate in + /// cooperative scheduling. This means the implementation of the `Stream` returned by + /// [`ExecutionPlan::execute`] does not contain explicit task budget consumption such as + /// [`tokio::task::coop::consume_budget`]. + /// + /// `NonCooperative` is the default value and is acceptable for most operators. Please refer to + /// the [`coop`](crate::coop) module for details on when it may be useful to use + /// `Cooperative` instead. + NonCooperative, + /// The stream generated by [`execute`](ExecutionPlan::execute) actively participates in + /// cooperative scheduling by consuming task budget when it was able to produce a + /// [`RecordBatch`]. + Cooperative, +} + +/// Represents how an operator's `Stream` implementation generates `RecordBatch`es. +/// +/// Most operators in DataFusion generate `RecordBatch`es when asked to do so by a call to +/// `Stream::poll_next`. This is known as demand-driven or lazy evaluation. +/// +/// Some operators like `Repartition` need to drive `RecordBatch` generation themselves though. This +/// is known as data-driven or eager evaluation. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EvaluationType { + /// The stream generated by [`execute`](ExecutionPlan::execute) only generates `RecordBatch` + /// instances when it is demanded by invoking `Stream::poll_next`. + /// Filter, projection, and join are examples of such lazy operators. + /// + /// Lazy operators are also known as demand-driven operators. + Lazy, + /// The stream generated by [`execute`](ExecutionPlan::execute) eagerly generates `RecordBatch` + /// in one or more spawned Tokio tasks. Eager evaluation is only started the first time + /// `Stream::poll_next` is called. + /// Examples of eager operators are repartition, coalesce partitions, and sort preserving merge. + /// + /// Eager operators are also known as a data-driven operators. + Eager, +} + /// Utility to determine an operator's boundedness based on its children's boundedness. /// /// Assumes boundedness can be inferred from child operators: @@ -727,6 +944,8 @@ pub struct PlanProperties { pub emission_type: EmissionType, /// See [ExecutionPlanProperties::boundedness] pub boundedness: Boundedness, + pub evaluation_type: EvaluationType, + pub scheduling_type: SchedulingType, /// See [ExecutionPlanProperties::output_ordering] output_ordering: Option, } @@ -746,6 +965,8 @@ impl PlanProperties { partitioning, emission_type, boundedness, + evaluation_type: EvaluationType::Lazy, + scheduling_type: SchedulingType::NonCooperative, output_ordering, } } @@ -777,6 +998,22 @@ impl PlanProperties { self } + /// Set the [`SchedulingType`]. + /// + /// Defaults to [`SchedulingType::NonCooperative`] + pub fn with_scheduling_type(mut self, scheduling_type: SchedulingType) -> Self { + self.scheduling_type = scheduling_type; + self + } + + /// Set the [`EvaluationType`]. + /// + /// Defaults to [`EvaluationType::Lazy`] + pub fn with_evaluation_type(mut self, drive_type: EvaluationType) -> Self { + self.evaluation_type = drive_type; + self + } + /// Overwrite constraints with its new value. pub fn with_constraints(mut self, constraints: Constraints) -> Self { self.eq_properties = self.eq_properties.with_constraints(constraints); @@ -803,30 +1040,12 @@ impl PlanProperties { /// Indicate whether a data exchange is needed for the input of `plan`, which will be very helpful /// especially for the distributed engine to judge whether need to deal with shuffling. -/// Currently there are 3 kinds of execution plan which needs data exchange +/// Currently, there are 3 kinds of execution plan which needs data exchange /// 1. RepartitionExec for changing the partition number between two `ExecutionPlan`s /// 2. CoalescePartitionsExec for collapsing all of the partitions into one without ordering guarantee /// 3. SortPreservingMergeExec for collapsing all of the sorted partitions into one with ordering guarantee pub fn need_data_exchange(plan: Arc) -> bool { - if let Some(repartition) = plan.as_any().downcast_ref::() { - !matches!( - repartition.properties().output_partitioning(), - Partitioning::RoundRobinBatch(_) - ) - } else if let Some(coalesce) = plan.as_any().downcast_ref::() - { - coalesce.input().output_partitioning().partition_count() > 1 - } else if let Some(sort_preserving_merge) = - plan.as_any().downcast_ref::() - { - sort_preserving_merge - .input() - .output_partitioning() - .partition_count() - > 1 - } else { - false - } + plan.properties().evaluation_type == EvaluationType::Eager } /// Returns a copy of this plan if we change any child according to the pointer comparison. @@ -1072,17 +1291,17 @@ pub enum CardinalityEffect { #[cfg(test)] mod tests { - use super::*; - use arrow::array::{DictionaryArray, Int32Array, NullArray, RunArray}; - use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use std::any::Any; use std::sync::Arc; + use super::*; + use crate::{DisplayAs, DisplayFormatType, ExecutionPlan}; + + use arrow::array::{DictionaryArray, Int32Array, NullArray, RunArray}; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{Result, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; - use crate::{DisplayAs, DisplayFormatType, ExecutionPlan}; - #[derive(Debug)] pub struct EmptyExec; @@ -1137,6 +1356,10 @@ mod tests { fn statistics(&self) -> Result { unimplemented!() } + + fn partition_statistics(&self, _partition: Option) -> Result { + unimplemented!() + } } #[derive(Debug)] @@ -1200,6 +1423,10 @@ mod tests { fn statistics(&self) -> Result { unimplemented!() } + + fn partition_statistics(&self, _partition: Option) -> Result { + unimplemented!() + } } #[test] diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index a8a9973ea043..252af9ebcd49 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -16,6 +16,7 @@ // under the License. use std::any::Any; +use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; use std::task::{ready, Context, Poll}; @@ -26,6 +27,10 @@ use super::{ }; use crate::common::can_project; use crate::execution_plan::CardinalityEffect; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, +}; use crate::projection::{ make_with_child, try_embed_projection, update_expr, EmbeddedProjection, ProjectionExec, @@ -39,25 +44,32 @@ use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; +use datafusion_common::config::ConfigOptions; use datafusion_common::stats::Precision; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::{ internal_err, plan_err, project_schema, DataFusionError, Result, ScalarValue, }; use datafusion_execution::TaskContext; use datafusion_expr::Operator; use datafusion_physical_expr::equivalence::ProjectionMapping; -use datafusion_physical_expr::expressions::BinaryExpr; +use datafusion_physical_expr::expressions::{lit, BinaryExpr, Column}; use datafusion_physical_expr::intervals::utils::check_support; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{ - analyze, split_conjunction, AcrossPartitions, AnalysisContext, ConstExpr, - ExprBoundaries, PhysicalExpr, + analyze, conjunction, split_conjunction, AcrossPartitions, AnalysisContext, + ConstExpr, ExprBoundaries, PhysicalExpr, }; use datafusion_physical_expr_common::physical_expr::fmt_sql; use futures::stream::{Stream, StreamExt}; +use itertools::Itertools; use log::trace; +const FILTER_EXEC_DEFAULT_SELECTIVITY: u8 = 20; + /// FilterExec evaluates a boolean predicate against all input batches to determine which rows to /// include in its output batches. #[derive(Debug, Clone)] @@ -84,7 +96,7 @@ impl FilterExec { ) -> Result { match predicate.data_type(input.schema().as_ref())? { DataType::Boolean => { - let default_selectivity = 20; + let default_selectivity = FILTER_EXEC_DEFAULT_SELECTIVITY; let cache = Self::compute_properties( &input, &predicate, @@ -170,12 +182,11 @@ impl FilterExec { /// Calculates `Statistics` for `FilterExec`, by applying selectivity (either default, or estimated) to input statistics. fn statistics_helper( - input: &Arc, + schema: SchemaRef, + input_stats: Statistics, predicate: &Arc, default_selectivity: u8, ) -> Result { - let input_stats = input.statistics()?; - let schema = input.schema(); if !check_support(predicate, &schema) { let selectivity = default_selectivity as f64 / 100.0; let mut stats = input_stats.to_inexact(); @@ -189,7 +200,7 @@ impl FilterExec { let num_rows = input_stats.num_rows; let total_byte_size = input_stats.total_byte_size; let input_analysis_ctx = AnalysisContext::try_from_statistics( - &input.schema(), + &schema, &input_stats.column_statistics, )?; @@ -223,24 +234,18 @@ impl FilterExec { if let Some(binary) = conjunction.as_any().downcast_ref::() { if binary.op() == &Operator::Eq { // Filter evaluates to single value for all partitions - if input_eqs.is_expr_constant(binary.left()) { - let (expr, across_parts) = ( - binary.right(), - input_eqs.get_expr_constant_value(binary.right()), - ); - res_constants.push( - ConstExpr::new(Arc::clone(expr)) - .with_across_partitions(across_parts), - ); - } else if input_eqs.is_expr_constant(binary.right()) { - let (expr, across_parts) = ( - binary.left(), - input_eqs.get_expr_constant_value(binary.left()), - ); - res_constants.push( - ConstExpr::new(Arc::clone(expr)) - .with_across_partitions(across_parts), - ); + if input_eqs.is_expr_constant(binary.left()).is_some() { + let across = input_eqs + .is_expr_constant(binary.right()) + .unwrap_or_default(); + res_constants + .push(ConstExpr::new(Arc::clone(binary.right()), across)); + } else if input_eqs.is_expr_constant(binary.right()).is_some() { + let across = input_eqs + .is_expr_constant(binary.left()) + .unwrap_or_default(); + res_constants + .push(ConstExpr::new(Arc::clone(binary.left()), across)); } } } @@ -256,11 +261,16 @@ impl FilterExec { ) -> Result { // Combine the equal predicates with the input equivalence properties // to construct the equivalence properties: - let stats = Self::statistics_helper(input, predicate, default_selectivity)?; + let stats = Self::statistics_helper( + input.schema(), + input.partition_statistics(None)?, + predicate, + default_selectivity, + )?; let mut eq_properties = input.equivalence_properties().clone(); let (equal_pairs, _) = collect_columns_from_predicate(predicate); for (lhs, rhs) in equal_pairs { - eq_properties.add_equal_conditions(lhs, rhs)? + eq_properties.add_equal_conditions(Arc::clone(lhs), Arc::clone(rhs))? } // Add the columns that have only one viable value (singleton) after // filtering to constants. @@ -272,15 +282,13 @@ impl FilterExec { .min_value .get_value(); let expr = Arc::new(column) as _; - ConstExpr::new(expr) - .with_across_partitions(AcrossPartitions::Uniform(value.cloned())) + ConstExpr::new(expr, AcrossPartitions::Uniform(value.cloned())) }); // This is for statistics - eq_properties = eq_properties.with_constants(constants); + eq_properties.add_constants(constants)?; // This is for logical constant (for example: a = '1', then a could be marked as a constant) // to do: how to deal with multiple situation to represent = (for example c1 between 0 and 0) - eq_properties = - eq_properties.with_constants(Self::extend_constants(input, predicate)); + eq_properties.add_constants(Self::extend_constants(input, predicate))?; let mut output_partitioning = input.output_partitioning().clone(); // If contains projection, update the PlanProperties. @@ -396,8 +404,14 @@ impl ExecutionPlan for FilterExec { /// The output statistics of a filtering operation can be estimated if the /// predicate's selectivity value can be determined for the incoming data. fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + let input_stats = self.input.partition_statistics(partition)?; let stats = Self::statistics_helper( - &self.input, + self.schema(), + input_stats, self.predicate(), self.default_selectivity, )?; @@ -433,6 +447,141 @@ impl ExecutionPlan for FilterExec { } try_embed_projection(projection, self) } + + fn gather_filters_for_pushdown( + &self, + phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + if !matches!(phase, FilterPushdownPhase::Pre) { + return Ok(FilterDescription::new_with_child_count(1) + .all_parent_filters_supported(parent_filters)); + } + let self_filter = split_conjunction(&self.predicate) + .into_iter() + .cloned() + .collect_vec(); + + let parent_filters = if let Some(projection_indices) = self.projection.as_ref() { + // We need to invert the projection on any referenced columns in the filter + // Create a mapping from the output columns to the input columns (the inverse of the projection) + let inverse_projection = projection_indices + .iter() + .enumerate() + .map(|(i, &p)| (p, i)) + .collect::>(); + parent_filters + .into_iter() + .map(|f| { + f.transform_up(|expr| { + let mut res = + if let Some(col) = expr.as_any().downcast_ref::() { + let index = col.index(); + let index_in_input_schema = + inverse_projection.get(&index).ok_or_else(|| { + DataFusionError::Internal(format!( + "Column {index} not found in projection" + )) + })?; + Transformed::yes(Arc::new(Column::new( + col.name(), + *index_in_input_schema, + )) as _) + } else { + Transformed::no(expr) + }; + // Columns can only exist in the leaves, no need to try all nodes + res.tnr = TreeNodeRecursion::Jump; + Ok(res) + }) + .data() + }) + .collect::>>()? + } else { + parent_filters + }; + + Ok(FilterDescription::new_with_child_count(1) + .all_parent_filters_supported(parent_filters) + .with_self_filters_for_children(vec![self_filter])) + } + + fn handle_child_pushdown_result( + &self, + phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + if !matches!(phase, FilterPushdownPhase::Pre) { + return Ok(FilterPushdownPropagation::transparent( + child_pushdown_result, + )); + } + // We absorb any parent filters that were not handled by our children + let mut unhandled_filters = + child_pushdown_result.parent_filters.collect_unsupported(); + assert_eq!( + child_pushdown_result.self_filters.len(), + 1, + "FilterExec should only have one child" + ); + let unsupported_self_filters = + child_pushdown_result.self_filters[0].collect_unsupported(); + unhandled_filters.extend(unsupported_self_filters); + + // If we have unhandled filters, we need to create a new FilterExec + let filter_input = Arc::clone(self.input()); + let new_predicate = conjunction(unhandled_filters); + let updated_node = if new_predicate.eq(&lit(true)) { + // FilterExec is no longer needed, but we may need to leave a projection in place + match self.projection() { + Some(projection_indices) => { + let filter_child_schema = filter_input.schema(); + let proj_exprs = projection_indices + .iter() + .map(|p| { + let field = filter_child_schema.field(*p).clone(); + ( + Arc::new(Column::new(field.name(), *p)) + as Arc, + field.name().to_string(), + ) + }) + .collect::>(); + Some(Arc::new(ProjectionExec::try_new(proj_exprs, filter_input)?) + as Arc) + } + None => { + // No projection needed, just return the input + Some(filter_input) + } + } + } else if new_predicate.eq(&self.predicate) { + // The new predicate is the same as our current predicate + None + } else { + // Create a new FilterExec with the new predicate + let new = FilterExec { + predicate: Arc::clone(&new_predicate), + input: Arc::clone(&filter_input), + metrics: self.metrics.clone(), + default_selectivity: self.default_selectivity, + cache: Self::compute_properties( + &filter_input, + &new_predicate, + self.default_selectivity, + self.projection.as_ref(), + )?, + projection: None, + }; + Some(Arc::new(new) as _) + }; + Ok(FilterPushdownPropagation { + filters: child_pushdown_result.parent_filters.make_supported(), + updated_node, + }) + } } impl EmbeddedProjection for FilterExec { @@ -703,7 +852,7 @@ mod tests { let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(25)); assert_eq!( statistics.total_byte_size, @@ -753,7 +902,7 @@ mod tests { sub_filter, )?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(16)); assert_eq!( statistics.column_statistics, @@ -813,7 +962,7 @@ mod tests { binary(col("a", &schema)?, Operator::GtEq, lit(10i32), &schema)?, b_gt_5, )?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; // On a uniform distribution, only fifteen rows will satisfy the // filter that 'a' proposed (a >= 10 AND a <= 25) (15/100) and only // 5 rows will satisfy the filter that 'b' proposed (b > 45) (5/50). @@ -858,7 +1007,7 @@ mod tests { let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Absent); Ok(()) @@ -931,7 +1080,7 @@ mod tests { )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; // 0.5 (from a) * 0.333333... (from b) * 0.798387... (from c) ≈ 0.1330... // num_rows after ceil => 133.0... => 134 // total_byte_size after ceil => 532.0... => 533 @@ -1027,10 +1176,10 @@ mod tests { )), )); // Since filter predicate passes all entries, statistics after filter shouldn't change. - let expected = input.statistics()?.column_statistics; + let expected = input.partition_statistics(None)?.column_statistics; let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(1000)); assert_eq!(statistics.total_byte_size, Precision::Inexact(4000)); @@ -1083,7 +1232,7 @@ mod tests { )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(0)); assert_eq!(statistics.total_byte_size, Precision::Inexact(0)); @@ -1143,7 +1292,7 @@ mod tests { )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(490)); assert_eq!(statistics.total_byte_size, Precision::Inexact(1960)); @@ -1193,7 +1342,7 @@ mod tests { )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let filter_statistics = filter.statistics()?; + let filter_statistics = filter.partition_statistics(None)?; let expected_filter_statistics = Statistics { num_rows: Precision::Absent, @@ -1227,7 +1376,7 @@ mod tests { )); let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); - let filter_statistics = filter.statistics()?; + let filter_statistics = filter.partition_statistics(None)?; // First column is "a", and it is a column with only one value after the filter. assert!(filter_statistics.column_statistics[0].is_singleton()); @@ -1274,11 +1423,11 @@ mod tests { Arc::new(Literal::new(ScalarValue::Decimal128(Some(10), 10, 10))), )); let filter = FilterExec::try_new(predicate, input)?; - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(200)); assert_eq!(statistics.total_byte_size, Precision::Inexact(800)); let filter = filter.with_default_selectivity(40)?; - let statistics = filter.statistics()?; + let statistics = filter.partition_statistics(None)?; assert_eq!(statistics.num_rows, Precision::Inexact(400)); assert_eq!(statistics.total_byte_size, Precision::Inexact(1600)); Ok(()) @@ -1312,7 +1461,7 @@ mod tests { Arc::new(EmptyExec::new(Arc::clone(&schema))), )?; - exec.statistics().unwrap(); + exec.partition_statistics(None).unwrap(); Ok(()) } diff --git a/datafusion/physical-plan/src/filter_pushdown.rs b/datafusion/physical-plan/src/filter_pushdown.rs new file mode 100644 index 000000000000..725abd7fc8b5 --- /dev/null +++ b/datafusion/physical-plan/src/filter_pushdown.rs @@ -0,0 +1,437 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; +use std::vec::IntoIter; + +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +#[derive(Debug, Clone, Copy)] +pub enum FilterPushdownPhase { + /// Pushdown that happens before most other optimizations. + /// This pushdown allows static filters that do not reference any [`ExecutionPlan`]s to be pushed down. + /// Filters that reference an [`ExecutionPlan`] cannot be pushed down at this stage since the whole plan tree may be rewritten + /// by other optimizations. + /// Implementers are however allowed to modify the execution plan themselves during this phase, for example by returning a completely + /// different [`ExecutionPlan`] from [`ExecutionPlan::handle_child_pushdown_result`]. + /// + /// Pushdown of [`FilterExec`] into `DataSourceExec` is an example of a pre-pushdown. + /// Unlike filter pushdown in the logical phase, which operates on the logical plan to push filters into the logical table scan, + /// the `Pre` phase in the physical plan targets the actual physical scan, pushing filters down to specific data source implementations. + /// For example, Parquet supports filter pushdown to reduce data read during scanning, while CSV typically does not. + /// + /// [`ExecutionPlan`]: crate::ExecutionPlan + /// [`FilterExec`]: crate::filter::FilterExec + /// [`ExecutionPlan::handle_child_pushdown_result`]: crate::ExecutionPlan::handle_child_pushdown_result + Pre, + /// Pushdown that happens after most other optimizations. + /// This stage of filter pushdown allows filters that reference an [`ExecutionPlan`] to be pushed down. + /// Since subsequent optimizations should not change the structure of the plan tree except for calling [`ExecutionPlan::with_new_children`] + /// (which generally preserves internal references) it is safe for references between [`ExecutionPlan`]s to be established at this stage. + /// + /// This phase is used to link a [`SortExec`] (with a TopK operator) or a [`HashJoinExec`] to a `DataSourceExec`. + /// + /// [`ExecutionPlan`]: crate::ExecutionPlan + /// [`ExecutionPlan::with_new_children`]: crate::ExecutionPlan::with_new_children + /// [`SortExec`]: crate::sorts::sort::SortExec + /// [`HashJoinExec`]: crate::joins::HashJoinExec + /// [`ExecutionPlan::handle_child_pushdown_result`]: crate::ExecutionPlan::handle_child_pushdown_result + Post, +} + +impl std::fmt::Display for FilterPushdownPhase { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FilterPushdownPhase::Pre => write!(f, "Pre"), + FilterPushdownPhase::Post => write!(f, "Post"), + } + } +} + +/// The result of a plan for pushing down a filter into a child node. +/// This contains references to filters so that nodes can mutate a filter +/// before pushing it down to a child node (e.g. to adjust a projection) +/// or can directly take ownership of `Unsupported` filters that their children +/// could not handle. +#[derive(Debug, Clone)] +pub enum PredicateSupport { + Supported(Arc), + Unsupported(Arc), +} + +impl PredicateSupport { + pub fn into_inner(self) -> Arc { + match self { + PredicateSupport::Supported(expr) | PredicateSupport::Unsupported(expr) => { + expr + } + } + } +} + +/// A thin wrapper around [`PredicateSupport`]s that allows for easy collection of +/// supported and unsupported filters. Inner vector stores each predicate for one node. +#[derive(Debug, Clone)] +pub struct PredicateSupports(Vec); + +impl PredicateSupports { + /// Create a new FilterPushdowns with the given filters and their pushdown status. + pub fn new(pushdowns: Vec) -> Self { + Self(pushdowns) + } + + /// Create a new [`PredicateSupport`] with all filters as supported. + pub fn all_supported(filters: Vec>) -> Self { + let pushdowns = filters + .into_iter() + .map(PredicateSupport::Supported) + .collect(); + Self::new(pushdowns) + } + + /// Create a new [`PredicateSupport`] with all filters as unsupported. + pub fn all_unsupported(filters: Vec>) -> Self { + let pushdowns = filters + .into_iter() + .map(PredicateSupport::Unsupported) + .collect(); + Self::new(pushdowns) + } + + /// Create a new [`PredicateSupport`] with filterrs marked as supported if + /// `f` returns true and unsupported otherwise. + pub fn new_with_supported_check( + filters: Vec>, + check: impl Fn(&Arc) -> bool, + ) -> Self { + let pushdowns = filters + .into_iter() + .map(|f| { + if check(&f) { + PredicateSupport::Supported(f) + } else { + PredicateSupport::Unsupported(f) + } + }) + .collect(); + Self::new(pushdowns) + } + + /// Transform all filters to supported, returning a new [`PredicateSupports`] + /// with all filters as [`PredicateSupport::Supported`]. + /// This does not modify the original [`PredicateSupport`]. + pub fn make_supported(self) -> Self { + let pushdowns = self + .0 + .into_iter() + .map(|f| match f { + PredicateSupport::Supported(expr) => PredicateSupport::Supported(expr), + PredicateSupport::Unsupported(expr) => PredicateSupport::Supported(expr), + }) + .collect(); + Self::new(pushdowns) + } + + /// Transform all filters to unsupported, returning a new [`PredicateSupports`] + /// with all filters as [`PredicateSupport::Supported`]. + /// This does not modify the original [`PredicateSupport`]. + pub fn make_unsupported(self) -> Self { + let pushdowns = self + .0 + .into_iter() + .map(|f| match f { + PredicateSupport::Supported(expr) => PredicateSupport::Unsupported(expr), + u @ PredicateSupport::Unsupported(_) => u, + }) + .collect(); + Self::new(pushdowns) + } + + /// Collect unsupported filters into a Vec, without removing them from the original + /// [`PredicateSupport`]. + pub fn collect_unsupported(&self) -> Vec> { + self.0 + .iter() + .filter_map(|f| match f { + PredicateSupport::Unsupported(expr) => Some(Arc::clone(expr)), + PredicateSupport::Supported(_) => None, + }) + .collect() + } + + /// Collect supported filters into a Vec, without removing them from the original + /// [`PredicateSupport`]. + pub fn collect_supported(&self) -> Vec> { + self.0 + .iter() + .filter_map(|f| match f { + PredicateSupport::Supported(expr) => Some(Arc::clone(expr)), + PredicateSupport::Unsupported(_) => None, + }) + .collect() + } + + /// Collect all filters into a Vec, without removing them from the original + /// FilterPushdowns. + pub fn collect_all(self) -> Vec> { + self.0 + .into_iter() + .map(|f| match f { + PredicateSupport::Supported(expr) + | PredicateSupport::Unsupported(expr) => expr, + }) + .collect() + } + + pub fn into_inner(self) -> Vec { + self.0 + } + + /// Return an iterator over the inner `Vec`. + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + /// Return the number of filters in the inner `Vec`. + pub fn len(&self) -> usize { + self.0.len() + } + + /// Check if the inner `Vec` is empty. + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Check if all filters are supported. + pub fn is_all_supported(&self) -> bool { + self.0 + .iter() + .all(|f| matches!(f, PredicateSupport::Supported(_))) + } + + /// Check if all filters are unsupported. + pub fn is_all_unsupported(&self) -> bool { + self.0 + .iter() + .all(|f| matches!(f, PredicateSupport::Unsupported(_))) + } +} + +impl IntoIterator for PredicateSupports { + type Item = PredicateSupport; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +/// The result of pushing down filters into a child node. +/// This is the result provided to nodes in [`ExecutionPlan::handle_child_pushdown_result`]. +/// Nodes process this result and convert it into a [`FilterPushdownPropagation`] +/// that is returned to their parent. +/// +/// [`ExecutionPlan::handle_child_pushdown_result`]: crate::ExecutionPlan::handle_child_pushdown_result +#[derive(Debug, Clone)] +pub struct ChildPushdownResult { + /// The combined result of pushing down each parent filter into each child. + /// For example, given the fitlers `[a, b]` and children `[1, 2, 3]` the matrix of responses: + /// + // | filter | child 1 | child 2 | child 3 | result | + // |--------|-------------|-----------|-----------|-------------| + // | a | Supported | Supported | Supported | Supported | + // | b | Unsupported | Supported | Supported | Unsupported | + /// + /// That is: if any child marks a filter as unsupported or if the filter was not pushed + /// down into any child then the result is unsupported. + /// If at least one children and all children that received the filter mark it as supported + /// then the result is supported. + pub parent_filters: PredicateSupports, + /// The result of pushing down each filter this node provided into each of it's children. + /// This is not combined with the parent filters so that nodes can treat each child independently. + pub self_filters: Vec, +} + +/// The result of pushing down filters into a node that it returns to its parent. +/// This is what nodes return from [`ExecutionPlan::handle_child_pushdown_result`] to communicate +/// to the optimizer: +/// +/// 1. What to do with any parent filters that were not completely handled by the children. +/// 2. If the node needs to be replaced in the execution plan with a new node or not. +/// +/// [`ExecutionPlan::handle_child_pushdown_result`]: crate::ExecutionPlan::handle_child_pushdown_result +#[derive(Debug, Clone)] +pub struct FilterPushdownPropagation { + pub filters: PredicateSupports, + pub updated_node: Option, +} + +impl FilterPushdownPropagation { + /// Create a new [`FilterPushdownPropagation`] that tells the parent node + /// that echoes back up to the parent the result of pushing down the filters + /// into the children. + pub fn transparent(child_pushdown_result: ChildPushdownResult) -> Self { + Self { + filters: child_pushdown_result.parent_filters, + updated_node: None, + } + } + + /// Create a new [`FilterPushdownPropagation`] that tells the parent node + /// that none of the parent filters were not pushed down. + pub fn unsupported(parent_filters: Vec>) -> Self { + let unsupported = PredicateSupports::all_unsupported(parent_filters); + Self { + filters: unsupported, + updated_node: None, + } + } + + /// Create a new [`FilterPushdownPropagation`] with the specified filter support. + pub fn with_filters(filters: PredicateSupports) -> Self { + Self { + filters, + updated_node: None, + } + } + + /// Bind an updated node to the [`FilterPushdownPropagation`]. + pub fn with_updated_node(mut self, updated_node: T) -> Self { + self.updated_node = Some(updated_node); + self + } +} + +#[derive(Debug, Clone)] +struct ChildFilterDescription { + /// Description of which parent filters can be pushed down into this node. + /// Since we need to transmit filter pushdown results back to this node's parent + /// we need to track each parent filter for each child, even those that are unsupported / won't be pushed down. + /// We do this using a [`PredicateSupport`] which simplifies manipulating supported/unsupported filters. + parent_filters: PredicateSupports, + /// Description of which filters this node is pushing down to its children. + /// Since this is not transmitted back to the parents we can have variable sized inner arrays + /// instead of having to track supported/unsupported. + self_filters: Vec>, +} + +impl ChildFilterDescription { + fn new() -> Self { + Self { + parent_filters: PredicateSupports::new(vec![]), + self_filters: vec![], + } + } +} + +#[derive(Debug, Clone)] +pub struct FilterDescription { + /// A filter description for each child. + /// This includes which parent filters and which self filters (from the node in question) + /// will get pushed down to each child. + child_filter_descriptions: Vec, +} + +impl FilterDescription { + pub fn new_with_child_count(num_children: usize) -> Self { + Self { + child_filter_descriptions: vec![ChildFilterDescription::new(); num_children], + } + } + + pub fn parent_filters(&self) -> Vec { + self.child_filter_descriptions + .iter() + .map(|d| &d.parent_filters) + .cloned() + .collect() + } + + pub fn self_filters(&self) -> Vec>> { + self.child_filter_descriptions + .iter() + .map(|d| &d.self_filters) + .cloned() + .collect() + } + + /// Mark all parent filters as supported for all children. + /// This is the case if the node allows filters to be pushed down through it + /// without any modification. + /// This broadcasts the parent filters to all children. + /// If handling of parent filters is different for each child then you should set the + /// field direclty. + /// For example, nodes like [`RepartitionExec`] that let filters pass through it transparently + /// use this to mark all parent filters as supported. + /// + /// [`RepartitionExec`]: crate::repartition::RepartitionExec + pub fn all_parent_filters_supported( + mut self, + parent_filters: Vec>, + ) -> Self { + let supported = PredicateSupports::all_supported(parent_filters); + for child in &mut self.child_filter_descriptions { + child.parent_filters = supported.clone(); + } + self + } + + /// Mark all parent filters as unsupported for all children. + /// This is the case if the node does not allow filters to be pushed down through it. + /// This broadcasts the parent filters to all children. + /// If handling of parent filters is different for each child then you should set the + /// field direclty. + /// For example, the default implementation of filter pushdwon in [`ExecutionPlan`] + /// assumes that filters cannot be pushed down to children. + /// + /// [`ExecutionPlan`]: crate::ExecutionPlan + pub fn all_parent_filters_unsupported( + mut self, + parent_filters: Vec>, + ) -> Self { + let unsupported = PredicateSupports::all_unsupported(parent_filters); + for child in &mut self.child_filter_descriptions { + child.parent_filters = unsupported.clone(); + } + self + } + + /// Add a filter generated / owned by the current node to be pushed down to all children. + /// This assumes that there is a single filter that that gets pushed down to all children + /// equally. + /// If there are multiple filters or pushdown to children is not homogeneous then + /// you should set the field directly. + /// For example: + /// - `TopK` uses this to push down a single filter to all children, it can use this method. + /// - `HashJoinExec` pushes down a filter only to the probe side, it cannot use this method. + pub fn with_self_filter(mut self, predicate: Arc) -> Self { + for child in &mut self.child_filter_descriptions { + child.self_filters = vec![Arc::clone(&predicate)]; + } + self + } + + pub fn with_self_filters_for_children( + mut self, + filters: Vec>>, + ) -> Self { + for (child, filters) in self.child_filter_descriptions.iter_mut().zip(filters) { + child.self_filters = filters; + } + self + } +} diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 8dd1addff15c..e4d554ceb62c 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -115,7 +115,7 @@ impl CrossJoinExec { }; let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata)); - let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)); + let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)).unwrap(); CrossJoinExec { left, @@ -142,7 +142,7 @@ impl CrossJoinExec { left: &Arc, right: &Arc, schema: SchemaRef, - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties // TODO: Check equivalence properties of cross join, it may preserve // ordering in some cases. @@ -154,7 +154,7 @@ impl CrossJoinExec { &[false, false], None, &[], - ); + )?; // Get output partitioning: // TODO: Optimize the cross join implementation to generate M * N @@ -162,14 +162,14 @@ impl CrossJoinExec { let output_partitioning = adjust_right_output_partitioning( right.output_partitioning(), left.schema().fields.len(), - ); + )?; - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, EmissionType::Final, boundedness_from_children([left, right]), - ) + )) } /// Returns a new `ExecutionPlan` that computes the same join as this one, @@ -337,10 +337,15 @@ impl ExecutionPlan for CrossJoinExec { } fn statistics(&self) -> Result { - Ok(stats_cartesian_product( - self.left.statistics()?, - self.right.statistics()?, - )) + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + // Get the all partitions statistics of the left + let left_stats = self.left.partition_statistics(None)?; + let right_stats = self.right.partition_statistics(partition)?; + + Ok(stats_cartesian_product(left_stats, right_stats)) } /// Tries to swap the projection with its input [`CrossJoinExec`]. If it can be done, @@ -869,7 +874,7 @@ mod tests { assert_contains!( err.to_string(), - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: CrossJoinExec" + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n CrossJoinExec" ); Ok(()) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index e8904db0f3ea..770399290dca 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -70,7 +70,7 @@ use arrow::util::bit_util; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ internal_datafusion_err, internal_err, plan_err, project_schema, DataFusionError, - JoinSide, JoinType, Result, + JoinSide, JoinType, NullEquality, Result, }; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; @@ -86,6 +86,10 @@ use datafusion_physical_expr_common::physical_expr::fmt_sql; use futures::{ready, Stream, StreamExt, TryStreamExt}; use parking_lot::Mutex; +/// Hard-coded seed to ensure hash values from the hash join differ from `RepartitionExec`, avoiding collisions. +const HASH_JOIN_SEED: RandomState = + RandomState::with_seeds('J' as u64, 'O' as u64, 'I' as u64, 'N' as u64); + /// HashTable and input data for the left (build side) of a join struct JoinLeftData { /// The hash table with indices into `batch` @@ -349,11 +353,8 @@ pub struct HashJoinExec { pub projection: Option>, /// Information of index and left / right placement of columns column_indices: Vec, - /// Null matching behavior: If `null_equals_null` is true, rows that have - /// `null`s in both left and right equijoin columns will be matched. - /// Otherwise, rows that have `null`s in the join columns will not be - /// matched and thus will not appear in the output. - pub null_equals_null: bool, + /// The equality null-handling behavior of the join algorithm. + pub null_equality: NullEquality, /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, } @@ -372,7 +373,7 @@ impl HashJoinExec { join_type: &JoinType, projection: Option>, partition_mode: PartitionMode, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result { let left_schema = left.schema(); let right_schema = right.schema(); @@ -385,7 +386,7 @@ impl HashJoinExec { let (join_schema, column_indices) = build_join_schema(&left_schema, &right_schema, join_type); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = HASH_JOIN_SEED; let join_schema = Arc::new(join_schema); @@ -415,7 +416,7 @@ impl HashJoinExec { metrics: ExecutionPlanMetricsSet::new(), projection, column_indices, - null_equals_null, + null_equality, cache, }) } @@ -456,9 +457,9 @@ impl HashJoinExec { &self.mode } - /// Get null_equals_null - pub fn null_equals_null(&self) -> bool { - self.null_equals_null + /// Get null_equality + pub fn null_equality(&self) -> NullEquality { + self.null_equality } /// Calculate order preservation flags for this hash join. @@ -471,6 +472,7 @@ impl HashJoinExec { | JoinType::Right | JoinType::RightAnti | JoinType::RightSemi + | JoinType::RightMark ), ] } @@ -505,7 +507,7 @@ impl HashJoinExec { &self.join_type, projection, self.mode, - self.null_equals_null, + self.null_equality, ) } @@ -528,17 +530,17 @@ impl HashJoinExec { &Self::maintains_input_order(join_type), Some(Self::probe_side()), on, - ); + )?; let mut output_partitioning = match mode { PartitionMode::CollectLeft => { - asymmetric_join_output_partitioning(left, right, &join_type) + asymmetric_join_output_partitioning(left, right, &join_type)? } PartitionMode::Auto => Partitioning::UnknownPartitioning( right.output_partitioning().partition_count(), ), PartitionMode::Partitioned => { - symmetric_join_output_partitioning(left, right, &join_type) + symmetric_join_output_partitioning(left, right, &join_type)? } }; @@ -552,7 +554,8 @@ impl HashJoinExec { | JoinType::LeftSemi | JoinType::RightSemi | JoinType::Right - | JoinType::RightAnti => EmissionType::Incremental, + | JoinType::RightAnti + | JoinType::RightMark => EmissionType::Incremental, // If we need to generate unmatched rows from the *build side*, // we need to emit them at the end. JoinType::Left @@ -613,7 +616,7 @@ impl HashJoinExec { self.join_type(), ), partition_mode, - self.null_equals_null(), + self.null_equality(), )?; // In case of anti / semi joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again if matches!( @@ -660,7 +663,7 @@ impl DisplayAs for HashJoinExec { let on = self .on .iter() - .map(|(c1, c2)| format!("({}, {})", c1, c2)) + .map(|(c1, c2)| format!("({c1}, {c2})")) .collect::>() .join(", "); write!( @@ -682,7 +685,7 @@ impl DisplayAs for HashJoinExec { if *self.join_type() != JoinType::Inner { writeln!(f, "join_type={:?}", self.join_type)?; } - writeln!(f, "on={}", on) + writeln!(f, "on={on}") } } } @@ -761,7 +764,7 @@ impl ExecutionPlan for HashJoinExec { &self.join_type, self.projection.clone(), self.mode, - self.null_equals_null, + self.null_equality, )?)) } @@ -865,7 +868,7 @@ impl ExecutionPlan for HashJoinExec { column_indices: column_indices_after_projection, random_state: self.random_state.clone(), join_metrics, - null_equals_null: self.null_equals_null, + null_equality: self.null_equality, state: HashJoinStreamState::WaitBuildSide, build_side: BuildSide::Initial(BuildSideInitialState { left_fut }), batch_size, @@ -879,12 +882,19 @@ impl ExecutionPlan for HashJoinExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema())); + } // TODO stats: it is not possible in general to know the output size of joins // There are some special cases though, for example: // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` let stats = estimate_join_statistics( - Arc::clone(&self.left), - Arc::clone(&self.right), + self.left.partition_statistics(None)?, + self.right.partition_statistics(None)?, self.on.clone(), &self.join_type, &self.join_schema, @@ -927,7 +937,7 @@ impl ExecutionPlan for HashJoinExec { // Returned early if projection is not None None, *self.partition_mode(), - self.null_equals_null, + self.null_equality, )?))) } else { try_embed_projection(projection, self) @@ -1218,8 +1228,8 @@ struct HashJoinStream { join_metrics: BuildProbeJoinMetrics, /// Information of index and left / right placement of columns column_indices: Vec, - /// If null_equals_null is true, null == null else null != null - null_equals_null: bool, + /// Defines the null equality for the join. + null_equality: NullEquality, /// State of the stream state: HashJoinStreamState, /// Build side @@ -1291,13 +1301,13 @@ fn lookup_join_hashmap( build_hashmap: &JoinHashMap, build_side_values: &[ArrayRef], probe_side_values: &[ArrayRef], - null_equals_null: bool, + null_equality: NullEquality, hashes_buffer: &[u64], limit: usize, offset: JoinHashMapOffset, ) -> Result<(UInt64Array, UInt32Array, Option)> { - let (probe_indices, build_indices, next_offset) = build_hashmap - .get_matched_indices_with_limit_offset(hashes_buffer, None, limit, offset); + let (probe_indices, build_indices, next_offset) = + build_hashmap.get_matched_indices_with_limit_offset(hashes_buffer, limit, offset); let build_indices: UInt64Array = build_indices.into(); let probe_indices: UInt32Array = probe_indices.into(); @@ -1307,7 +1317,7 @@ fn lookup_join_hashmap( &probe_indices, build_side_values, probe_side_values, - null_equals_null, + null_equality, )?; Ok((build_indices, probe_indices, next_offset)) @@ -1317,22 +1327,21 @@ fn lookup_join_hashmap( fn eq_dyn_null( left: &dyn Array, right: &dyn Array, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result { // Nested datatypes cannot use the underlying not_distinct/eq function and must use a special // implementation // if left.data_type().is_nested() { - let op = if null_equals_null { - Operator::IsNotDistinctFrom - } else { - Operator::Eq + let op = match null_equality { + NullEquality::NullEqualsNothing => Operator::Eq, + NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom, }; return Ok(compare_op_for_nested(op, &left, &right)?); } - match (left.data_type(), right.data_type()) { - _ if null_equals_null => not_distinct(&left, &right), - _ => eq(&left, &right), + match null_equality { + NullEquality::NullEqualsNothing => eq(&left, &right), + NullEquality::NullEqualsNull => not_distinct(&left, &right), } } @@ -1341,7 +1350,7 @@ pub fn equal_rows_arr( indices_right: &UInt32Array, left_arrays: &[ArrayRef], right_arrays: &[ArrayRef], - null_equals_null: bool, + null_equality: NullEquality, ) -> Result<(UInt64Array, UInt32Array)> { let mut iter = left_arrays.iter().zip(right_arrays.iter()); @@ -1354,7 +1363,7 @@ pub fn equal_rows_arr( let arr_left = take(first_left.as_ref(), indices_left, None)?; let arr_right = take(first_right.as_ref(), indices_right, None)?; - let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equals_null)?; + let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equality)?; // Use map and try_fold to iterate over the remaining pairs of arrays. // In each iteration, take is used on the pair of arrays and their equality is determined. @@ -1363,7 +1372,7 @@ pub fn equal_rows_arr( .map(|(left, right)| { let arr_left = take(left.as_ref(), indices_left, None)?; let arr_right = take(right.as_ref(), indices_right, None)?; - eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equals_null) + eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equality) }) .try_fold(equal, |acc, equal2| and(&acc, &equal2?))?; @@ -1483,7 +1492,7 @@ impl HashJoinStream { build_side.left_data.hash_map(), build_side.left_data.values(), &state.values, - self.null_equals_null, + self.null_equality, &self.hashes_buffer, self.batch_size, state.offset, @@ -1550,15 +1559,27 @@ impl HashJoinStream { self.right_side_ordered, )?; - let result = build_batch_from_indices( - &self.schema, - build_side.left_data.batch(), - &state.batch, - &left_indices, - &right_indices, - &self.column_indices, - JoinSide::Left, - )?; + let result = if self.join_type == JoinType::RightMark { + build_batch_from_indices( + &self.schema, + &state.batch, + build_side.left_data.batch(), + &left_indices, + &right_indices, + &self.column_indices, + JoinSide::Right, + )? + } else { + build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &state.batch, + &left_indices, + &right_indices, + &self.column_indices, + JoinSide::Left, + )? + }; self.join_metrics.output_batches.add(1); self.join_metrics.output_rows.add(result.num_rows()); @@ -1701,7 +1722,7 @@ mod tests { right: Arc, on: JoinOn, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result { HashJoinExec::try_new( left, @@ -1711,7 +1732,7 @@ mod tests { join_type, None, PartitionMode::CollectLeft, - null_equals_null, + null_equality, ) } @@ -1721,7 +1742,7 @@ mod tests { on: JoinOn, filter: JoinFilter, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result { HashJoinExec::try_new( left, @@ -1731,7 +1752,7 @@ mod tests { join_type, None, PartitionMode::CollectLeft, - null_equals_null, + null_equality, ) } @@ -1740,10 +1761,10 @@ mod tests { right: Arc, on: JoinOn, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, context: Arc, ) -> Result<(Vec, Vec)> { - let join = join(left, right, on, join_type, null_equals_null)?; + let join = join(left, right, on, join_type, null_equality)?; let columns_header = columns(&join.schema()); let stream = join.execute(0, context)?; @@ -1757,7 +1778,7 @@ mod tests { right: Arc, on: JoinOn, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, context: Arc, ) -> Result<(Vec, Vec)> { join_collect_with_partition_mode( @@ -1766,7 +1787,7 @@ mod tests { on, join_type, PartitionMode::Partitioned, - null_equals_null, + null_equality, context, ) .await @@ -1778,7 +1799,7 @@ mod tests { on: JoinOn, join_type: &JoinType, partition_mode: PartitionMode, - null_equals_null: bool, + null_equality: NullEquality, context: Arc, ) -> Result<(Vec, Vec)> { let partition_count = 4; @@ -1828,7 +1849,7 @@ mod tests { join_type, None, partition_mode, - null_equals_null, + null_equality, )?; let columns = columns(&join.schema()); @@ -1873,7 +1894,7 @@ mod tests { Arc::clone(&right), on.clone(), &JoinType::Inner, - false, + NullEquality::NullEqualsNothing, task_ctx, ) .await?; @@ -1920,7 +1941,7 @@ mod tests { Arc::clone(&right), on.clone(), &JoinType::Inner, - false, + NullEquality::NullEqualsNothing, task_ctx, ) .await?; @@ -1960,8 +1981,15 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (columns, batches) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -1999,8 +2027,15 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (columns, batches) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -2046,8 +2081,15 @@ mod tests { ), ]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (columns, batches) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); @@ -2117,8 +2159,15 @@ mod tests { ), ]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (columns, batches) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); @@ -2183,8 +2232,15 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (columns, batches) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -2233,7 +2289,13 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let join = join(left, right, on, &JoinType::Inner, false)?; + let join = join( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -2326,7 +2388,14 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) as _, )]; - let join = join(left, right, on, &JoinType::Left, false).unwrap(); + let join = join( + left, + right, + on, + &JoinType::Left, + NullEquality::NullEqualsNothing, + ) + .unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -2369,7 +2438,14 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, )]; - let join = join(left, right, on, &JoinType::Full, false).unwrap(); + let join = join( + left, + right, + on, + &JoinType::Full, + NullEquality::NullEqualsNothing, + ) + .unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -2410,7 +2486,14 @@ mod tests { )]; let schema = right.schema(); let right = TestMemoryExec::try_new_exec(&[vec![right]], schema, None).unwrap(); - let join = join(left, right, on, &JoinType::Left, false).unwrap(); + let join = join( + left, + right, + on, + &JoinType::Left, + NullEquality::NullEqualsNothing, + ) + .unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -2447,7 +2530,14 @@ mod tests { )]; let schema = right.schema(); let right = TestMemoryExec::try_new_exec(&[vec![right]], schema, None).unwrap(); - let join = join(left, right, on, &JoinType::Full, false).unwrap(); + let join = join( + left, + right, + on, + &JoinType::Full, + NullEquality::NullEqualsNothing, + ) + .unwrap(); let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -2492,7 +2582,7 @@ mod tests { Arc::clone(&right), on.clone(), &JoinType::Left, - false, + NullEquality::NullEqualsNothing, task_ctx, ) .await?; @@ -2537,7 +2627,7 @@ mod tests { Arc::clone(&right), on.clone(), &JoinType::Left, - false, + NullEquality::NullEqualsNothing, task_ctx, ) .await?; @@ -2590,7 +2680,13 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let join = join(left, right, on, &JoinType::LeftSemi, false)?; + let join = join( + left, + right, + on, + &JoinType::LeftSemi, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1"]); @@ -2652,7 +2748,7 @@ mod tests { on.clone(), filter, &JoinType::LeftSemi, - false, + NullEquality::NullEqualsNothing, )?; let columns_header = columns(&join.schema()); @@ -2685,7 +2781,14 @@ mod tests { Arc::new(intermediate_schema), ); - let join = join_with_filter(left, right, on, filter, &JoinType::LeftSemi, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::LeftSemi, + NullEquality::NullEqualsNothing, + )?; let columns_header = columns(&join.schema()); assert_eq!(columns_header, vec!["a1", "b1", "c1"]); @@ -2719,7 +2822,13 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let join = join(left, right, on, &JoinType::RightSemi, false)?; + let join = join( + left, + right, + on, + &JoinType::RightSemi, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a2", "b2", "c2"]); @@ -2781,7 +2890,7 @@ mod tests { on.clone(), filter, &JoinType::RightSemi, - false, + NullEquality::NullEqualsNothing, )?; let columns = columns(&join.schema()); @@ -2816,8 +2925,14 @@ mod tests { Arc::new(intermediate_schema.clone()), ); - let join = - join_with_filter(left, right, on, filter, &JoinType::RightSemi, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::RightSemi, + NullEquality::NullEqualsNothing, + )?; let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; @@ -2848,7 +2963,13 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let join = join(left, right, on, &JoinType::LeftAnti, false)?; + let join = join( + left, + right, + on, + &JoinType::LeftAnti, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1"]); @@ -2907,7 +3028,7 @@ mod tests { on.clone(), filter, &JoinType::LeftAnti, - false, + NullEquality::NullEqualsNothing, )?; let columns_header = columns(&join.schema()); @@ -2944,7 +3065,14 @@ mod tests { Arc::new(intermediate_schema), ); - let join = join_with_filter(left, right, on, filter, &JoinType::LeftAnti, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::LeftAnti, + NullEquality::NullEqualsNothing, + )?; let columns_header = columns(&join.schema()); assert_eq!(columns_header, vec!["a1", "b1", "c1"]); @@ -2981,7 +3109,13 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let join = join(left, right, on, &JoinType::RightAnti, false)?; + let join = join( + left, + right, + on, + &JoinType::RightAnti, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a2", "b2", "c2"]); @@ -3041,7 +3175,7 @@ mod tests { on.clone(), filter, &JoinType::RightAnti, - false, + NullEquality::NullEqualsNothing, )?; let columns_header = columns(&join.schema()); @@ -3082,8 +3216,14 @@ mod tests { Arc::new(intermediate_schema), ); - let join = - join_with_filter(left, right, on, filter, &JoinType::RightAnti, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::RightAnti, + NullEquality::NullEqualsNothing, + )?; let columns_header = columns(&join.schema()); assert_eq!(columns_header, vec!["a2", "b2", "c2"]); @@ -3127,8 +3267,15 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Right, false, task_ctx).await?; + let (columns, batches) = join_collect( + left, + right, + on, + &JoinType::Right, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -3166,9 +3313,15 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (columns, batches) = - partitioned_join_collect(left, right, on, &JoinType::Right, false, task_ctx) - .await?; + let (columns, batches) = partitioned_join_collect( + left, + right, + on, + &JoinType::Right, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -3206,7 +3359,13 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, )]; - let join = join(left, right, on, &JoinType::Full, false)?; + let join = join( + left, + right, + on, + &JoinType::Full, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -3254,7 +3413,7 @@ mod tests { Arc::clone(&right), on.clone(), &JoinType::LeftMark, - false, + NullEquality::NullEqualsNothing, task_ctx, ) .await?; @@ -3299,7 +3458,7 @@ mod tests { Arc::clone(&right), on.clone(), &JoinType::LeftMark, - false, + NullEquality::NullEqualsNothing, task_ctx, ) .await?; @@ -3320,9 +3479,98 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] + #[tokio::test] + async fn join_right_mark(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (columns, batches) = join_collect( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &JoinType::RightMark, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; + assert_eq!(columns, vec!["a2", "b1", "c2", "mark"]); + + let expected = [ + "+----+----+----+-------+", + "| a2 | b1 | c2 | mark |", + "+----+----+----+-------+", + "| 10 | 4 | 70 | true |", + "| 20 | 5 | 80 | true |", + "| 30 | 6 | 90 | false |", + "+----+----+----+-------+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[apply(batch_sizes)] + #[tokio::test] + async fn partitioned_join_right_mark(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (columns, batches) = partitioned_join_collect( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &JoinType::RightMark, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; + assert_eq!(columns, vec!["a2", "b1", "c2", "mark"]); + + let expected = [ + "+----+----+----+-------+", + "| a2 | b1 | c2 | mark |", + "+----+----+----+-------+", + "| 10 | 4 | 60 | true |", + "| 20 | 4 | 70 | true |", + "| 30 | 5 | 80 | true |", + "| 40 | 6 | 90 | false |", + "+----+----+----+-------+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + #[test] fn join_with_hash_collision() -> Result<()> { - let mut hashmap_left = HashTable::with_capacity(2); + let mut hashmap_left = HashTable::with_capacity(4); let left = build_table_i32( ("a", &vec![10, 20]), ("x", &vec![100, 200]), @@ -3337,9 +3585,15 @@ mod tests { hashes_buff, )?; - // Create hash collisions (same hashes) + // Maps both values to both indices (1 and 2, representing input 0 and 1) + // 0 -> (0, 1) + // 1 -> (0, 2) + // The equality check will make sure only hashes[0] maps to 0 and hashes[1] maps to 1 hashmap_left.insert_unique(hashes[0], (hashes[0], 1), |(h, _)| *h); + hashmap_left.insert_unique(hashes[0], (hashes[0], 2), |(h, _)| *h); + hashmap_left.insert_unique(hashes[1], (hashes[1], 1), |(h, _)| *h); + hashmap_left.insert_unique(hashes[1], (hashes[1], 2), |(h, _)| *h); let next = vec![2, 0]; @@ -3368,7 +3622,7 @@ mod tests { &join_hash_map, &[left_keys_values], &[right_keys_values], - false, + NullEquality::NullEqualsNothing, &hashes_buffer, 8192, (0, None), @@ -3404,7 +3658,13 @@ mod tests { Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) as _, )]; - let join = join(left, right, on, &JoinType::Inner, false)?; + let join = join( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); @@ -3474,7 +3734,14 @@ mod tests { )]; let filter = prepare_join_filter(); - let join = join_with_filter(left, right, on, filter, &JoinType::Inner, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); @@ -3516,7 +3783,14 @@ mod tests { )]; let filter = prepare_join_filter(); - let join = join_with_filter(left, right, on, filter, &JoinType::Left, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::Left, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); @@ -3561,7 +3835,14 @@ mod tests { )]; let filter = prepare_join_filter(); - let join = join_with_filter(left, right, on, filter, &JoinType::Right, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::Right, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); @@ -3605,7 +3886,14 @@ mod tests { )]; let filter = prepare_join_filter(); - let join = join_with_filter(left, right, on, filter, &JoinType::Full, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + &JoinType::Full, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); @@ -3742,6 +4030,15 @@ mod tests { "| 3 | 7 | 9 | false |", "+----+----+----+-------+", ]; + let expected_right_mark = vec![ + "+----+----+----+-------+", + "| a2 | b2 | c2 | mark |", + "+----+----+----+-------+", + "| 10 | 4 | 70 | true |", + "| 20 | 5 | 80 | true |", + "| 30 | 6 | 90 | false |", + "+----+----+----+-------+", + ]; let test_cases = vec![ (JoinType::Inner, expected_inner), @@ -3753,6 +4050,7 @@ mod tests { (JoinType::RightSemi, expected_right_semi), (JoinType::RightAnti, expected_right_anti), (JoinType::LeftMark, expected_left_mark), + (JoinType::RightMark, expected_right_mark), ]; for (join_type, expected) in test_cases { @@ -3762,7 +4060,7 @@ mod tests { on.clone(), &join_type, PartitionMode::CollectLeft, - false, + NullEquality::NullEqualsNothing, Arc::clone(&task_ctx), ) .await?; @@ -3794,7 +4092,13 @@ mod tests { Arc::new(Column::new_with_schema("date", &right.schema()).unwrap()) as _, )]; - let join = join(left, right, on, &JoinType::Inner, false)?; + let join = join( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + )?; let task_ctx = Arc::new(TaskContext::default()); let stream = join.execute(0, task_ctx)?; @@ -3853,7 +4157,7 @@ mod tests { Arc::clone(&right_input) as Arc, on.clone(), &join_type, - false, + NullEquality::NullEqualsNothing, ) .unwrap(); let task_ctx = Arc::new(TaskContext::default()); @@ -3967,7 +4271,7 @@ mod tests { Arc::clone(&right), on.clone(), &join_type, - false, + NullEquality::NullEqualsNothing, ) .unwrap(); @@ -3990,10 +4294,7 @@ mod tests { assert_eq!( batches.len(), expected_batch_count, - "expected {} output batches for {} join with batch_size = {}", - expected_batch_count, - join_type, - batch_size + "expected {expected_batch_count} output batches for {join_type} join with batch_size = {batch_size}" ); let expected = match join_type { @@ -4035,6 +4336,7 @@ mod tests { JoinType::RightSemi, JoinType::RightAnti, JoinType::LeftMark, + JoinType::RightMark, ]; for join_type in join_types { @@ -4049,7 +4351,7 @@ mod tests { Arc::clone(&right), on.clone(), &join_type, - false, + NullEquality::NullEqualsNothing, )?; let stream = join.execute(0, task_ctx)?; @@ -4058,12 +4360,12 @@ mod tests { // Asserting that operator-level reservation attempting to overallocate assert_contains!( err.to_string(), - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput" + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n HashJoinInput" ); assert_contains!( err.to_string(), - "Failed to allocate additional 120 bytes for HashJoinInput" + "Failed to allocate additional 120.0 B for HashJoinInput" ); } @@ -4130,7 +4432,7 @@ mod tests { &join_type, None, PartitionMode::Partitioned, - false, + NullEquality::NullEqualsNothing, )?; let stream = join.execute(1, task_ctx)?; @@ -4139,13 +4441,13 @@ mod tests { // Asserting that stream-level reservation attempting to overallocate assert_contains!( err.to_string(), - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput[1]" + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n HashJoinInput[1]" ); assert_contains!( err.to_string(), - "Failed to allocate additional 120 bytes for HashJoinInput[1]" + "Failed to allocate additional 120.0 B for HashJoinInput[1]" ); } @@ -4190,8 +4492,15 @@ mod tests { Arc::new(Column::new_with_schema("n2", &right.schema())?) as _, )]; - let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (columns, batches) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; assert_eq!(columns, vec!["n1", "n2"]); @@ -4227,7 +4536,7 @@ mod tests { Arc::clone(&right), on.clone(), &JoinType::Inner, - true, + NullEquality::NullEqualsNull, Arc::clone(&task_ctx), ) .await?; @@ -4242,8 +4551,15 @@ mod tests { "#); } - let (_, batches_null_neq) = - join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + let (_, batches_null_neq) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; let expected_null_neq = ["+----+----+", "| n1 | n2 |", "+----+----+", "+----+----+"]; diff --git a/datafusion/physical-plan/src/joins/join_hash_map.rs b/datafusion/physical-plan/src/joins/join_hash_map.rs index 7af0aeca0fd6..521e19d7bf44 100644 --- a/datafusion/physical-plan/src/joins/join_hash_map.rs +++ b/datafusion/physical-plan/src/joins/join_hash_map.rs @@ -116,20 +116,10 @@ pub(crate) type JoinHashMapOffset = (usize, Option); macro_rules! chain_traverse { ( $input_indices:ident, $match_indices:ident, $hash_values:ident, $next_chain:ident, - $input_idx:ident, $chain_idx:ident, $deleted_offset:ident, $remaining_output:ident + $input_idx:ident, $chain_idx:ident, $remaining_output:ident ) => { - let mut i = $chain_idx - 1; + let mut match_row_idx = $chain_idx - 1; loop { - let match_row_idx = if let Some(offset) = $deleted_offset { - // This arguments means that we prune the next index way before here. - if i < offset as u64 { - // End of the list due to pruning - break; - } - i - offset as u64 - } else { - i - }; $match_indices.push(match_row_idx); $input_indices.push($input_idx as u32); $remaining_output -= 1; @@ -150,7 +140,7 @@ macro_rules! chain_traverse { // end of list break; } - i = next - 1; + match_row_idx = next - 1; } }; } @@ -168,6 +158,11 @@ pub trait JoinHashMapType { /// Returns a reference to the next. fn get_list(&self) -> &Self::NextType; + // Whether values in the hashmap are distinct (no duplicate keys) + fn is_distinct(&self) -> bool { + false + } + /// Updates hashmap from iterator of row indices & row hashes pairs. fn update_from_iter<'a>( &mut self, @@ -257,17 +252,35 @@ pub trait JoinHashMapType { fn get_matched_indices_with_limit_offset( &self, hash_values: &[u64], - deleted_offset: Option, limit: usize, offset: JoinHashMapOffset, ) -> (Vec, Vec, Option) { - let mut input_indices = vec![]; - let mut match_indices = vec![]; - - let mut remaining_output = limit; + let mut input_indices = Vec::with_capacity(limit); + let mut match_indices = Vec::with_capacity(limit); let hash_map: &HashTable<(u64, u64)> = self.get_map(); let next_chain = self.get_list(); + // Check if hashmap consists of unique values + // If so, we can skip the chain traversal + if self.is_distinct() { + let start = offset.0; + let end = (start + limit).min(hash_values.len()); + for (row_idx, &hash_value) in hash_values[start..end].iter().enumerate() { + if let Some((_, index)) = + hash_map.find(hash_value, |(hash, _)| hash_value == *hash) + { + input_indices.push(start as u32 + row_idx as u32); + match_indices.push(*index - 1); + } + } + if end == hash_values.len() { + // No more values to process + return (input_indices, match_indices, None); + } + return (input_indices, match_indices, Some((end, None))); + } + + let mut remaining_output = limit; // Calculate initial `hash_values` index before iterating let to_skip = match offset { @@ -286,7 +299,6 @@ pub trait JoinHashMapType { next_chain, initial_idx, initial_next_idx, - deleted_offset, remaining_output ); @@ -295,6 +307,7 @@ pub trait JoinHashMapType { }; let mut row_idx = to_skip; + for hash_value in &hash_values[to_skip..] { if let Some((_, index)) = hash_map.find(*hash_value, |(hash, _)| *hash_value == *hash) @@ -306,7 +319,6 @@ pub trait JoinHashMapType { next_chain, row_idx, index, - deleted_offset, remaining_output ); } @@ -338,6 +350,11 @@ impl JoinHashMapType for JoinHashMap { fn get_list(&self) -> &Self::NextType { &self.next } + + /// Check if the values in the hashmap are distinct. + fn is_distinct(&self) -> bool { + self.map.len() == self.next.len() + } } impl Debug for JoinHashMap { diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index b90279595096..fcc1107a0e26 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -259,10 +259,10 @@ impl NestedLoopJoinExec { None, // No on columns in nested loop join &[], - ); + )?; let mut output_partitioning = - asymmetric_join_output_partitioning(left, right, &join_type); + asymmetric_join_output_partitioning(left, right, &join_type)?; let emission_type = if left.boundedness().is_unbounded() { EmissionType::Final @@ -274,7 +274,8 @@ impl NestedLoopJoinExec { | JoinType::LeftSemi | JoinType::RightSemi | JoinType::Right - | JoinType::RightAnti => EmissionType::Incremental, + | JoinType::RightAnti + | JoinType::RightMark => EmissionType::Incremental, // If we need to generate unmatched rows from the *build side*, // we need to emit them at the end. JoinType::Left @@ -567,9 +568,16 @@ impl ExecutionPlan for NestedLoopJoinExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema())); + } estimate_join_statistics( - Arc::clone(&self.left), - Arc::clone(&self.right), + self.left.partition_statistics(None)?, + self.right.partition_statistics(None)?, vec![], &self.join_type, &self.join_schema, @@ -712,7 +720,7 @@ struct NestedLoopJoinStream { /// Information of index and left / right placement of columns column_indices: Vec, // TODO: support null aware equal - // null_equals_null: bool + // null_equality: NullEquality, /// Join execution metrics join_metrics: BuildProbeJoinMetrics, /// Cache for join indices calculations @@ -1002,15 +1010,30 @@ fn join_left_and_right_batch( right_side_ordered, )?; - build_batch_from_indices( - schema, - left_batch, - right_batch, - &left_side, - &right_side, - column_indices, - JoinSide::Left, - ) + // Switch around the build side and probe side for `JoinType::RightMark` + // because in a RightMark join, we want to mark rows on the right table + // by looking for matches in the left. + if join_type == JoinType::RightMark { + build_batch_from_indices( + schema, + right_batch, + left_batch, + &left_side, + &right_side, + column_indices, + JoinSide::Right, + ) + } else { + build_batch_from_indices( + schema, + left_batch, + right_batch, + &left_side, + &right_side, + column_indices, + JoinSide::Left, + ) + } } impl Stream for NestedLoopJoinStream { @@ -1081,22 +1104,18 @@ pub(crate) mod tests { vec![batch] }; - let mut source = - TestMemoryExec::try_new(&[batches], Arc::clone(&schema), None).unwrap(); - if !sorted_column_names.is_empty() { - let mut sort_info = LexOrdering::default(); - for name in sorted_column_names { - let index = schema.index_of(name).unwrap(); - let sort_expr = PhysicalSortExpr { - expr: Arc::new(Column::new(name, index)), - options: SortOptions { - descending: false, - nulls_first: false, - }, - }; - sort_info.push(sort_expr); - } - source = source.try_with_sort_information(vec![sort_info]).unwrap(); + let mut sort_info = vec![]; + for name in sorted_column_names { + let index = schema.index_of(name).unwrap(); + let sort_expr = PhysicalSortExpr::new( + Arc::new(Column::new(name, index)), + SortOptions::new(false, false), + ); + sort_info.push(sort_expr); + } + let mut source = TestMemoryExec::try_new(&[batches], schema, None).unwrap(); + if let Some(ordering) = LexOrdering::new(sort_info) { + source = source.try_with_sort_information(vec![ordering]).unwrap(); } Arc::new(TestMemoryExec::update_cache(Arc::new(source))) @@ -1457,6 +1476,36 @@ pub(crate) mod tests { Ok(()) } + #[tokio::test] + async fn join_right_mark_with_filter() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let left = build_left_table(); + let right = build_right_table(); + + let filter = prepare_join_filter(); + let (columns, batches) = multi_partitioned_join_collect( + left, + right, + &JoinType::RightMark, + Some(filter), + task_ctx, + ) + .await?; + assert_eq!(columns, vec!["a2", "b2", "c2", "mark"]); + + assert_snapshot!(batches_to_sort_string(&batches), @r#" + +----+----+-----+-------+ + | a2 | b2 | c2 | mark | + +----+----+-----+-------+ + | 10 | 10 | 100 | false | + | 12 | 10 | 40 | false | + | 2 | 2 | 80 | true | + +----+----+-----+-------+ + "#); + + Ok(()) + } + #[tokio::test] async fn test_overallocation() -> Result<()> { let left = build_table( @@ -1485,6 +1534,7 @@ pub(crate) mod tests { JoinType::LeftMark, JoinType::RightSemi, JoinType::RightAnti, + JoinType::RightMark, ]; for join_type in join_types { @@ -1506,7 +1556,7 @@ pub(crate) mod tests { assert_contains!( err.to_string(), - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: NestedLoopJoinLoad[0]" + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n NestedLoopJoinLoad[0]" ); } @@ -1663,11 +1713,7 @@ pub(crate) mod tests { .into_iter() .zip(prev_values) .all(|(current, prev)| current >= prev), - "batch_index: {} row: {} current: {:?}, prev: {:?}", - batch_index, - row, - current_values, - prev_values + "batch_index: {batch_index} row: {row} current: {current_values:?}, prev: {prev_values:?}" ); prev_values = current_values; } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 89f2e3c911f8..a8c209a492ba 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -16,9 +16,8 @@ // under the License. //! Defines the Sort-Merge join execution plan. -//! A Sort-Merge join plan consumes two sorted children plan and produces +//! A Sort-Merge join plan consumes two sorted children plans and produces //! joined output by given join type and other options. -//! Sort-Merge join feature is currently experimental. use std::any::Any; use std::cmp::Ordering; @@ -62,18 +61,18 @@ use arrow::compute::{ use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; use arrow::ipc::reader::StreamReader; +use datafusion_common::config::SpillCompression; use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_err, DataFusionError, HashSet, JoinSide, - JoinType, Result, + JoinType, NullEquality, Result, }; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; -use datafusion_physical_expr::PhysicalExprRef; -use datafusion_physical_expr_common::physical_expr::fmt_sql; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; use futures::{Stream, StreamExt}; @@ -147,8 +146,8 @@ pub struct SortMergeJoinExec { right_sort_exprs: LexOrdering, /// Sort options of join columns used in sorting left and right execution plans pub sort_options: Vec, - /// If null_equals_null is true, null == null else null != null - pub null_equals_null: bool, + /// Defines the null equality for the join. + pub null_equality: NullEquality, /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, } @@ -165,17 +164,11 @@ impl SortMergeJoinExec { filter: Option, join_type: JoinType, sort_options: Vec, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result { let left_schema = left.schema(); let right_schema = right.schema(); - if join_type == JoinType::RightSemi { - return not_impl_err!( - "SortMergeJoinExec does not support JoinType::RightSemi" - ); - } - check_join_is_valid(&left_schema, &right_schema, &on)?; if sort_options.len() != on.len() { return plan_err!( @@ -200,11 +193,21 @@ impl SortMergeJoinExec { (left, right) }) .unzip(); + let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else { + return plan_err!( + "SortMergeJoinExec requires valid sort expressions for its left side" + ); + }; + let Some(right_sort_exprs) = LexOrdering::new(right_sort_exprs) else { + return plan_err!( + "SortMergeJoinExec requires valid sort expressions for its right side" + ); + }; let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); let cache = - Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on); + Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on)?; Ok(Self { left, right, @@ -213,10 +216,10 @@ impl SortMergeJoinExec { join_type, schema, metrics: ExecutionPlanMetricsSet::new(), - left_sort_exprs: LexOrdering::new(left_sort_exprs), - right_sort_exprs: LexOrdering::new(right_sort_exprs), + left_sort_exprs, + right_sort_exprs, sort_options, - null_equals_null, + null_equality, cache, }) } @@ -227,9 +230,11 @@ impl SortMergeJoinExec { // When output schema contains only the right side, probe side is right. // Otherwise probe side is the left side. match join_type { - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { - JoinSide::Right - } + // TODO: sort merge support for right mark (tracked here: https://github.com/apache/datafusion/issues/16226) + JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => JoinSide::Right, JoinType::Inner | JoinType::Left | JoinType::Full @@ -247,7 +252,10 @@ impl SortMergeJoinExec { | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => vec![true, false], - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { + JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => { vec![false, true] } _ => vec![false, false], @@ -284,9 +292,9 @@ impl SortMergeJoinExec { &self.sort_options } - /// Null equals null - pub fn null_equals_null(&self) -> bool { - self.null_equals_null + /// Null equality + pub fn null_equality(&self) -> NullEquality { + self.null_equality } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -296,7 +304,7 @@ impl SortMergeJoinExec { schema: SchemaRef, join_type: JoinType, join_on: JoinOnRef, - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties: let eq_properties = join_equivalence_properties( left.equivalence_properties().clone(), @@ -306,17 +314,17 @@ impl SortMergeJoinExec { &Self::maintains_input_order(join_type), Some(Self::probe_side(&join_type)), join_on, - ); + )?; let output_partitioning = - symmetric_join_output_partitioning(left, right, &join_type); + symmetric_join_output_partitioning(left, right, &join_type)?; - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, EmissionType::Incremental, boundedness_from_children([left, right]), - ) + )) } pub fn swap_inputs(&self) -> Result> { @@ -332,7 +340,7 @@ impl SortMergeJoinExec { self.filter().as_ref().map(JoinFilter::swap), self.join_type().swap(), self.sort_options.clone(), - self.null_equals_null, + self.null_equality, )?; // TODO: OR this condition with having a built-in projection (like @@ -358,7 +366,7 @@ impl DisplayAs for SortMergeJoinExec { let on = self .on .iter() - .map(|(c1, c2)| format!("({}, {})", c1, c2)) + .map(|(c1, c2)| format!("({c1}, {c2})")) .collect::>() .join(", "); write!( @@ -385,7 +393,7 @@ impl DisplayAs for SortMergeJoinExec { if self.join_type() != JoinType::Inner { writeln!(f, "join_type={:?}", self.join_type)?; } - writeln!(f, "on={}", on) + writeln!(f, "on={on}") } } } @@ -416,10 +424,10 @@ impl ExecutionPlan for SortMergeJoinExec { ] } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { vec![ - Some(LexRequirement::from(self.left_sort_exprs.clone())), - Some(LexRequirement::from(self.right_sort_exprs.clone())), + Some(OrderingRequirements::from(self.left_sort_exprs.clone())), + Some(OrderingRequirements::from(self.right_sort_exprs.clone())), ] } @@ -443,7 +451,7 @@ impl ExecutionPlan for SortMergeJoinExec { self.filter.clone(), self.join_type, self.sort_options.clone(), - self.null_equals_null, + self.null_equality, )?)), _ => internal_err!("SortMergeJoin wrong number of children"), } @@ -493,9 +501,10 @@ impl ExecutionPlan for SortMergeJoinExec { // create join stream Ok(Box::pin(SortMergeJoinStream::try_new( + context.session_config().spill_compression(), Arc::clone(&self.schema), self.sort_options.clone(), - self.null_equals_null, + self.null_equality, streamed, buffered, on_streamed, @@ -514,12 +523,19 @@ impl ExecutionPlan for SortMergeJoinExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema())); + } // TODO stats: it is not possible in general to know the output size of joins // There are some special cases though, for example: // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` estimate_join_statistics( - Arc::clone(&self.left), - Arc::clone(&self.right), + self.left.partition_statistics(None)?, + self.right.partition_statistics(None)?, self.on.clone(), &self.join_type, &self.schema, @@ -577,7 +593,7 @@ impl ExecutionPlan for SortMergeJoinExec { self.filter.clone(), self.join_type, self.sort_options.clone(), - self.null_equals_null, + self.null_equality, )?))) } } @@ -830,8 +846,8 @@ struct SortMergeJoinStream { // ======================================================================== /// Output schema pub schema: SchemaRef, - /// null == null? - pub null_equals_null: bool, + /// Defines the null equality for the join. + pub null_equality: NullEquality, /// Sort options of join columns used to sort streamed and buffered data stream pub sort_options: Vec, /// optional join filter @@ -916,7 +932,7 @@ struct JoinedRecordBatches { pub batches: Vec, /// Filter match mask for each row(matched/non-matched) pub filter_mask: BooleanBuilder, - /// Row indices to glue together rows in `batches` and `filter_mask` + /// Left row indices to glue together rows in `batches` and `filter_mask` pub row_indices: UInt64Builder, /// Which unique batch id the row belongs to /// It is necessary to differentiate rows that are distributed the way when they point to the same @@ -1016,7 +1032,7 @@ fn get_corrected_filter_mask( corrected_mask.append_n(expected_size - corrected_mask.len(), false); Some(corrected_mask.finish()) } - JoinType::LeftSemi => { + JoinType::LeftSemi | JoinType::RightSemi => { for i in 0..row_indices_length { let last_index = last_index_for_row(i, row_indices, batch_ids, row_indices_length); @@ -1145,6 +1161,7 @@ impl Stream for SortMergeJoinStream { | JoinType::LeftSemi | JoinType::LeftMark | JoinType::Right + | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti | JoinType::Full @@ -1250,6 +1267,7 @@ impl Stream for SortMergeJoinStream { JoinType::Left | JoinType::LeftSemi | JoinType::Right + | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftMark @@ -1275,6 +1293,7 @@ impl Stream for SortMergeJoinStream { JoinType::Left | JoinType::LeftSemi | JoinType::Right + | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti | JoinType::Full @@ -1307,9 +1326,11 @@ impl Stream for SortMergeJoinStream { impl SortMergeJoinStream { #[allow(clippy::too_many_arguments)] pub fn try_new( + // Configured via `datafusion.execution.spill_compression`. + spill_compression: SpillCompression, schema: SchemaRef, sort_options: Vec, - null_equals_null: bool, + null_equality: NullEquality, streamed: SendableRecordBatchStream, buffered: SendableRecordBatchStream, on_streamed: Vec>, @@ -1327,11 +1348,12 @@ impl SortMergeJoinStream { Arc::clone(&runtime_env), join_metrics.spill_metrics.clone(), Arc::clone(&buffered_schema), - ); + ) + .with_compression_type(spill_compression); Ok(Self { state: SortMergeJoinState::Init, sort_options, - null_equals_null, + null_equality, schema: Arc::clone(&schema), streamed_schema: Arc::clone(&streamed_schema), buffered_schema, @@ -1576,7 +1598,7 @@ impl SortMergeJoinStream { &self.buffered_data.head_batch().join_arrays, self.buffered_data.head_batch().range.start, &self.sort_options, - self.null_equals_null, + self.null_equality, ) } @@ -1597,7 +1619,6 @@ impl SortMergeJoinStream { self.join_type, JoinType::Left | JoinType::Right - | JoinType::RightSemi | JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti @@ -1607,7 +1628,10 @@ impl SortMergeJoinStream { } } Ordering::Equal => { - if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftMark) { + if matches!( + self.join_type, + JoinType::LeftSemi | JoinType::LeftMark | JoinType::RightSemi + ) { mark_row_as_match = matches!(self.join_type, JoinType::LeftMark); // if the join filter is specified then its needed to output the streamed index // only if it has not been emitted before @@ -1827,7 +1851,10 @@ impl SortMergeJoinStream { vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef] } else if matches!( self.join_type, - JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightAnti + JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::RightSemi ) { vec![] } else if let Some(buffered_idx) = chunk.buffered_batch_idx { @@ -1861,7 +1888,10 @@ impl SortMergeJoinStream { )?; get_filter_column(&self.filter, &left_columns, &right_cols) - } else if matches!(self.join_type, JoinType::RightAnti) { + } else if matches!( + self.join_type, + JoinType::RightAnti | JoinType::RightSemi + ) { let right_cols = fetch_right_columns_by_idxs( &self.buffered_data, chunk.buffered_batch_idx.unwrap(), @@ -1922,6 +1952,7 @@ impl SortMergeJoinStream { JoinType::Left | JoinType::LeftSemi | JoinType::Right + | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftMark @@ -2019,6 +2050,7 @@ impl SortMergeJoinStream { JoinType::Left | JoinType::LeftSemi | JoinType::Right + | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftMark @@ -2128,7 +2160,7 @@ impl SortMergeJoinStream { let output_column_indices = (0..left_columns_length).collect::>(); filtered_record_batch = filtered_record_batch.project(&output_column_indices)?; - } else if matches!(self.join_type, JoinType::RightAnti) { + } else if matches!(self.join_type, JoinType::RightAnti | JoinType::RightSemi) { let output_column_indices = (0..right_columns_length).collect::>(); filtered_record_batch = filtered_record_batch.project(&output_column_indices)?; @@ -2407,7 +2439,7 @@ fn compare_join_arrays( right_arrays: &[ArrayRef], right: usize, sort_options: &[SortOptions], - null_equals_null: bool, + null_equality: NullEquality, ) -> Result { let mut res = Ordering::Equal; for ((left_array, right_array), sort_options) in @@ -2441,10 +2473,9 @@ fn compare_join_arrays( }; } _ => { - res = if null_equals_null { - Ordering::Equal - } else { - Ordering::Less + res = match null_equality { + NullEquality::NullEqualsNothing => Ordering::Less, + NullEquality::NullEqualsNull => Ordering::Equal, }; } } @@ -2465,6 +2496,7 @@ fn compare_join_arrays( DataType::Float32 => compare_value!(Float32Array), DataType::Float64 => compare_value!(Float64Array), DataType::Utf8 => compare_value!(StringArray), + DataType::Utf8View => compare_value!(StringViewArray), DataType::LargeUtf8 => compare_value!(LargeStringArray), DataType::Decimal128(..) => compare_value!(Decimal128Array), DataType::Timestamp(time_unit, None) => match time_unit { @@ -2532,6 +2564,7 @@ fn is_join_arrays_equal( DataType::Float32 => compare_value!(Float32Array), DataType::Float64 => compare_value!(Float64Array), DataType::Utf8 => compare_value!(StringArray), + DataType::Utf8View => compare_value!(StringViewArray), DataType::LargeUtf8 => compare_value!(LargeStringArray), DataType::Decimal128(..) => compare_value!(Decimal128Array), DataType::Timestamp(time_unit, None) => match time_unit { @@ -2568,13 +2601,15 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::JoinType::*; - use datafusion_common::{assert_batches_eq, assert_contains, JoinType, Result}; + use datafusion_common::{ + assert_batches_eq, assert_contains, JoinType, NullEquality, Result, + }; use datafusion_common::{ test_util::{batches_to_sort_string, batches_to_string}, JoinSide, }; use datafusion_execution::config::SessionConfig; - use datafusion_execution::disk_manager::DiskManagerConfig; + use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::TaskContext; use datafusion_expr::Operator; @@ -2693,7 +2728,15 @@ mod tests { join_type: JoinType, ) -> Result { let sort_options = vec![SortOptions::default(); on.len()]; - SortMergeJoinExec::try_new(left, right, on, None, join_type, sort_options, false) + SortMergeJoinExec::try_new( + left, + right, + on, + None, + join_type, + sort_options, + NullEquality::NullEqualsNothing, + ) } fn join_with_options( @@ -2702,7 +2745,7 @@ mod tests { on: JoinOn, join_type: JoinType, sort_options: Vec, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result { SortMergeJoinExec::try_new( left, @@ -2711,7 +2754,7 @@ mod tests { None, join_type, sort_options, - null_equals_null, + null_equality, ) } @@ -2722,7 +2765,7 @@ mod tests { filter: JoinFilter, join_type: JoinType, sort_options: Vec, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result { SortMergeJoinExec::try_new( left, @@ -2731,7 +2774,7 @@ mod tests { Some(filter), join_type, sort_options, - null_equals_null, + null_equality, ) } @@ -2742,7 +2785,15 @@ mod tests { join_type: JoinType, ) -> Result<(Vec, Vec)> { let sort_options = vec![SortOptions::default(); on.len()]; - join_collect_with_options(left, right, on, join_type, sort_options, false).await + join_collect_with_options( + left, + right, + on, + join_type, + sort_options, + NullEquality::NullEqualsNothing, + ) + .await } async fn join_collect_with_filter( @@ -2755,8 +2806,15 @@ mod tests { let sort_options = vec![SortOptions::default(); on.len()]; let task_ctx = Arc::new(TaskContext::default()); - let join = - join_with_filter(left, right, on, filter, join_type, sort_options, false)?; + let join = join_with_filter( + left, + right, + on, + filter, + join_type, + sort_options, + NullEquality::NullEqualsNothing, + )?; let columns = columns(&join.schema()); let stream = join.execute(0, task_ctx)?; @@ -2770,17 +2828,11 @@ mod tests { on: JoinOn, join_type: JoinType, sort_options: Vec, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result<(Vec, Vec)> { let task_ctx = Arc::new(TaskContext::default()); - let join = join_with_options( - left, - right, - on, - join_type, - sort_options, - null_equals_null, - )?; + let join = + join_with_options(left, right, on, join_type, sort_options, null_equality)?; let columns = columns(&join.schema()); let stream = join.execute(0, task_ctx)?; @@ -2986,7 +3038,7 @@ mod tests { }; 2 ], - true, + NullEquality::NullEqualsNull, ) .await?; // The output order is important as SMJ preserves sortedness @@ -3409,7 +3461,7 @@ mod tests { }; 2 ], - true, + NullEquality::NullEqualsNull, ) .await?; @@ -3467,7 +3519,7 @@ mod tests { } #[tokio::test] - async fn join_semi() -> Result<()> { + async fn join_left_semi() -> Result<()> { let left = build_table( ("a1", &vec![1, 2, 2, 3]), ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right @@ -3497,6 +3549,255 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_right_semi_one() -> Result<()> { + let left = build_table( + ("a1", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 5, 5, 6]), + ("c1", &vec![70, 80, 90, 100]), + ); + let right = build_table( + ("a2", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), + ("c2", &vec![7, 8, 8, 9]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a2 | b1 | c2 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_semi_two() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 6]), + ("c1", &vec![70, 80, 90, 100]), + ); + let right = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), + ("c2", &vec![7, 8, 8, 9]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_semi_two_with_filter() -> Result<()> { + let left = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c1", &vec![30])); + let right = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c2", &vec![20])); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c2", 1)), + Operator::Lt, + Arc::new(Column::new("c1", 0)), + )), + vec![ + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ], + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, true), + ])), + ); + let (_, batches) = + join_collect_with_filter(left, right, on, filter, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 10 | 20 |", + "+----+----+----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_semi_with_nulls() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(0), Some(1), Some(2), Some(2), Some(3)]), + ("b1", &vec![Some(3), Some(4), Some(5), None, Some(6)]), + ("c2", &vec![Some(60), None, Some(80), Some(85), Some(90)]), + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(2), Some(2), Some(3)]), + ("b1", &vec![Some(4), Some(5), None, Some(6)]), // null in key field + ("c2", &vec![Some(7), Some(8), Some(8), None]), // null in non-key field + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect(left, right, on, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 3 | 6 | |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_semi_with_nulls_with_options() -> Result<()> { + let left = build_table_i32_nullable( + ("a1", &vec![Some(3), Some(2), Some(1), Some(0), Some(2)]), + ("b1", &vec![None, Some(5), Some(4), None, Some(5)]), + ("c2", &vec![Some(90), Some(80), Some(70), Some(60), None]), + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(3), Some(2), Some(2), Some(1)]), + ("b1", &vec![None, Some(5), Some(5), Some(4)]), // null in key field + ("c2", &vec![Some(9), None, Some(8), Some(7)]), // null in non-key field + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = join_collect_with_options( + left, + right, + on, + RightSemi, + vec![ + SortOptions { + descending: true, + nulls_first: false, + }; + 2 + ], + NullEquality::NullEqualsNull, + ) + .await?; + + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 3 | | 9 |", + "| 2 | 5 | |", + "| 2 | 5 | 8 |", + "| 1 | 4 | 7 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_semi_output_two_batches() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 6]), + ("c1", &vec![70, 80, 90, 100]), + ); + let right = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), + ("c2", &vec![7, 8, 8, 9]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ), + ]; + + let (_, batches) = + join_collect_batch_size_equals_two(left, right, on, RightSemi).await?; + let expected = [ + "+----+----+----+", + "| a1 | b1 | c2 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_eq!(batches.len(), 2); + assert_eq!(batches[0].num_rows(), 2); + assert_eq!(batches[1].num_rows(), 1); + assert_batches_eq!(expected, &batches); + Ok(()) + } + #[tokio::test] async fn join_left_mark() -> Result<()> { let left = build_table( @@ -3856,12 +4157,16 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; + let join_types = vec![ + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, + ]; // Disable DiskManager to prevent spilling let runtime = RuntimeEnvBuilder::new() .with_memory_limit(100, 1.0) - .with_disk_manager(DiskManagerConfig::Disabled) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled), + ) .build_arc()?; let session_config = SessionConfig::default().with_batch_size(50); @@ -3877,7 +4182,7 @@ mod tests { on.clone(), join_type, sort_options.clone(), - false, + NullEquality::NullEqualsNothing, )?; let stream = join.execute(0, task_ctx)?; @@ -3934,12 +4239,16 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; + let join_types = vec![ + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, + ]; // Disable DiskManager to prevent spilling let runtime = RuntimeEnvBuilder::new() .with_memory_limit(100, 1.0) - .with_disk_manager(DiskManagerConfig::Disabled) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled), + ) .build_arc()?; let session_config = SessionConfig::default().with_batch_size(50); @@ -3954,7 +4263,7 @@ mod tests { on.clone(), join_type, sort_options.clone(), - false, + NullEquality::NullEqualsNothing, )?; let stream = join.execute(0, task_ctx)?; @@ -3990,12 +4299,16 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; + let join_types = [ + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, + ]; // Enable DiskManager to allow spilling let runtime = RuntimeEnvBuilder::new() .with_memory_limit(100, 1.0) - .with_disk_manager(DiskManagerConfig::NewOs) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) .build_arc()?; for batch_size in [1, 50] { @@ -4013,7 +4326,7 @@ mod tests { on.clone(), *join_type, sort_options.clone(), - false, + NullEquality::NullEqualsNothing, )?; let stream = join.execute(0, task_ctx)?; @@ -4035,7 +4348,7 @@ mod tests { on.clone(), *join_type, sort_options.clone(), - false, + NullEquality::NullEqualsNothing, )?; let stream = join.execute(0, task_ctx_no_spill)?; let no_spilled_join_result = common::collect(stream).await.unwrap(); @@ -4091,12 +4404,16 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; + let join_types = [ + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, + ]; // Enable DiskManager to allow spilling let runtime = RuntimeEnvBuilder::new() .with_memory_limit(500, 1.0) - .with_disk_manager(DiskManagerConfig::NewOs) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) .build_arc()?; for batch_size in [1, 50] { @@ -4113,7 +4430,7 @@ mod tests { on.clone(), *join_type, sort_options.clone(), - false, + NullEquality::NullEqualsNothing, )?; let stream = join.execute(0, task_ctx)?; @@ -4134,7 +4451,7 @@ mod tests { on.clone(), *join_type, sort_options.clone(), - false, + NullEquality::NullEqualsNothing, )?; let stream = join.execute(0, task_ctx_no_spill)?; let no_spilled_join_result = common::collect(stream).await.unwrap(); @@ -4475,169 +4792,177 @@ mod tests { } #[tokio::test] - async fn test_left_semi_join_filtered_mask() -> Result<()> { - let mut joined_batches = build_joined_record_batches()?; - let schema = joined_batches.batches.first().unwrap().schema(); + async fn test_semi_join_filtered_mask() -> Result<()> { + for join_type in [LeftSemi, RightSemi] { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); - let output = concat_batches(&schema, &joined_batches.batches)?; - let out_mask = joined_batches.filter_mask.finish(); - let out_indices = joined_batches.row_indices.finish(); + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![true]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![false]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0]), - &[0usize; 2], - &BooleanArray::from(vec![true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![Some(true), None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], + &BooleanArray::from(vec![true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![Some(true), None, None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None, None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, false, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![Some(true), None, None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None, None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![None, None, Some(true),]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, None, Some(true),]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![None, Some(true), None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, Some(true), None]) + ); - assert_eq!( - get_corrected_filter_mask( - LeftSemi, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, false]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![None, None, None]) - ); + assert_eq!( + get_corrected_filter_mask( + join_type, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); - let corrected_mask = get_corrected_filter_mask( - LeftSemi, - &out_indices, - &joined_batches.batch_ids, - &out_mask, - output.num_rows(), - ) - .unwrap(); + let corrected_mask = get_corrected_filter_mask( + join_type, + &out_indices, + &joined_batches.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); - assert_eq!( - corrected_mask, - BooleanArray::from(vec![ - Some(true), - None, - Some(true), - None, - Some(true), - None, - None, - None - ]) - ); + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + None, + Some(true), + None, + None, + None + ]) + ); - let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; - assert_snapshot!(batches_to_string(&[filtered_rb]), @r#" - +---+----+---+----+ - | a | b | x | y | - +---+----+---+----+ - | 1 | 10 | 1 | 11 | - | 1 | 11 | 1 | 12 | - | 1 | 12 | 1 | 13 | - +---+----+---+----+ - "#); + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 10 | 1 | 11 |", + "| 1 | 11 | 1 | 12 |", + "| 1 | 12 | 1 | 13 |", + "+---+----+---+----+", + ], + &[filtered_rb] + ); - // output null rows - let null_mask = arrow::compute::not(&corrected_mask)?; - assert_eq!( - null_mask, - BooleanArray::from(vec![ - Some(false), - None, - Some(false), - None, - Some(false), - None, - None, - None - ]) - ); + // output null rows + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + Some(false), + None, + Some(false), + None, + Some(false), + None, + None, + None + ]) + ); - let null_joined_batch = filter_record_batch(&output, &null_mask)?; + let null_joined_batch = filter_record_batch(&output, &null_mask)?; - assert_snapshot!(batches_to_string(&[null_joined_batch]), @r#" - +---+---+---+---+ - | a | b | x | y | - +---+---+---+---+ - +---+---+---+---+ - "#); + assert_batches_eq!( + &[ + "+---+---+---+---+", + "| a | b | x | y |", + "+---+---+---+---+", + "+---+---+---+---+", + ], + &[null_joined_batch] + ); + } Ok(()) } diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 0dcb42169e00..6dbe75cc0ae4 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -67,15 +67,16 @@ use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::bisect; -use datafusion_common::{internal_err, plan_err, HashSet, JoinSide, JoinType, Result}; +use datafusion_common::{ + internal_err, plan_err, HashSet, JoinSide, JoinType, NullEquality, Result, +}; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; -use datafusion_physical_expr::PhysicalExprRef; -use datafusion_physical_expr_common::physical_expr::fmt_sql; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; use ahash::RandomState; use futures::{ready, Stream, StreamExt}; @@ -186,8 +187,8 @@ pub struct SymmetricHashJoinExec { metrics: ExecutionPlanMetricsSet, /// Information of index and left / right placement of columns column_indices: Vec, - /// If null_equals_null is true, null == null else null != null - pub(crate) null_equals_null: bool, + /// Defines the null equality for the join. + pub(crate) null_equality: NullEquality, /// Left side sort expression(s) pub(crate) left_sort_exprs: Option, /// Right side sort expression(s) @@ -212,7 +213,7 @@ impl SymmetricHashJoinExec { on: JoinOn, filter: Option, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, left_sort_exprs: Option, right_sort_exprs: Option, mode: StreamJoinPartitionMode, @@ -237,8 +238,7 @@ impl SymmetricHashJoinExec { // Initialize the random state for the join operation: let random_state = RandomState::with_seeds(0, 0, 0, 0); let schema = Arc::new(schema); - let cache = - Self::compute_properties(&left, &right, Arc::clone(&schema), *join_type, &on); + let cache = Self::compute_properties(&left, &right, schema, *join_type, &on)?; Ok(SymmetricHashJoinExec { left, right, @@ -248,7 +248,7 @@ impl SymmetricHashJoinExec { random_state, metrics: ExecutionPlanMetricsSet::new(), column_indices, - null_equals_null, + null_equality, left_sort_exprs, right_sort_exprs, mode, @@ -263,7 +263,7 @@ impl SymmetricHashJoinExec { schema: SchemaRef, join_type: JoinType, join_on: JoinOnRef, - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties: let eq_properties = join_equivalence_properties( left.equivalence_properties().clone(), @@ -274,17 +274,17 @@ impl SymmetricHashJoinExec { // Has alternating probe side None, join_on, - ); + )?; let output_partitioning = - symmetric_join_output_partitioning(left, right, &join_type); + symmetric_join_output_partitioning(left, right, &join_type)?; - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, emission_type_from_children([left, right]), boundedness_from_children([left, right]), - ) + )) } /// left stream @@ -312,9 +312,9 @@ impl SymmetricHashJoinExec { &self.join_type } - /// Get null_equals_null - pub fn null_equals_null(&self) -> bool { - self.null_equals_null + /// Get null_equality + pub fn null_equality(&self) -> NullEquality { + self.null_equality } /// Get partition mode @@ -372,7 +372,7 @@ impl DisplayAs for SymmetricHashJoinExec { let on = self .on .iter() - .map(|(c1, c2)| format!("({}, {})", c1, c2)) + .map(|(c1, c2)| format!("({c1}, {c2})")) .collect::>() .join(", "); write!( @@ -395,7 +395,7 @@ impl DisplayAs for SymmetricHashJoinExec { if *self.join_type() != JoinType::Inner { writeln!(f, "join_type={:?}", self.join_type)?; } - writeln!(f, "on={}", on) + writeln!(f, "on={on}") } } } @@ -433,16 +433,14 @@ impl ExecutionPlan for SymmetricHashJoinExec { } } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { vec![ self.left_sort_exprs .as_ref() - .cloned() - .map(LexRequirement::from), + .map(|e| OrderingRequirements::from(e.clone())), self.right_sort_exprs .as_ref() - .cloned() - .map(LexRequirement::from), + .map(|e| OrderingRequirements::from(e.clone())), ] } @@ -460,7 +458,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { self.on.clone(), self.filter.clone(), &self.join_type, - self.null_equals_null, + self.null_equality, self.left_sort_exprs.clone(), self.right_sort_exprs.clone(), self.mode, @@ -549,7 +547,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { graph, left_sorted_filter_expr, right_sorted_filter_expr, - null_equals_null: self.null_equals_null, + null_equality: self.null_equality, state: SHJStreamState::PullRight, reservation, batch_transformer: BatchSplitter::new(batch_size), @@ -569,7 +567,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { graph, left_sorted_filter_expr, right_sorted_filter_expr, - null_equals_null: self.null_equals_null, + null_equality: self.null_equality, state: SHJStreamState::PullRight, reservation, batch_transformer: NoopBatchTransformer::new(), @@ -635,21 +633,18 @@ impl ExecutionPlan for SymmetricHashJoinExec { self.right(), )?; - Ok(Some(Arc::new(SymmetricHashJoinExec::try_new( + SymmetricHashJoinExec::try_new( Arc::new(new_left), Arc::new(new_right), new_on, new_filter, self.join_type(), - self.null_equals_null(), - self.right() - .output_ordering() - .map(|p| LexOrdering::new(p.to_vec())), - self.left() - .output_ordering() - .map(|p| LexOrdering::new(p.to_vec())), + self.null_equality(), + self.right().output_ordering().cloned(), + self.left().output_ordering().cloned(), self.partition_mode(), - )?))) + ) + .map(|e| Some(Arc::new(e) as _)) } } @@ -678,8 +673,8 @@ struct SymmetricHashJoinStream { right_sorted_filter_expr: Option, /// Random state used for hashing initialization random_state: RandomState, - /// If null_equals_null is true, null == null else null != null - null_equals_null: bool, + /// Defines the null equality for the join. + null_equality: NullEquality, /// Metrics metrics: StreamJoinMetrics, /// Memory reservation @@ -777,7 +772,11 @@ fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> } else { matches!( join_type, - JoinType::Right | JoinType::RightAnti | JoinType::Full | JoinType::RightSemi + JoinType::Right + | JoinType::RightAnti + | JoinType::Full + | JoinType::RightSemi + | JoinType::RightMark ) } } @@ -811,6 +810,21 @@ where { // Store the result in a tuple let result = match (build_side, join_type) { + // For a mark join we “mark” each build‐side row with a dummy 0 in the probe‐side index + // if it ever matched. For example, if + // + // prune_length = 5 + // deleted_offset = 0 + // visited_rows = {1, 3} + // + // then we produce: + // + // build_indices = [0, 1, 2, 3, 4] + // probe_indices = [None, Some(0), None, Some(0), None] + // + // Example: for each build row i in [0..5): + // – We always output its own index i in `build_indices` + // – We output `Some(0)` in `probe_indices[i]` if row i was ever visited, else `None` (JoinSide::Left, JoinType::LeftMark) => { let build_indices = (0..prune_length) .map(L::Native::from_usize) @@ -825,6 +839,20 @@ where .collect(); (build_indices, probe_indices) } + (JoinSide::Right, JoinType::RightMark) => { + let build_indices = (0..prune_length) + .map(L::Native::from_usize) + .collect::>(); + let probe_indices = (0..prune_length) + .map(|idx| { + // For mark join we output a dummy index 0 to indicate the row had a match + visited_rows + .contains(&(idx + deleted_offset)) + .then_some(R::Native::from_usize(0).unwrap()) + }) + .collect(); + (build_indices, probe_indices) + } // In the case of `Left` or `Right` join, or `Full` join, get the anti indices (JoinSide::Left, JoinType::Left | JoinType::LeftAnti) | (JoinSide::Right, JoinType::Right | JoinType::RightAnti) @@ -923,7 +951,7 @@ pub(crate) fn build_side_determined_results( /// * `probe_batch` - The second record batch to be joined. /// * `column_indices` - An array of columns to be selected for the result of the join. /// * `random_state` - The random state for the join. -/// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining. +/// * `null_equality` - Indicates whether NULL values should be treated as equal when joining. /// /// # Returns /// @@ -939,7 +967,7 @@ pub(crate) fn join_with_probe_batch( probe_batch: &RecordBatch, column_indices: &[ColumnIndex], random_state: &RandomState, - null_equals_null: bool, + null_equality: NullEquality, ) -> Result> { if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { return Ok(None); @@ -951,7 +979,7 @@ pub(crate) fn join_with_probe_batch( &build_hash_joiner.on, &probe_hash_joiner.on, random_state, - null_equals_null, + null_equality, &mut build_hash_joiner.hashes_buffer, Some(build_hash_joiner.deleted_offset), )?; @@ -1017,7 +1045,7 @@ pub(crate) fn join_with_probe_batch( /// * `build_on` - An array of columns on which the join will be performed. The columns are from the build side of the join. /// * `probe_on` - An array of columns on which the join will be performed. The columns are from the probe side of the join. /// * `random_state` - The random state for the join. -/// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining. +/// * `null_equality` - Indicates whether NULL values should be treated as equal when joining. /// * `hashes_buffer` - Buffer used for probe side keys hash calculation. /// * `deleted_offset` - deleted offset for build side data. /// @@ -1033,7 +1061,7 @@ fn lookup_join_hashmap( build_on: &[PhysicalExprRef], probe_on: &[PhysicalExprRef], random_state: &RandomState, - null_equals_null: bool, + null_equality: NullEquality, hashes_buffer: &mut Vec, deleted_offset: Option, ) -> Result<(UInt64Array, UInt32Array)> { @@ -1094,7 +1122,7 @@ fn lookup_join_hashmap( &probe_indices, &build_join_values, &keys_values, - null_equals_null, + null_equality, )?; Ok((build_indices, probe_indices)) @@ -1591,7 +1619,7 @@ impl SymmetricHashJoinStream { size += size_of_val(&self.left_sorted_filter_expr); size += size_of_val(&self.right_sorted_filter_expr); size += size_of_val(&self.random_state); - size += size_of_val(&self.null_equals_null); + size += size_of_val(&self.null_equality); size += size_of_val(&self.metrics); size } @@ -1646,7 +1674,7 @@ impl SymmetricHashJoinStream { &probe_batch, &self.column_indices, &self.random_state, - self.null_equals_null, + self.null_equality, )?; // Increment the offset for the probe hash joiner: probe_hash_joiner.offset += probe_batch.num_rows(); @@ -1743,7 +1771,7 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{binary, col, lit, Column}; - use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use rstest::*; @@ -1802,12 +1830,18 @@ mod tests { on.clone(), filter.clone(), &join_type, - false, + NullEquality::NullEqualsNothing, Arc::clone(&task_ctx), ) .await?; let second_batches = partitioned_hash_join_with_filter( - left, right, on, filter, &join_type, false, task_ctx, + left, + right, + on, + filter, + &join_type, + NullEquality::NullEqualsNothing, + task_ctx, ) .await?; compare_batches(&first_batches, &second_batches); @@ -1843,7 +1877,7 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: binary( col("la1", left_schema)?, Operator::Plus, @@ -1851,11 +1885,13 @@ mod tests { left_schema, )?, options: SortOptions::default(), - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -1923,14 +1959,16 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("la1", left_schema)?, options: SortOptions::default(), - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2068,20 +2106,22 @@ mod tests { let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?; let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("la1_des", left_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("ra1_des", right_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2127,20 +2167,22 @@ mod tests { let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?; let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("l_asc_null_first", left_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("r_asc_null_first", right_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2186,20 +2228,22 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("l_asc_null_last", left_schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("r_asc_null_last", right_schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2247,20 +2291,22 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("l_desc_null_first", left_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("r_desc_null_first", right_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2309,15 +2355,16 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("la1", left_schema)?, options: SortOptions::default(), - }]); - - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2368,20 +2415,23 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); let left_sorted = vec![ - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("la1", left_schema)?, options: SortOptions::default(), - }]), - LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(), + [PhysicalSortExpr { expr: col("la2", left_schema)?, options: SortOptions::default(), - }]), + }] + .into(), ]; - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let right_sorted = [PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, @@ -2449,20 +2499,22 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("lt1", left_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("rt1", right_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2532,20 +2584,22 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("li1", left_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("ri1", right_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2608,14 +2662,16 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { + let left_sorted = [PhysicalSortExpr { expr: col("l_float", left_schema)?, options: SortOptions::default(), - }]); - let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { + }] + .into(); + let right_sorted = [PhysicalSortExpr { expr: col("r_float", right_schema)?, options: SortOptions::default(), - }]); + }] + .into(); let (left, right) = create_memory_table( left_partition, right_partition, diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index d38637dae028..ea893cc9338e 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -33,7 +33,7 @@ use arrow::array::{ }; use arrow::datatypes::{DataType, Schema}; use arrow::util::pretty::pretty_format_batches; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::{NullEquality, Result, ScalarValue}; use datafusion_execution::TaskContext; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::expressions::{binary, cast, col, lit}; @@ -74,7 +74,7 @@ pub async fn partitioned_sym_join_with_filter( on: JoinOn, filter: Option, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, context: Arc, ) -> Result> { let partition_count = 4; @@ -101,11 +101,9 @@ pub async fn partitioned_sym_join_with_filter( on, filter, join_type, - null_equals_null, - left.output_ordering().map(|p| LexOrdering::new(p.to_vec())), - right - .output_ordering() - .map(|p| LexOrdering::new(p.to_vec())), + null_equality, + left.output_ordering().cloned(), + right.output_ordering().cloned(), StreamJoinPartitionMode::Partitioned, )?; @@ -130,7 +128,7 @@ pub async fn partitioned_hash_join_with_filter( on: JoinOn, filter: Option, join_type: &JoinType, - null_equals_null: bool, + null_equality: NullEquality, context: Arc, ) -> Result> { let partition_count = 4; @@ -153,7 +151,7 @@ pub async fn partitioned_hash_join_with_filter( join_type, None, PartitionMode::Partitioned, - null_equals_null, + null_equality, )?); let mut batches = vec![]; @@ -195,7 +193,7 @@ struct AscendingRandomFloatIterator { impl AscendingRandomFloatIterator { fn new(min: f64, max: f64) -> Self { let mut rng = StdRng::seed_from_u64(42); - let initial = rng.gen_range(min..max); + let initial = rng.random_range(min..max); AscendingRandomFloatIterator { prev: initial, max, @@ -208,7 +206,7 @@ impl Iterator for AscendingRandomFloatIterator { type Item = f64; fn next(&mut self) -> Option { - let value = self.rng.gen_range(self.prev..self.max); + let value = self.rng.random_range(self.prev..self.max); self.prev = value; Some(value) } diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 5516f172d510..c5f7087ac195 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -25,7 +25,9 @@ use std::ops::Range; use std::sync::Arc; use std::task::{Context, Poll}; +use crate::joins::SharedBitmapBuilder; use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; +use crate::projection::ProjectionExec; use crate::{ ColumnStatistics, ExecutionPlan, ExecutionPlanProperties, Partitioning, Statistics, }; @@ -39,26 +41,24 @@ use arrow::array::{ BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch, RecordBatchOptions, UInt32Array, UInt32Builder, UInt64Array, }; +use arrow::buffer::NullBuffer; use arrow::compute; use arrow::datatypes::{ ArrowNativeType, Field, Schema, SchemaBuilder, UInt32Type, UInt64Type, }; use datafusion_common::cast::as_boolean_array; use datafusion_common::stats::Precision; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ plan_err, DataFusionError, JoinSide, JoinType, Result, SharedResult, }; use datafusion_expr::interval_arithmetic::Interval; -use datafusion_physical_expr::equivalence::add_offset_to_expr; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::utils::{collect_columns, merge_vectors}; +use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{ - LexOrdering, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, + add_offset_to_expr, add_offset_to_physical_sort_exprs, LexOrdering, PhysicalExpr, + PhysicalExprRef, }; -use crate::joins::SharedBitmapBuilder; -use crate::projection::ProjectionExec; use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; use parking_lot::Mutex; @@ -114,113 +114,84 @@ fn check_join_set_is_valid( pub fn adjust_right_output_partitioning( right_partitioning: &Partitioning, left_columns_len: usize, -) -> Partitioning { - match right_partitioning { +) -> Result { + let result = match right_partitioning { Partitioning::Hash(exprs, size) => { let new_exprs = exprs .iter() - .map(|expr| add_offset_to_expr(Arc::clone(expr), left_columns_len)) - .collect(); + .map(|expr| add_offset_to_expr(Arc::clone(expr), left_columns_len as _)) + .collect::>()?; Partitioning::Hash(new_exprs, *size) } result => result.clone(), - } -} - -/// Replaces the right column (first index in the `on_column` tuple) with -/// the left column (zeroth index in the tuple) inside `right_ordering`. -fn replace_on_columns_of_right_ordering( - on_columns: &[(PhysicalExprRef, PhysicalExprRef)], - right_ordering: &mut LexOrdering, -) -> Result<()> { - for (left_col, right_col) in on_columns { - right_ordering.transform(|item| { - let new_expr = Arc::clone(&item.expr) - .transform(|e| { - if e.eq(right_col) { - Ok(Transformed::yes(Arc::clone(left_col))) - } else { - Ok(Transformed::no(e)) - } - }) - .data() - .expect("closure is infallible"); - item.expr = new_expr; - }); - } - Ok(()) -} - -fn offset_ordering( - ordering: &LexOrdering, - join_type: &JoinType, - offset: usize, -) -> LexOrdering { - match join_type { - // In the case below, right ordering should be offsetted with the left - // side length, since we append the right table to the left table. - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => ordering - .iter() - .map(|sort_expr| PhysicalSortExpr { - expr: add_offset_to_expr(Arc::clone(&sort_expr.expr), offset), - options: sort_expr.options, - }) - .collect(), - _ => ordering.clone(), - } + }; + Ok(result) } /// Calculate the output ordering of a given join operation. pub fn calculate_join_output_ordering( - left_ordering: &LexOrdering, - right_ordering: &LexOrdering, + left_ordering: Option<&LexOrdering>, + right_ordering: Option<&LexOrdering>, join_type: JoinType, - on_columns: &[(PhysicalExprRef, PhysicalExprRef)], left_columns_len: usize, maintains_input_order: &[bool], probe_side: Option, -) -> Option { - let output_ordering = match maintains_input_order { +) -> Result> { + match maintains_input_order { [true, false] => { // Special case, we can prefix ordering of right side with the ordering of left side. if join_type == JoinType::Inner && probe_side == Some(JoinSide::Left) { - replace_on_columns_of_right_ordering( - on_columns, - &mut right_ordering.clone(), - ) - .ok()?; - merge_vectors( - left_ordering, - offset_ordering(right_ordering, &join_type, left_columns_len) - .as_ref(), - ) - } else { - left_ordering.clone() + if let Some(right_ordering) = right_ordering.cloned() { + let right_offset = add_offset_to_physical_sort_exprs( + right_ordering, + left_columns_len as _, + )?; + return if let Some(left_ordering) = left_ordering { + let mut result = left_ordering.clone(); + result.extend(right_offset); + Ok(Some(result)) + } else { + Ok(LexOrdering::new(right_offset)) + }; + } } + Ok(left_ordering.cloned()) } [false, true] => { // Special case, we can prefix ordering of left side with the ordering of right side. if join_type == JoinType::Inner && probe_side == Some(JoinSide::Right) { - replace_on_columns_of_right_ordering( - on_columns, - &mut right_ordering.clone(), - ) - .ok()?; - merge_vectors( - offset_ordering(right_ordering, &join_type, left_columns_len) - .as_ref(), - left_ordering, - ) - } else { - offset_ordering(right_ordering, &join_type, left_columns_len) + return if let Some(right_ordering) = right_ordering.cloned() { + let mut right_offset = add_offset_to_physical_sort_exprs( + right_ordering, + left_columns_len as _, + )?; + if let Some(left_ordering) = left_ordering { + right_offset.extend(left_ordering.clone()); + } + Ok(LexOrdering::new(right_offset)) + } else { + Ok(left_ordering.cloned()) + }; + } + let Some(right_ordering) = right_ordering else { + return Ok(None); + }; + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + add_offset_to_physical_sort_exprs( + right_ordering.clone(), + left_columns_len as _, + ) + .map(LexOrdering::new) + } + _ => Ok(Some(right_ordering.clone())), } } // Doesn't maintain ordering, output ordering is None. - [false, false] => return None, + [false, false] => Ok(None), [true, true] => unreachable!("Cannot maintain ordering of both sides"), _ => unreachable!("Join operators can not have more than two children"), - }; - (!output_ordering.is_empty()).then_some(output_ordering) + } } /// Information about the index and placement (left or right) of the columns @@ -246,6 +217,7 @@ fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) -> JoinType::LeftAnti => false, // doesn't introduce nulls (or can it??) JoinType::RightAnti => false, // doesn't introduce nulls (or can it??) JoinType::LeftMark => false, + JoinType::RightMark => false, }; if force_nullable { @@ -312,14 +284,30 @@ pub fn build_join_schema( left_fields().chain(right_field).unzip() } JoinType::RightSemi | JoinType::RightAnti => right_fields().unzip(), + JoinType::RightMark => { + let left_field = once(( + Field::new("mark", arrow_schema::DataType::Boolean, false), + ColumnIndex { + index: 0, + side: JoinSide::None, + }, + )); + right_fields().chain(left_field).unzip() + } + }; + + let (schema1, schema2) = match join_type { + JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => (left, right), + _ => (right, left), }; - let metadata = left + let metadata = schema1 .metadata() .clone() .into_iter() - .chain(right.metadata().clone()) + .chain(schema2.metadata().clone()) .collect(); + (fields.finish().with_metadata(metadata), column_indices) } @@ -403,15 +391,12 @@ struct PartialJoinStatistics { /// Estimate the statistics for the given join's output. pub(crate) fn estimate_join_statistics( - left: Arc, - right: Arc, + left_stats: Statistics, + right_stats: Statistics, on: JoinOn, join_type: &JoinType, schema: &Schema, ) -> Result { - let left_stats = left.statistics()?; - let right_stats = right.statistics()?; - let join_stats = estimate_join_cardinality(join_type, left_stats, right_stats, &on); let (num_rows, column_statistics) = match join_stats { Some(stats) => (Precision::Inexact(stats.num_rows), stats.column_statistics), @@ -536,6 +521,15 @@ fn estimate_join_cardinality( column_statistics, }) } + JoinType::RightMark => { + let num_rows = *right_stats.num_rows.get_value()?; + let mut column_statistics = right_stats.column_statistics; + column_statistics.push(ColumnStatistics::new_unknown()); + Some(PartialJoinStatistics { + num_rows, + column_statistics, + }) + } } } @@ -907,7 +901,7 @@ pub(crate) fn build_batch_from_indices( for column_index in column_indices { let array = if column_index.side == JoinSide::None { - // LeftMark join, the mark column is a true if the indices is not null, otherwise it will be false + // For mark joins, the mark column is a true if the indices is not null, otherwise it will be false Arc::new(compute::is_not_null(probe_indices)?) } else if column_index.side == build_side { let array = build_input_buffer.column(column_index.index); @@ -978,6 +972,12 @@ pub(crate) fn adjust_indices_by_join_type( // the left_indices will not be used later for the `right anti` join Ok((left_indices, right_indices)) } + JoinType::RightMark => { + let right_indices = get_mark_indices(&adjust_range, &right_indices); + let left_indices_vec: Vec = adjust_range.map(|i| i as u64).collect(); + let left_indices = UInt64Array::from(left_indices_vec); + Ok((left_indices, right_indices)) + } JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { // matched or unmatched left row will be produced in the end of loop // When visit the right batch, we can output the matched left row and don't need to wait the end of loop @@ -1079,17 +1079,7 @@ pub(crate) fn get_anti_indices( where NativeAdapter: From<::Native>, { - let mut bitmap = BooleanBufferBuilder::new(range.len()); - bitmap.append_n(range.len(), false); - input_indices - .iter() - .flatten() - .map(|v| v.as_usize()) - .filter(|v| range.contains(v)) - .for_each(|v| { - bitmap.set_bit(v - range.start, true); - }); - + let bitmap = build_range_bitmap(&range, input_indices); let offset = range.start; // get the anti index @@ -1108,19 +1098,8 @@ pub(crate) fn get_semi_indices( where NativeAdapter: From<::Native>, { - let mut bitmap = BooleanBufferBuilder::new(range.len()); - bitmap.append_n(range.len(), false); - input_indices - .iter() - .flatten() - .map(|v| v.as_usize()) - .filter(|v| range.contains(v)) - .for_each(|v| { - bitmap.set_bit(v - range.start, true); - }); - + let bitmap = build_range_bitmap(&range, input_indices); let offset = range.start; - // get the semi index (range) .filter_map(|idx| { @@ -1129,6 +1108,37 @@ where .collect() } +pub(crate) fn get_mark_indices( + range: &Range, + input_indices: &PrimitiveArray, +) -> PrimitiveArray +where + NativeAdapter: From<::Native>, +{ + let mut bitmap = build_range_bitmap(range, input_indices); + PrimitiveArray::new( + vec![0; range.len()].into(), + Some(NullBuffer::new(bitmap.finish())), + ) +} + +fn build_range_bitmap( + range: &Range, + input: &PrimitiveArray, +) -> BooleanBufferBuilder { + let mut builder = BooleanBufferBuilder::new(range.len()); + builder.append_n(range.len(), false); + + input.iter().flatten().for_each(|v| { + let idx = v.as_usize(); + if range.contains(&idx) { + builder.set_bit(idx - range.start, true); + } + }); + + builder +} + /// Appends probe indices in order by considering the given build indices. /// /// This function constructs new build and probe indices by iterating through @@ -1296,36 +1306,41 @@ pub(crate) fn symmetric_join_output_partitioning( left: &Arc, right: &Arc, join_type: &JoinType, -) -> Partitioning { +) -> Result { let left_columns_len = left.schema().fields.len(); let left_partitioning = left.output_partitioning(); let right_partitioning = right.output_partitioning(); - match join_type { + let result = match join_type { JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { left_partitioning.clone() } - JoinType::RightSemi | JoinType::RightAnti => right_partitioning.clone(), + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + right_partitioning.clone() + } JoinType::Inner | JoinType::Right => { - adjust_right_output_partitioning(right_partitioning, left_columns_len) + adjust_right_output_partitioning(right_partitioning, left_columns_len)? } JoinType::Full => { // We could also use left partition count as they are necessarily equal. Partitioning::UnknownPartitioning(right_partitioning.partition_count()) } - } + }; + Ok(result) } pub(crate) fn asymmetric_join_output_partitioning( left: &Arc, right: &Arc, join_type: &JoinType, -) -> Partitioning { - match join_type { +) -> Result { + let result = match join_type { JoinType::Inner | JoinType::Right => adjust_right_output_partitioning( right.output_partitioning(), left.schema().fields().len(), - ), - JoinType::RightSemi | JoinType::RightAnti => right.output_partitioning().clone(), + )?, + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + right.output_partitioning().clone() + } JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti @@ -1333,7 +1348,8 @@ pub(crate) fn asymmetric_join_output_partitioning( | JoinType::LeftMark => Partitioning::UnknownPartitioning( right.output_partitioning().partition_count(), ), - } + }; + Ok(result) } /// Trait for incrementally generating Join output. @@ -1500,15 +1516,17 @@ pub(super) fn swap_join_projection( #[cfg(test)] mod tests { - use super::*; + use std::collections::HashMap; use std::pin::Pin; + use super::*; + use arrow::array::Int32Array; - use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Fields}; use arrow::error::{ArrowError, Result as ArrowResult}; use datafusion_common::stats::Precision::{Absent, Exact, Inexact}; use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; + use datafusion_physical_expr::PhysicalSortExpr; use rstest::rstest; @@ -2245,8 +2263,7 @@ mod tests { assert_eq!( output_cardinality, expected, - "failure for join_type: {}", - join_type + "failure for join_type: {join_type}" ); } @@ -2319,85 +2336,35 @@ mod tests { #[test] fn test_calculate_join_output_ordering() -> Result<()> { - let options = SortOptions::default(); let left_ordering = LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("d", 3)), - options, - }, + PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))), + PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))), + PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))), ]); let right_ordering = LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("z", 2)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("y", 1)), - options, - }, + PhysicalSortExpr::new_default(Arc::new(Column::new("z", 2))), + PhysicalSortExpr::new_default(Arc::new(Column::new("y", 1))), ]); let join_type = JoinType::Inner; - let on_columns = [( - Arc::new(Column::new("b", 1)) as _, - Arc::new(Column::new("x", 0)) as _, - )]; let left_columns_len = 5; let maintains_input_orders = [[true, false], [false, true]]; let probe_sides = [Some(JoinSide::Left), Some(JoinSide::Right)]; let expected = [ - Some(LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("d", 3)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("z", 7)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("y", 6)), - options, - }, - ])), - Some(LexOrdering::new(vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("z", 7)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("y", 6)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options, - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("d", 3)), - options, - }, - ])), + LexOrdering::new(vec![ + PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))), + PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))), + PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))), + PhysicalSortExpr::new_default(Arc::new(Column::new("z", 7))), + PhysicalSortExpr::new_default(Arc::new(Column::new("y", 6))), + ]), + LexOrdering::new(vec![ + PhysicalSortExpr::new_default(Arc::new(Column::new("z", 7))), + PhysicalSortExpr::new_default(Arc::new(Column::new("y", 6))), + PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))), + PhysicalSortExpr::new_default(Arc::new(Column::new("c", 2))), + PhysicalSortExpr::new_default(Arc::new(Column::new("d", 3))), + ]), ]; for (i, (maintains_input_order, probe_side)) in @@ -2408,11 +2375,10 @@ mod tests { left_ordering.as_ref(), right_ordering.as_ref(), join_type, - &on_columns, left_columns_len, maintains_input_order, probe_side, - ), + )?, expected[i] ); } @@ -2499,4 +2465,28 @@ mod tests { assert_eq!(col.name(), name); assert_eq!(col.index(), index); } + + #[test] + fn test_join_metadata() -> Result<()> { + let left_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]) + .with_metadata(HashMap::from([("key".to_string(), "left".to_string())])); + + let right_schema = Schema::new(vec![Field::new("b", DataType::Int32, false)]) + .with_metadata(HashMap::from([("key".to_string(), "right".to_string())])); + + let (join_schema, _) = + build_join_schema(&left_schema, &right_schema, &JoinType::Left); + assert_eq!( + join_schema.metadata(), + &HashMap::from([("key".to_string(), "left".to_string())]) + ); + let (join_schema, _) = + build_join_schema(&left_schema, &right_schema, &JoinType::Right); + assert_eq!( + join_schema.metadata(), + &HashMap::from([("key".to_string(), "right".to_string())]) + ); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index b256e615b232..5c0b231915cc 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -50,6 +50,7 @@ pub use crate::ordering::InputOrderMode; pub use crate::stream::EmptyRecordBatchStream; pub use crate::topk::TopK; pub use crate::visitor::{accept, visit_execution_plan, ExecutionPlanVisitor}; +pub use crate::work_table::WorkTable; pub use spill::spill_manager::SpillManager; mod ordering; @@ -59,14 +60,18 @@ mod visitor; pub mod aggregates; pub mod analyze; +pub mod async_func; +pub mod coalesce; pub mod coalesce_batches; pub mod coalesce_partitions; pub mod common; +pub mod coop; pub mod display; pub mod empty; pub mod execution_plan; pub mod explain; pub mod filter; +pub mod filter_pushdown; pub mod joins; pub mod limit; pub mod memory; @@ -90,6 +95,4 @@ pub mod udaf { pub use datafusion_physical_expr::aggregate::AggregateFunctionExpr; } -pub mod coalesce; -#[cfg(test)] pub mod test; diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index 89cf47a6d650..2224f85cc122 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -110,7 +110,7 @@ impl DisplayAs for GlobalLimitExec { } DisplayFormatType::TreeRender => { if let Some(fetch) = self.fetch { - writeln!(f, "limit={}", fetch)?; + writeln!(f, "limit={fetch}")?; } write!(f, "skip={}", self.skip) } @@ -164,10 +164,7 @@ impl ExecutionPlan for GlobalLimitExec { partition: usize, context: Arc, ) -> Result { - trace!( - "Start GlobalLimitExec::execute for partition: {}", - partition - ); + trace!("Start GlobalLimitExec::execute for partition: {partition}"); // GlobalLimitExec has a single output partition if 0 != partition { return internal_err!("GlobalLimitExec invalid partition {partition}"); @@ -193,8 +190,11 @@ impl ExecutionPlan for GlobalLimitExec { } fn statistics(&self) -> Result { - Statistics::with_fetch( - self.input.statistics()?, + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input.partition_statistics(partition)?.with_fetch( self.schema(), self.fetch, self.skip, @@ -334,8 +334,11 @@ impl ExecutionPlan for LocalLimitExec { } fn statistics(&self) -> Result { - Statistics::with_fetch( - self.input.statistics()?, + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input.partition_statistics(partition)?.with_fetch( self.schema(), Some(self.fetch), 0, @@ -765,7 +768,7 @@ mod tests { let offset = GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch); - Ok(offset.statistics()?.num_rows) + Ok(offset.partition_statistics(None)?.num_rows) } pub fn build_group_by( @@ -805,7 +808,7 @@ mod tests { fetch, ); - Ok(offset.statistics()?.num_rows) + Ok(offset.partition_statistics(None)?.num_rows) } async fn row_number_statistics_for_local_limit( @@ -818,7 +821,7 @@ mod tests { let offset = LocalLimitExec::new(csv, fetch); - Ok(offset.statistics()?.num_rows) + Ok(offset.partition_statistics(None)?.num_rows) } /// Return a RecordBatch with a single array with row_count sz diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 1bc872a56e76..3e5ea32a4cab 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -22,7 +22,9 @@ use std::fmt; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::execution_plan::{Boundedness, EmissionType}; +use crate::coop::cooperative; +use crate::execution_plan::{Boundedness, EmissionType, SchedulingType}; +use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, @@ -35,6 +37,7 @@ use datafusion_execution::memory_pool::MemoryReservation; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use futures::Stream; use parking_lot::RwLock; @@ -131,6 +134,10 @@ impl RecordBatchStream for MemoryStream { } pub trait LazyBatchGenerator: Send + Sync + fmt::Debug + fmt::Display { + fn boundedness(&self) -> Boundedness { + Boundedness::Bounded + } + /// Generate the next batch, return `None` when no more batches are available fn generate_next_batch(&mut self) -> Result>; } @@ -146,6 +153,8 @@ pub struct LazyMemoryExec { batch_generators: Vec>>, /// Plan properties cache storing equivalence properties, partitioning, and execution mode cache: PlanProperties, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, } impl LazyMemoryExec { @@ -154,18 +163,62 @@ impl LazyMemoryExec { schema: SchemaRef, generators: Vec>>, ) -> Result { + let boundedness = generators + .iter() + .map(|g| g.read().boundedness()) + .reduce(|acc, b| match acc { + Boundedness::Bounded => b, + Boundedness::Unbounded { + requires_infinite_memory, + } => { + let acc_infinite_memory = requires_infinite_memory; + match b { + Boundedness::Bounded => acc, + Boundedness::Unbounded { + requires_infinite_memory, + } => Boundedness::Unbounded { + requires_infinite_memory: requires_infinite_memory + || acc_infinite_memory, + }, + } + } + }) + .unwrap_or(Boundedness::Bounded); + let cache = PlanProperties::new( EquivalenceProperties::new(Arc::clone(&schema)), Partitioning::RoundRobinBatch(generators.len()), EmissionType::Incremental, - Boundedness::Bounded, - ); + boundedness, + ) + .with_scheduling_type(SchedulingType::Cooperative); + Ok(Self { schema, batch_generators: generators, cache, + metrics: ExecutionPlanMetricsSet::new(), }) } + + pub fn try_set_partitioning(&mut self, partitioning: Partitioning) -> Result<()> { + if partitioning.partition_count() != self.batch_generators.len() { + internal_err!( + "Partition count must match generator count: {} != {}", + partitioning.partition_count(), + self.batch_generators.len() + ) + } else { + self.cache.partitioning = partitioning; + Ok(()) + } + } + + pub fn add_ordering(&mut self, ordering: impl IntoIterator) { + self.cache + .eq_properties + .add_orderings(std::iter::once(ordering)); + } } impl fmt::Debug for LazyMemoryExec { @@ -254,10 +307,18 @@ impl ExecutionPlan for LazyMemoryExec { ); } - Ok(Box::pin(LazyMemoryStream { + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + + let stream = LazyMemoryStream { schema: Arc::clone(&self.schema), generator: Arc::clone(&self.batch_generators[partition]), - })) + baseline_metrics, + }; + Ok(Box::pin(cooperative(stream))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) } fn statistics(&self) -> Result { @@ -276,6 +337,8 @@ pub struct LazyMemoryStream { /// parallel execution. /// Sharing generators between streams should be used with caution. generator: Arc>, + /// Execution metrics + baseline_metrics: BaselineMetrics, } impl Stream for LazyMemoryStream { @@ -285,13 +348,16 @@ impl Stream for LazyMemoryStream { self: std::pin::Pin<&mut Self>, _: &mut Context<'_>, ) -> Poll> { + let _timer_guard = self.baseline_metrics.elapsed_compute().timer(); let batch = self.generator.write().generate_next_batch(); - match batch { + let poll = match batch { Ok(Some(batch)) => Poll::Ready(Some(Ok(batch))), Ok(None) => Poll::Ready(None), Err(e) => Poll::Ready(Some(Err(e))), - } + }; + + self.baseline_metrics.record_poll(poll) } } @@ -304,6 +370,7 @@ impl RecordBatchStream for LazyMemoryStream { #[cfg(test)] mod lazy_memory_tests { use super::*; + use crate::common::collect; use arrow::array::Int64Array; use arrow::datatypes::{DataType, Field, Schema}; use futures::StreamExt; @@ -419,4 +486,45 @@ mod lazy_memory_tests { Ok(()) } + + #[tokio::test] + async fn test_generate_series_metrics_integration() -> Result<()> { + // Test LazyMemoryExec metrics with different configurations + let test_cases = vec![ + (10, 2, 10), // 10 rows, batch size 2, expected 10 rows + (100, 10, 100), // 100 rows, batch size 10, expected 100 rows + (5, 1, 5), // 5 rows, batch size 1, expected 5 rows + ]; + + for (total_rows, batch_size, expected_rows) in test_cases { + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let generator = TestGenerator { + counter: 0, + max_batches: (total_rows + batch_size - 1) / batch_size, // ceiling division + batch_size: batch_size as usize, + schema: Arc::clone(&schema), + }; + + let exec = + LazyMemoryExec::try_new(schema, vec![Arc::new(RwLock::new(generator))])?; + let task_ctx = Arc::new(TaskContext::default()); + + let stream = exec.execute(0, task_ctx)?; + let batches = collect(stream).await?; + + // Verify metrics exist with actual expected numbers + let metrics = exec.metrics().unwrap(); + + // Count actual rows returned + let actual_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(actual_rows, expected_rows); + + // Verify metrics match actual output + assert_eq!(metrics.output_rows().unwrap(), expected_rows); + assert!(metrics.elapsed_compute().unwrap() > 0); + } + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/metrics/baseline.rs b/datafusion/physical-plan/src/metrics/baseline.rs index a4a83b84b655..a52336108a87 100644 --- a/datafusion/physical-plan/src/metrics/baseline.rs +++ b/datafusion/physical-plan/src/metrics/baseline.rs @@ -45,7 +45,7 @@ use datafusion_common::Result; /// ``` #[derive(Debug, Clone)] pub struct BaselineMetrics { - /// end_time is set when `ExecutionMetrics::done()` is called + /// end_time is set when `BaselineMetrics::done()` is called end_time: Timestamp, /// amount of time the operator was actively trying to use the CPU @@ -117,9 +117,10 @@ impl BaselineMetrics { } } - /// Process a poll result of a stream producing output for an - /// operator, recording the output rows and stream done time and - /// returning the same poll result + /// Process a poll result of a stream producing output for an operator. + /// + /// Note: this method only updates `output_rows` and `end_time` metrics. + /// Remember to update `elapsed_compute` and other metrics manually. pub fn record_poll( &self, poll: Poll>>, @@ -150,7 +151,7 @@ pub struct SpillMetrics { /// count of spills during the execution of the operator pub spill_file_count: Count, - /// total spilled bytes during the execution of the operator + /// total bytes actually written to disk during the execution of the operator pub spilled_bytes: Count, /// total spilled rows during the execution of the operator diff --git a/datafusion/physical-plan/src/metrics/custom.rs b/datafusion/physical-plan/src/metrics/custom.rs new file mode 100644 index 000000000000..546af6f3335e --- /dev/null +++ b/datafusion/physical-plan/src/metrics/custom.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Custom metric value type. + +use std::{any::Any, fmt::Debug, fmt::Display, sync::Arc}; + +/// A trait for implementing custom metric values. +/// +/// This trait enables defining application- or operator-specific metric types +/// that can be aggregated and displayed alongside standard metrics. These +/// custom metrics integrate with [`MetricValue::Custom`] and support +/// aggregation logic, introspection, and optional numeric representation. +/// +/// # Requirements +/// Implementations of `CustomMetricValue` must satisfy the following: +/// +/// 1. [`Self::aggregate`]: Defines how two metric values are combined +/// 2. [`Self::new_empty`]: Returns a new, zero-value instance for accumulation +/// 3. [`Self::as_any`]: Enables dynamic downcasting for type-specific operations +/// 4. [`Self::as_usize`]: Optionally maps the value to a `usize` (for sorting, display, etc.) +/// 5. [`Self::is_eq`]: Implements comparison between two values, this isn't reusing the std +/// PartialEq trait because this trait is used dynamically in the context of +/// [`MetricValue::Custom`] +/// +/// # Examples +/// ``` +/// # use std::sync::Arc; +/// # use std::fmt::{Debug, Display}; +/// # use std::any::Any; +/// # use std::sync::atomic::{AtomicUsize, Ordering}; +/// +/// # use datafusion_physical_plan::metrics::CustomMetricValue; +/// +/// #[derive(Debug, Default)] +/// struct MyCounter { +/// count: AtomicUsize, +/// } +/// +/// impl Display for MyCounter { +/// fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +/// write!(f, "count: {}", self.count.load(Ordering::Relaxed)) +/// } +/// } +/// +/// impl CustomMetricValue for MyCounter { +/// fn new_empty(&self) -> Arc { +/// Arc::new(Self::default()) +/// } +/// +/// fn aggregate(&self, other: Arc) { +/// let other = other.as_any().downcast_ref::().unwrap(); +/// self.count.fetch_add(other.count.load(Ordering::Relaxed), Ordering::Relaxed); +/// } +/// +/// fn as_any(&self) -> &dyn Any { +/// self +/// } +/// +/// fn as_usize(&self) -> usize { +/// self.count.load(Ordering::Relaxed) +/// } +/// +/// fn is_eq(&self, other: &Arc) -> bool { +/// let Some(other) = other.as_any().downcast_ref::() else { +/// return false; +/// }; +/// +/// self.count.load(Ordering::Relaxed) == other.count.load(Ordering::Relaxed) +/// } +/// } +/// ``` +/// +/// [`MetricValue::Custom`]: super::MetricValue::Custom +pub trait CustomMetricValue: Display + Debug + Send + Sync { + /// Returns a new, zero-initialized version of this metric value. + /// + /// This value is used during metric aggregation to accumulate results. + fn new_empty(&self) -> Arc; + + /// Merges another metric value into this one. + /// + /// The type of `other` could be of a different custom type as long as it's aggregatable into self. + fn aggregate(&self, other: Arc); + + /// Returns this value as a [`Any`] to support dynamic downcasting. + fn as_any(&self) -> &dyn Any; + + /// Optionally returns a numeric representation of the value, if meaningful. + /// Otherwise will default to zero. + /// + /// This is used for sorting and summarizing metrics. + fn as_usize(&self) -> usize { + 0 + } + + /// Compares this value with another custom value. + fn is_eq(&self, other: &Arc) -> bool; +} diff --git a/datafusion/physical-plan/src/metrics/mod.rs b/datafusion/physical-plan/src/metrics/mod.rs index 2ac7ac1299a0..87783eada8b0 100644 --- a/datafusion/physical-plan/src/metrics/mod.rs +++ b/datafusion/physical-plan/src/metrics/mod.rs @@ -19,6 +19,7 @@ mod baseline; mod builder; +mod custom; mod value; use parking_lot::Mutex; @@ -33,6 +34,7 @@ use datafusion_common::HashMap; // public exports pub use baseline::{BaselineMetrics, RecordOutput, SpillMetrics}; pub use builder::MetricBuilder; +pub use custom::CustomMetricValue; pub use value::{Count, Gauge, MetricValue, ScopedTimerGuard, Time, Timestamp}; /// Something that tracks a value of interest (metric) of a DataFusion @@ -263,6 +265,7 @@ impl MetricsSet { MetricValue::Gauge { name, .. } => name == metric_name, MetricValue::StartTimestamp(_) => false, MetricValue::EndTimestamp(_) => false, + MetricValue::Custom { .. } => false, }) } diff --git a/datafusion/physical-plan/src/metrics/value.rs b/datafusion/physical-plan/src/metrics/value.rs index decf77369db4..1cc4a4fbcb05 100644 --- a/datafusion/physical-plan/src/metrics/value.rs +++ b/datafusion/physical-plan/src/metrics/value.rs @@ -17,9 +17,14 @@ //! Value representation of metrics +use super::CustomMetricValue; +use chrono::{DateTime, Utc}; +use datafusion_common::instant::Instant; +use datafusion_execution::memory_pool::human_readable_size; +use parking_lot::Mutex; use std::{ borrow::{Borrow, Cow}, - fmt::Display, + fmt::{Debug, Display}, sync::{ atomic::{AtomicUsize, Ordering}, Arc, @@ -27,10 +32,6 @@ use std::{ time::Duration, }; -use chrono::{DateTime, Utc}; -use datafusion_common::instant::Instant; -use parking_lot::Mutex; - /// A counter to record things such as number of input or output rows /// /// Note `clone`ing counters update the same underlying metrics @@ -343,7 +344,7 @@ impl Drop for ScopedTimerGuard<'_> { /// Among other differences, the metric types have different ways to /// logically interpret their underlying values and some metrics are /// so common they are given special treatment. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub enum MetricValue { /// Number of output rows produced: "output_rows" metric OutputRows(Count), @@ -400,6 +401,78 @@ pub enum MetricValue { StartTimestamp(Timestamp), /// The time at which execution ended EndTimestamp(Timestamp), + Custom { + /// The provided name of this metric + name: Cow<'static, str>, + /// A custom implementation of the metric value. + value: Arc, + }, +} + +// Manually implement PartialEq for `MetricValue` because it contains CustomMetricValue in its +// definition which is a dyn trait. This wouldn't allow us to just derive PartialEq. +impl PartialEq for MetricValue { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (MetricValue::OutputRows(count), MetricValue::OutputRows(other)) => { + count == other + } + (MetricValue::ElapsedCompute(time), MetricValue::ElapsedCompute(other)) => { + time == other + } + (MetricValue::SpillCount(count), MetricValue::SpillCount(other)) => { + count == other + } + (MetricValue::SpilledBytes(count), MetricValue::SpilledBytes(other)) => { + count == other + } + (MetricValue::SpilledRows(count), MetricValue::SpilledRows(other)) => { + count == other + } + ( + MetricValue::CurrentMemoryUsage(gauge), + MetricValue::CurrentMemoryUsage(other), + ) => gauge == other, + ( + MetricValue::Count { name, count }, + MetricValue::Count { + name: other_name, + count: other_count, + }, + ) => name == other_name && count == other_count, + ( + MetricValue::Gauge { name, gauge }, + MetricValue::Gauge { + name: other_name, + gauge: other_gauge, + }, + ) => name == other_name && gauge == other_gauge, + ( + MetricValue::Time { name, time }, + MetricValue::Time { + name: other_name, + time: other_time, + }, + ) => name == other_name && time == other_time, + + ( + MetricValue::StartTimestamp(timestamp), + MetricValue::StartTimestamp(other), + ) => timestamp == other, + (MetricValue::EndTimestamp(timestamp), MetricValue::EndTimestamp(other)) => { + timestamp == other + } + ( + MetricValue::Custom { name, value }, + MetricValue::Custom { + name: other_name, + value: other_value, + }, + ) => name == other_name && value.is_eq(other_value), + // Default case when the two sides do not have the same type. + _ => false, + } + } } impl MetricValue { @@ -417,6 +490,7 @@ impl MetricValue { Self::Time { name, .. } => name.borrow(), Self::StartTimestamp(_) => "start_timestamp", Self::EndTimestamp(_) => "end_timestamp", + Self::Custom { name, .. } => name.borrow(), } } @@ -442,6 +516,7 @@ impl MetricValue { .and_then(|ts| ts.timestamp_nanos_opt()) .map(|nanos| nanos as usize) .unwrap_or(0), + Self::Custom { value, .. } => value.as_usize(), } } @@ -469,6 +544,10 @@ impl MetricValue { }, Self::StartTimestamp(_) => Self::StartTimestamp(Timestamp::new()), Self::EndTimestamp(_) => Self::EndTimestamp(Timestamp::new()), + Self::Custom { name, value } => Self::Custom { + name: name.clone(), + value: value.new_empty(), + }, } } @@ -515,6 +594,14 @@ impl MetricValue { (Self::EndTimestamp(timestamp), Self::EndTimestamp(other_timestamp)) => { timestamp.update_to_max(other_timestamp); } + ( + Self::Custom { value, .. }, + Self::Custom { + value: other_value, .. + }, + ) => { + value.aggregate(Arc::clone(other_value)); + } m @ (_, _) => { panic!( "Mismatched metric types. Can not aggregate {:?} with value {:?}", @@ -539,6 +626,7 @@ impl MetricValue { Self::Time { .. } => 8, Self::StartTimestamp(_) => 9, // show timestamps last Self::EndTimestamp(_) => 10, + Self::Custom { .. } => 11, } } @@ -554,11 +642,14 @@ impl Display for MetricValue { match self { Self::OutputRows(count) | Self::SpillCount(count) - | Self::SpilledBytes(count) | Self::SpilledRows(count) | Self::Count { count, .. } => { write!(f, "{count}") } + Self::SpilledBytes(count) => { + let readable_count = human_readable_size(count.value()); + write!(f, "{readable_count}") + } Self::CurrentMemoryUsage(gauge) | Self::Gauge { gauge, .. } => { write!(f, "{gauge}") } @@ -574,16 +665,103 @@ impl Display for MetricValue { Self::StartTimestamp(timestamp) | Self::EndTimestamp(timestamp) => { write!(f, "{timestamp}") } + Self::Custom { name, value } => { + write!(f, "name:{name} {value}") + } } } } #[cfg(test)] mod tests { + use std::any::Any; + use chrono::TimeZone; + use datafusion_execution::memory_pool::units::MB; use super::*; + #[derive(Debug, Default)] + pub struct CustomCounter { + count: AtomicUsize, + } + + impl Display for CustomCounter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "count: {}", self.count.load(Ordering::Relaxed)) + } + } + + impl CustomMetricValue for CustomCounter { + fn new_empty(&self) -> Arc { + Arc::new(CustomCounter::default()) + } + + fn aggregate(&self, other: Arc) { + let other = other.as_any().downcast_ref::().unwrap(); + self.count + .fetch_add(other.count.load(Ordering::Relaxed), Ordering::Relaxed); + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn is_eq(&self, other: &Arc) -> bool { + let Some(other) = other.as_any().downcast_ref::() else { + return false; + }; + + self.count.load(Ordering::Relaxed) == other.count.load(Ordering::Relaxed) + } + } + + fn new_custom_counter(name: &'static str, value: usize) -> MetricValue { + let custom_counter = CustomCounter::default(); + custom_counter.count.fetch_add(value, Ordering::Relaxed); + let custom_val = MetricValue::Custom { + name: Cow::Borrowed(name), + value: Arc::new(custom_counter), + }; + + custom_val + } + + #[test] + fn test_custom_metric_with_mismatching_names() { + let mut custom_val = new_custom_counter("Hi", 1); + let other_custom_val = new_custom_counter("Hello", 1); + + // Not equal since the name differs. + assert!(other_custom_val != custom_val); + + // Should work even though the name differs + custom_val.aggregate(&other_custom_val); + + let expected_val = new_custom_counter("Hi", 2); + assert!(expected_val == custom_val); + } + + #[test] + fn test_custom_metric() { + let mut custom_val = new_custom_counter("hi", 11); + let other_custom_val = new_custom_counter("hi", 20); + + custom_val.aggregate(&other_custom_val); + + assert!(custom_val != other_custom_val); + + if let MetricValue::Custom { value, .. } = custom_val { + let counter = value + .as_any() + .downcast_ref::() + .expect("Expected CustomCounter"); + assert_eq!(counter.count.load(Ordering::Relaxed), 31); + } else { + panic!("Unexpected value"); + } + } + #[test] fn test_display_output_rows() { let count = Count::new(); @@ -605,6 +783,20 @@ mod tests { } } + #[test] + fn test_display_spilled_bytes() { + let count = Count::new(); + let spilled_byte = MetricValue::SpilledBytes(count.clone()); + + assert_eq!("0.0 B", spilled_byte.to_string()); + + count.add((100 * MB) as usize); + assert_eq!("100.0 MB", spilled_byte.to_string()); + + count.add((0.5 * MB as f64) as usize); + assert_eq!("100.5 MB", spilled_byte.to_string()); + } + #[test] fn test_display_time() { let time = Time::new(); diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index eecd980d09f8..6cd581700a88 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -20,12 +20,15 @@ use std::any::Any; use std::sync::Arc; -use crate::execution_plan::{Boundedness, EmissionType}; +use crate::coop::cooperative; +use crate::execution_plan::{Boundedness, EmissionType, SchedulingType}; use crate::memory::MemoryStream; -use crate::{common, DisplayAs, PlanProperties, SendableRecordBatchStream, Statistics}; -use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; -use arrow::array::{ArrayRef, NullArray}; -use arrow::array::{RecordBatch, RecordBatchOptions}; +use crate::{ + common, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, Statistics, +}; + +use arrow::array::{ArrayRef, NullArray, RecordBatch, RecordBatchOptions}; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; @@ -99,6 +102,7 @@ impl PlaceholderRowExec { EmissionType::Incremental, Boundedness::Bounded, ) + .with_scheduling_type(SchedulingType::Cooperative) } } @@ -158,14 +162,18 @@ impl ExecutionPlan for PlaceholderRowExec { ); } - Ok(Box::pin(MemoryStream::try_new( - self.data()?, - Arc::clone(&self.schema), - None, - )?)) + let ms = MemoryStream::try_new(self.data()?, Arc::clone(&self.schema), None)?; + Ok(Box::pin(cooperative(ms))) } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema())); + } let batch = self .data() .expect("Create single row placeholder RecordBatch should not fail"); diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 72934c74446e..a29f4aeb4090 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -26,7 +26,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use super::expressions::{CastExpr, Column, Literal}; +use super::expressions::{Column, Literal}; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, @@ -46,11 +46,10 @@ use datafusion_common::{internal_err, JoinSide, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr::PhysicalExprRef; +use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; -use datafusion_physical_expr_common::physical_expr::fmt_sql; use futures::stream::{Stream, StreamExt}; -use itertools::Itertools; use log::trace; /// Execution plan for a projection @@ -79,14 +78,14 @@ impl ProjectionExec { let fields: Result> = expr .iter() .map(|(e, name)| { - let mut field = Field::new( + let metadata = e.return_field(&input_schema)?.metadata().clone(); + + let field = Field::new( name, e.data_type(&input_schema)?, e.nullable(&input_schema)?, - ); - field.set_metadata( - get_field_metadata(e, &input_schema).unwrap_or_default(), - ); + ) + .with_metadata(metadata); Ok(field) }) @@ -98,7 +97,7 @@ impl ProjectionExec { )); // Construct a map from the input expressions to the output expression of the Projection - let projection_mapping = ProjectionMapping::try_new(&expr, &input_schema)?; + let projection_mapping = ProjectionMapping::try_new(expr.clone(), &input_schema)?; let cache = Self::compute_properties(&input, &projection_mapping, Arc::clone(&schema))?; Ok(Self { @@ -127,14 +126,12 @@ impl ProjectionExec { schema: SchemaRef, ) -> Result { // Calculate equivalence properties: - let mut input_eq_properties = input.equivalence_properties().clone(); - input_eq_properties.substitute_oeq_class(projection_mapping)?; + let input_eq_properties = input.equivalence_properties(); let eq_properties = input_eq_properties.project(projection_mapping, schema); - // Calculate output partitioning, which needs to respect aliases: - let input_partition = input.output_partitioning(); - let output_partitioning = - input_partition.project(projection_mapping, &input_eq_properties); + let output_partitioning = input + .output_partitioning() + .project(projection_mapping, input_eq_properties); Ok(PlanProperties::new( eq_properties, @@ -198,23 +195,11 @@ impl ExecutionPlan for ProjectionExec { &self.cache } - fn children(&self) -> Vec<&Arc> { - vec![&self.input] - } - fn maintains_input_order(&self) -> Vec { // Tell optimizer this operator doesn't reorder its input vec![true] } - fn with_new_children( - self: Arc, - mut children: Vec>, - ) -> Result> { - ProjectionExec::try_new(self.expr.clone(), children.swap_remove(0)) - .map(|p| Arc::new(p) as _) - } - fn benefits_from_input_partitioning(&self) -> Vec { let all_simple_exprs = self .expr @@ -225,6 +210,18 @@ impl ExecutionPlan for ProjectionExec { vec![!all_simple_exprs] } + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + ProjectionExec::try_new(self.expr.clone(), children.swap_remove(0)) + .map(|p| Arc::new(p) as _) + } + fn execute( &self, partition: usize, @@ -244,8 +241,13 @@ impl ExecutionPlan for ProjectionExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + let input_stats = self.input.partition_statistics(partition)?; Ok(stats_projection( - self.input.statistics()?, + input_stats, self.expr.iter().map(|(e, _)| Arc::clone(e)), Arc::clone(&self.schema), )) @@ -273,24 +275,6 @@ impl ExecutionPlan for ProjectionExec { } } -/// If 'e' is a direct column reference, returns the field level -/// metadata for that field, if any. Otherwise returns None -pub(crate) fn get_field_metadata( - e: &Arc, - input_schema: &Schema, -) -> Option> { - if let Some(cast) = e.as_any().downcast_ref::() { - return get_field_metadata(cast.expr(), input_schema); - } - - // Look up field by index in schema (not NAME as there can be more than one - // column with the same name) - e.as_any() - .downcast_ref::() - .map(|column| input_schema.field(column.index()).metadata()) - .cloned() -} - fn stats_projection( mut stats: Statistics, exprs: impl Iterator>, @@ -538,7 +522,7 @@ pub fn remove_unnecessary_projections( } else { return Ok(Transformed::no(plan)); }; - Ok(maybe_modified.map_or(Transformed::no(plan), Transformed::yes)) + Ok(maybe_modified.map_or_else(|| Transformed::no(plan), Transformed::yes)) } /// Compare the inputs and outputs of the projection. All expressions must be @@ -635,7 +619,7 @@ pub fn update_expr( let mut state = RewriteState::Unchanged; let new_expr = Arc::clone(expr) - .transform_up(|expr: Arc| { + .transform_up(|expr| { if state == RewriteState::RewrittenInvalid { return Ok(Transformed::no(expr)); } @@ -679,6 +663,42 @@ pub fn update_expr( new_expr.map(|e| (state == RewriteState::RewrittenValid).then_some(e)) } +/// Updates the given lexicographic ordering according to given projected +/// expressions using the [`update_expr`] function. +pub fn update_ordering( + ordering: LexOrdering, + projected_exprs: &[(Arc, String)], +) -> Result> { + let mut updated_exprs = vec![]; + for mut sort_expr in ordering.into_iter() { + let Some(updated_expr) = update_expr(&sort_expr.expr, projected_exprs, false)? + else { + return Ok(None); + }; + sort_expr.expr = updated_expr; + updated_exprs.push(sort_expr); + } + Ok(LexOrdering::new(updated_exprs)) +} + +/// Updates the given lexicographic requirement according to given projected +/// expressions using the [`update_expr`] function. +pub fn update_ordering_requirement( + reqs: LexRequirement, + projected_exprs: &[(Arc, String)], +) -> Result> { + let mut updated_exprs = vec![]; + for mut sort_expr in reqs.into_iter() { + let Some(updated_expr) = update_expr(&sort_expr.expr, projected_exprs, false)? + else { + return Ok(None); + }; + sort_expr.expr = updated_expr; + updated_exprs.push(sort_expr); + } + Ok(LexRequirement::new(updated_exprs)) +} + /// Downcasts all the expressions in `exprs` to `Column`s. If any of the given /// expressions is not a `Column`, returns `None`. pub fn physical_to_column_exprs( @@ -713,7 +733,7 @@ pub fn new_join_children( alias.clone(), ) }) - .collect_vec(), + .collect(), Arc::clone(left_child), )?; let left_size = left_child.schema().fields().len() as i32; @@ -731,7 +751,7 @@ pub fn new_join_children( alias.clone(), ) }) - .collect_vec(), + .collect(), Arc::clone(right_child), )?; @@ -1093,13 +1113,11 @@ mod tests { let task_ctx = Arc::new(TaskContext::default()); let exec = test::scan_partitioned(1); - let expected = collect(exec.execute(0, Arc::clone(&task_ctx))?) - .await - .unwrap(); + let expected = collect(exec.execute(0, Arc::clone(&task_ctx))?).await?; let projection = ProjectionExec::try_new(vec![], exec)?; let stream = projection.execute(0, Arc::clone(&task_ctx))?; - let output = collect(stream).await.unwrap(); + let output = collect(stream).await?; assert_eq!(output.len(), expected.len()); Ok(()) diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 7268735ea457..99b460dfcfdc 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -184,8 +184,7 @@ impl ExecutionPlan for RecursiveQueryExec { // TODO: we might be able to handle multiple partitions in the future. if partition != 0 { return Err(DataFusionError::Internal(format!( - "RecursiveQueryExec got an invalid partition {} (expected 0)", - partition + "RecursiveQueryExec got an invalid partition {partition} (expected 0)" ))); } @@ -352,16 +351,16 @@ fn assign_work_table( ) -> Result> { let mut work_table_refs = 0; plan.transform_down(|plan| { - if let Some(exec) = plan.as_any().downcast_ref::() { + if let Some(new_plan) = + plan.with_new_state(Arc::clone(&work_table) as Arc) + { if work_table_refs > 0 { not_impl_err!( "Multiple recursive references to the same CTE are not supported" ) } else { work_table_refs += 1; - Ok(Transformed::yes(Arc::new( - exec.with_work_table(Arc::clone(&work_table)), - ))) + Ok(Transformed::yes(new_plan)) } } else if plan.as_any().is::() { not_impl_err!("Recursive queries cannot be nested") diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index f7055d814d02..620bfa2809a9 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -19,6 +19,7 @@ //! partitions to M output partitions based on a partitioning scheme, optionally //! maintaining the order of the input rows in the output. +use std::fmt::{Debug, Formatter}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -29,7 +30,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream, }; -use crate::execution_plan::CardinalityEffect; +use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType}; use crate::hash_utils::create_hashes; use crate::metrics::BaselineMetrics; use crate::projection::{all_columns, make_with_child, update_expr, ProjectionExec}; @@ -43,8 +44,9 @@ use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Stat use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions}; use arrow::compute::take_arrays; use arrow::datatypes::{SchemaRef, UInt32Type}; +use datafusion_common::config::ConfigOptions; use datafusion_common::utils::transpose; -use datafusion_common::HashMap; +use datafusion_common::{internal_err, HashMap}; use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::memory_pool::MemoryConsumer; @@ -52,6 +54,10 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, +}; use futures::stream::Stream; use futures::{FutureExt, StreamExt, TryStreamExt}; use log::trace; @@ -63,9 +69,8 @@ type MaybeBatch = Option>; type InputPartitionsToCurrentPartitionSender = Vec>; type InputPartitionsToCurrentPartitionReceiver = Vec>; -/// Inner state of [`RepartitionExec`]. #[derive(Debug)] -struct RepartitionExecState { +struct ConsumingInputStreamsState { /// Channels for sending batches from input partitions to output partitions. /// Key is the partition number. channels: HashMap< @@ -81,16 +86,97 @@ struct RepartitionExecState { abort_helper: Arc>>, } +/// Inner state of [`RepartitionExec`]. +enum RepartitionExecState { + /// Not initialized yet. This is the default state stored in the RepartitionExec node + /// upon instantiation. + NotInitialized, + /// Input streams are initialized, but they are still not being consumed. The node + /// transitions to this state when the arrow's RecordBatch stream is created in + /// RepartitionExec::execute(), but before any message is polled. + InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>), + /// The input streams are being consumed. The node transitions to this state when + /// the first message in the arrow's RecordBatch stream is consumed. + ConsumingInputStreams(ConsumingInputStreamsState), +} + +impl Default for RepartitionExecState { + fn default() -> Self { + Self::NotInitialized + } +} + +impl Debug for RepartitionExecState { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + RepartitionExecState::NotInitialized => write!(f, "NotInitialized"), + RepartitionExecState::InputStreamsInitialized(v) => { + write!(f, "InputStreamsInitialized({:?})", v.len()) + } + RepartitionExecState::ConsumingInputStreams(v) => { + write!(f, "ConsumingInputStreams({v:?})") + } + } + } +} + impl RepartitionExecState { - fn new( + fn ensure_input_streams_initialized( + &mut self, input: Arc, - partitioning: Partitioning, metrics: ExecutionPlanMetricsSet, + output_partitions: usize, + ctx: Arc, + ) -> Result<()> { + if !matches!(self, RepartitionExecState::NotInitialized) { + return Ok(()); + } + + let num_input_partitions = input.output_partitioning().partition_count(); + let mut streams_and_metrics = Vec::with_capacity(num_input_partitions); + + for i in 0..num_input_partitions { + let metrics = RepartitionMetrics::new(i, output_partitions, &metrics); + + let timer = metrics.fetch_time.timer(); + let stream = input.execute(i, Arc::clone(&ctx))?; + timer.done(); + + streams_and_metrics.push((stream, metrics)); + } + *self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics); + Ok(()) + } + + fn consume_input_streams( + &mut self, + input: Arc, + metrics: ExecutionPlanMetricsSet, + partitioning: Partitioning, preserve_order: bool, name: String, context: Arc, - ) -> Self { - let num_input_partitions = input.output_partitioning().partition_count(); + ) -> Result<&mut ConsumingInputStreamsState> { + let streams_and_metrics = match self { + RepartitionExecState::NotInitialized => { + self.ensure_input_streams_initialized( + input, + metrics, + partitioning.partition_count(), + Arc::clone(&context), + )?; + let RepartitionExecState::InputStreamsInitialized(value) = self else { + // This cannot happen, as ensure_input_streams_initialized() was just called, + // but the compiler does not know. + return internal_err!("Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized"); + }; + value + } + RepartitionExecState::ConsumingInputStreams(value) => return Ok(value), + RepartitionExecState::InputStreamsInitialized(value) => value, + }; + + let num_input_partitions = streams_and_metrics.len(); let num_output_partitions = partitioning.partition_count(); let (txs, rxs) = if preserve_order { @@ -117,7 +203,7 @@ impl RepartitionExecState { let mut channels = HashMap::with_capacity(txs.len()); for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() { let reservation = Arc::new(Mutex::new( - MemoryConsumer::new(format!("{}[{partition}]", name)) + MemoryConsumer::new(format!("{name}[{partition}]")) .register(context.memory_pool()), )); channels.insert(partition, (tx, rx, reservation)); @@ -125,7 +211,9 @@ impl RepartitionExecState { // launch one async task per *input* partition let mut spawned_tasks = Vec::with_capacity(num_input_partitions); - for i in 0..num_input_partitions { + for (i, (stream, metrics)) in + std::mem::take(streams_and_metrics).into_iter().enumerate() + { let txs: HashMap<_, _> = channels .iter() .map(|(partition, (tx, _rx, reservation))| { @@ -133,15 +221,11 @@ impl RepartitionExecState { }) .collect(); - let r_metrics = RepartitionMetrics::new(i, num_output_partitions, &metrics); - let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input( - Arc::clone(&input), - i, + stream, txs.clone(), partitioning.clone(), - r_metrics, - Arc::clone(&context), + metrics, )); // In a separate task, wait for each input to be done @@ -154,28 +238,17 @@ impl RepartitionExecState { )); spawned_tasks.push(wait_for_task); } - - Self { + *self = Self::ConsumingInputStreams(ConsumingInputStreamsState { channels, abort_helper: Arc::new(spawned_tasks), + }); + match self { + RepartitionExecState::ConsumingInputStreams(value) => Ok(value), + _ => unreachable!(), } } } -/// Lazily initialized state -/// -/// Note that the state is initialized ONCE for all partitions by a single task(thread). -/// This may take a short while. It is also like that multiple threads -/// call execute at the same time, because we have just started "target partitions" tasks -/// which is commonly set to the number of CPU cores and all call execute at the same time. -/// -/// Thus, use a **tokio** `OnceCell` for this initialization so as not to waste CPU cycles -/// in a mutex lock but instead allow other threads to do something useful. -/// -/// Uses a parking_lot `Mutex` to control other accesses as they are very short duration -/// (e.g. removing channels on completion) where the overhead of `await` is not warranted. -type LazyState = Arc>>; - /// A utility that can be used to partition batches based on [`Partitioning`] pub struct BatchPartitioner { state: BatchPartitionerState, @@ -402,8 +475,9 @@ impl BatchPartitioner { pub struct RepartitionExec { /// Input execution plan input: Arc, - /// Inner state that is initialized when the first output stream is created. - state: LazyState, + /// Inner state that is initialized when the parent calls .execute() on this node + /// and consumed as soon as the parent starts consuming this node. + state: Arc>, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Boolean flag to decide whether to preserve ordering. If true means @@ -482,11 +556,7 @@ impl RepartitionExec { } impl DisplayAs for RepartitionExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!( @@ -508,11 +578,17 @@ impl DisplayAs for RepartitionExec { } DisplayFormatType::TreeRender => { writeln!(f, "partitioning_scheme={}", self.partitioning(),)?; + + let input_partition_count = + self.input.output_partitioning().partition_count(); + let output_partition_count = self.partitioning().partition_count(); + let input_to_output_partition_str = + format!("{input_partition_count} -> {output_partition_count}"); writeln!( f, - "input_partition_count={}", - self.input.output_partitioning().partition_count() + "partition_count(in->out)={input_to_output_partition_str}" )?; + if self.preserve_order { writeln!(f, "preserve_order={}", self.preserve_order)?; } @@ -573,42 +649,42 @@ impl ExecutionPlan for RepartitionExec { partition ); - let lazy_state = Arc::clone(&self.state); let input = Arc::clone(&self.input); let partitioning = self.partitioning().clone(); let metrics = self.metrics.clone(); - let preserve_order = self.preserve_order; + let preserve_order = self.sort_exprs().is_some(); let name = self.name().to_owned(); let schema = self.schema(); let schema_captured = Arc::clone(&schema); // Get existing ordering to use for merging - let sort_exprs = self.sort_exprs().cloned().unwrap_or_default(); + let sort_exprs = self.sort_exprs().cloned(); + + let state = Arc::clone(&self.state); + if let Some(mut state) = state.try_lock() { + state.ensure_input_streams_initialized( + Arc::clone(&input), + metrics.clone(), + partitioning.partition_count(), + Arc::clone(&context), + )?; + } let stream = futures::stream::once(async move { let num_input_partitions = input.output_partitioning().partition_count(); - let input_captured = Arc::clone(&input); - let metrics_captured = metrics.clone(); - let name_captured = name.clone(); - let context_captured = Arc::clone(&context); - let state = lazy_state - .get_or_init(|| async move { - Mutex::new(RepartitionExecState::new( - input_captured, - partitioning, - metrics_captured, - preserve_order, - name_captured, - context_captured, - )) - }) - .await; - // lock scope let (mut rx, reservation, abort_helper) = { // lock mutexes let mut state = state.lock(); + let state = state.consume_input_streams( + Arc::clone(&input), + metrics.clone(), + partitioning, + preserve_order, + name.clone(), + Arc::clone(&context), + )?; // now return stream for the specified *output* partition which will // read from the channel @@ -621,9 +697,7 @@ impl ExecutionPlan for RepartitionExec { }; trace!( - "Before returning stream in {}::execute for partition: {}", - name, - partition + "Before returning stream in {name}::execute for partition: {partition}" ); if preserve_order { @@ -645,12 +719,12 @@ impl ExecutionPlan for RepartitionExec { // input partitions to this partition: let fetch = None; let merge_reservation = - MemoryConsumer::new(format!("{}[Merge {partition}]", name)) + MemoryConsumer::new(format!("{name}[Merge {partition}]")) .register(context.memory_pool()); StreamingMergeBuilder::new() .with_streams(input_streams) .with_schema(schema_captured) - .with_expressions(&sort_exprs) + .with_expressions(&sort_exprs.unwrap()) .with_metrics(BaselineMetrics::new(&metrics, partition)) .with_batch_size(context.session_config().batch_size()) .with_fetch(fetch) @@ -677,7 +751,15 @@ impl ExecutionPlan for RepartitionExec { } fn statistics(&self) -> Result { - self.input.statistics() + self.input.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_none() { + self.input.partition_statistics(None) + } else { + Ok(Statistics::new_unknown(&self.schema())) + } } fn cardinality_effect(&self) -> CardinalityEffect { @@ -723,6 +805,27 @@ impl ExecutionPlan for RepartitionExec { new_partitioning, )?))) } + + fn gather_filters_for_pushdown( + &self, + _phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + Ok(FilterDescription::new_with_child_count(1) + .all_parent_filters_supported(parent_filters)) + } + + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::transparent( + child_pushdown_result, + )) + } } impl RepartitionExec { @@ -783,6 +886,8 @@ impl RepartitionExec { input.pipeline_behavior(), input.boundedness(), ) + .with_scheduling_type(SchedulingType::Cooperative) + .with_evaluation_type(EvaluationType::Eager) } /// Specify if this repartitioning operation should preserve the order of @@ -818,24 +923,17 @@ impl RepartitionExec { /// /// txs hold the output sending channels for each output partition async fn pull_from_input( - input: Arc, - partition: usize, + mut stream: SendableRecordBatchStream, mut output_channels: HashMap< usize, (DistributionSender, SharedMemoryReservation), >, partitioning: Partitioning, metrics: RepartitionMetrics, - context: Arc, ) -> Result<()> { let mut partitioner = BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?; - // execute the child operator - let timer = metrics.fetch_time.timer(); - let mut stream = input.execute(partition, context)?; - timer.done(); - // While there are still outputs to send to, keep pulling inputs let mut batches_until_yield = partitioner.num_partitions(); while !output_channels.is_empty() { @@ -1083,6 +1181,7 @@ mod tests { use datafusion_common_runtime::JoinSet; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use insta::assert_snapshot; + use itertools::Itertools; #[tokio::test] async fn one_to_many_round_robin() -> Result<()> { @@ -1263,15 +1362,9 @@ mod tests { let partitioning = Partitioning::RoundRobinBatch(1); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); - // Note: this should pass (the stream can be created) but the - // error when the input is executed should get passed back - let output_stream = exec.execute(0, task_ctx).unwrap(); - // Expect that an error is returned - let result_string = crate::common::collect(output_stream) - .await - .unwrap_err() - .to_string(); + let result_string = exec.execute(0, task_ctx).err().unwrap().to_string(); + assert!( result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"), "actual: {result_string}" @@ -1461,7 +1554,14 @@ mod tests { }); let batches_with_drop = crate::common::collect(output_stream1).await.unwrap(); - assert_eq!(batches_without_drop, batches_with_drop); + fn sort(batch: Vec) -> Vec { + batch + .into_iter() + .sorted_by_key(|b| format!("{b:?}")) + .collect() + } + + assert_eq!(sort(batches_without_drop), sort(batches_with_drop)); } fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> { @@ -1715,11 +1815,11 @@ mod test { } fn sort_exprs(schema: &Schema) -> LexOrdering { - let options = SortOptions::default(); - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("c0", schema).unwrap(), - options, - }]) + options: SortOptions::default(), + }] + .into() } fn memory_exec(schema: &SchemaRef) -> Arc { diff --git a/datafusion/physical-plan/src/sorts/cursor.rs b/datafusion/physical-plan/src/sorts/cursor.rs index efb9c0a47bf5..17033e6a3142 100644 --- a/datafusion/physical-plan/src/sorts/cursor.rs +++ b/datafusion/physical-plan/src/sorts/cursor.rs @@ -293,14 +293,19 @@ impl CursorValues for StringViewArray { self.views().len() } + #[inline(always)] fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool { // SAFETY: Both l_idx and r_idx are guaranteed to be within bounds, // and any null-checks are handled in the outer layers. // Fast path: Compare the lengths before full byte comparison. - let l_view = unsafe { l.views().get_unchecked(l_idx) }; - let l_len = *l_view as u32; let r_view = unsafe { r.views().get_unchecked(r_idx) }; + + if l.data_buffers().is_empty() && r.data_buffers().is_empty() { + return l_view.eq(r_view); + } + + let l_len = *l_view as u32; let r_len = *r_view as u32; if l_len != r_len { return false; @@ -309,13 +314,19 @@ impl CursorValues for StringViewArray { unsafe { GenericByteViewArray::compare_unchecked(l, l_idx, r, r_idx).is_eq() } } + #[inline(always)] fn eq_to_previous(cursor: &Self, idx: usize) -> bool { // SAFETY: The caller guarantees that idx > 0 and the indices are valid. // Already checked it in is_eq_to_prev_one function // Fast path: Compare the lengths of the current and previous views. let l_view = unsafe { cursor.views().get_unchecked(idx) }; - let l_len = *l_view as u32; let r_view = unsafe { cursor.views().get_unchecked(idx - 1) }; + if cursor.data_buffers().is_empty() { + return l_view.eq(r_view); + } + + let l_len = *l_view as u32; + let r_len = *r_view as u32; if l_len != r_len { return false; @@ -326,10 +337,21 @@ impl CursorValues for StringViewArray { } } + #[inline(always)] fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { // SAFETY: Prior assertions guarantee that l_idx and r_idx are valid indices. // Null-checks are assumed to have been handled in the wrapper (e.g., ArrayValues). // And the bound is checked in is_finished, it is safe to call get_unchecked + if l.data_buffers().is_empty() && r.data_buffers().is_empty() { + let l_view = unsafe { l.views().get_unchecked(l_idx) }; + let r_view = unsafe { r.views().get_unchecked(r_idx) }; + let l_len = *l_view as u32; + let r_len = *r_view as u32; + let l_data = unsafe { StringViewArray::inline_value(l_view, l_len as usize) }; + let r_data = unsafe { StringViewArray::inline_value(r_view, r_len as usize) }; + return l_data.cmp(r_data); + } + unsafe { GenericByteViewArray::compare_unchecked(l, l_idx, r, r_idx) } } } diff --git a/datafusion/physical-plan/src/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs index 2b42457635f7..0c18a3b6c703 100644 --- a/datafusion/physical-plan/src/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -18,7 +18,6 @@ //! Merge that deals with an arbitrary size of streaming inputs. //! This is an order-preserving merge. -use std::collections::VecDeque; use std::pin::Pin; use std::sync::Arc; use std::task::{ready, Context, Poll}; @@ -143,11 +142,8 @@ pub(crate) struct SortPreservingMergeStream { /// number of rows produced produced: usize, - /// This queue contains partition indices in order. When a partition is polled and returns `Poll::Ready`, - /// it is removed from the vector. If a partition returns `Poll::Pending`, it is moved to the end of the - /// vector to ensure the next iteration starts with a different partition, preventing the same partition - /// from being continuously polled. - uninitiated_partitions: VecDeque, + /// This vector contains the indices of the partitions that have not started emitting yet. + uninitiated_partitions: Vec, } impl SortPreservingMergeStream { @@ -216,36 +212,50 @@ impl SortPreservingMergeStream { // Once all partitions have set their corresponding cursors for the loser tree, // we skip the following block. Until then, this function may be called multiple // times and can return Poll::Pending if any partition returns Poll::Pending. + if self.loser_tree.is_empty() { - while let Some(&partition_idx) = self.uninitiated_partitions.front() { + // Manual indexing since we're iterating over the vector and shrinking it in the loop + let mut idx = 0; + while idx < self.uninitiated_partitions.len() { + let partition_idx = self.uninitiated_partitions[idx]; match self.maybe_poll_stream(cx, partition_idx) { Poll::Ready(Err(e)) => { self.aborted = true; return Poll::Ready(Some(Err(e))); } Poll::Pending => { - // If a partition returns Poll::Pending, to avoid continuously polling it - // and potentially increasing upstream buffer sizes, we move it to the - // back of the polling queue. - self.uninitiated_partitions.rotate_left(1); - - // This function could remain in a pending state, so we manually wake it here. - // However, this approach can be investigated further to find a more natural way - // to avoid disrupting the runtime scheduler. - cx.waker().wake_by_ref(); - return Poll::Pending; + // The polled stream is pending which means we're already set up to + // be woken when necessary + // Try the next stream + idx += 1; } _ => { - // If the polling result is Poll::Ready(Some(batch)) or Poll::Ready(None), - // we remove this partition from the queue so it is not polled again. - self.uninitiated_partitions.pop_front(); + // The polled stream is ready + // Remove it from uninitiated_partitions + // Don't bump idx here, since a new element will have taken its + // place which we'll try in the next loop iteration + // swap_remove will change the partition poll order, but that shouldn't + // make a difference since we're waiting for all streams to be ready. + self.uninitiated_partitions.swap_remove(idx); } } } - // Claim the memory for the uninitiated partitions - self.uninitiated_partitions.shrink_to_fit(); - self.init_loser_tree(); + if self.uninitiated_partitions.is_empty() { + // If there are no more uninitiated partitions, set up the loser tree and continue + // to the next phase. + + // Claim the memory for the uninitiated partitions + self.uninitiated_partitions.shrink_to_fit(); + self.init_loser_tree(); + } else { + // There are still uninitiated partitions so return pending. + // We only get here if we've polled all uninitiated streams and at least one of them + // returned pending itself. That means we will be woken as soon as one of the + // streams would like to be polled again. + // There is no need to reschedule ourselves eagerly. + return Poll::Pending; + } } // NB timer records time taken on drop, so there are no diff --git a/datafusion/physical-plan/src/sorts/mod.rs b/datafusion/physical-plan/src/sorts/mod.rs index c7ffae4061c0..9c72e34fe343 100644 --- a/datafusion/physical-plan/src/sorts/mod.rs +++ b/datafusion/physical-plan/src/sorts/mod.rs @@ -20,6 +20,7 @@ mod builder; mod cursor; mod merge; +mod multi_level_merge; pub mod partial_sort; pub mod sort; pub mod sort_preserving_merge; diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs new file mode 100644 index 000000000000..fc55465b0e01 --- /dev/null +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -0,0 +1,342 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Create a stream that do a multi level merge stream + +use crate::metrics::BaselineMetrics; +use crate::{EmptyRecordBatchStream, SpillManager}; +use arrow::array::RecordBatch; +use std::fmt::{Debug, Formatter}; +use std::mem; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::datatypes::SchemaRef; +use datafusion_common::Result; +use datafusion_execution::memory_pool::{ + MemoryConsumer, MemoryPool, MemoryReservation, UnboundedMemoryPool, +}; + +use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; +use crate::stream::RecordBatchStreamAdapter; +use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use futures::TryStreamExt; +use futures::{Stream, StreamExt}; + +/// Merges a stream of sorted cursors and record batches into a single sorted stream +pub(crate) struct MultiLevelMergeBuilder { + spill_manager: SpillManager, + schema: SchemaRef, + sorted_spill_files: Vec, + sorted_streams: Vec, + expr: LexOrdering, + metrics: BaselineMetrics, + batch_size: usize, + reservation: MemoryReservation, + fetch: Option, + enable_round_robin_tie_breaker: bool, + + // This is for avoiding double reservation of memory from our side and the sort preserving merge stream + // side. + // and doing a lot of code changes to avoid accounting for the memory used by the streams + unbounded_memory_pool: Arc, +} + +impl Debug for MultiLevelMergeBuilder { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "MultiLevelMergeBuilder") + } +} + +impl MultiLevelMergeBuilder { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + spill_manager: SpillManager, + schema: SchemaRef, + sorted_spill_files: Vec, + sorted_streams: Vec, + expr: LexOrdering, + metrics: BaselineMetrics, + batch_size: usize, + reservation: MemoryReservation, + fetch: Option, + enable_round_robin_tie_breaker: bool, + ) -> Self { + Self { + spill_manager, + schema, + sorted_spill_files, + sorted_streams, + expr, + metrics, + batch_size, + reservation, + enable_round_robin_tie_breaker, + fetch, + unbounded_memory_pool: Arc::new(UnboundedMemoryPool::default()), + } + } + + pub(crate) fn create_spillable_merge_stream(self) -> SendableRecordBatchStream { + Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&self.schema), + futures::stream::once(self.create_stream()).try_flatten(), + )) + } + + async fn create_stream(mut self) -> Result { + loop { + // Hold this for the lifetime of the stream + let mut current_memory_reservation = self.reservation.new_empty(); + let mut stream = + self.create_sorted_stream(&mut current_memory_reservation)?; + + // TODO - add a threshold for number of files to disk even if empty and reading from disk so + // we can avoid the memory reservation + + // If no spill files are left, we can return the stream as this is the last sorted run + // TODO - We can write to disk before reading it back to avoid having multiple streams in memory + if self.sorted_spill_files.is_empty() { + // Attach the memory reservation to the stream as we are done with it + // but because we replaced the memory reservation of the merge stream, we must hold + // this to make sure we have enough memory + return Ok(Box::pin(StreamAttachedReservation::new( + stream, + current_memory_reservation, + ))); + } + + // Need to sort to a spill file + let Some((spill_file, max_record_batch_memory)) = self + .spill_manager + .spill_record_batch_stream_by_size( + &mut stream, + self.batch_size, + "MultiLevelMergeBuilder intermediate spill", + ) + .await? + else { + continue; + }; + + // Add the spill file + self.sorted_spill_files.push(SortedSpillFile { + file: spill_file, + max_record_batch_memory, + }); + } + } + + fn create_sorted_stream( + &mut self, + memory_reservation: &mut MemoryReservation, + ) -> Result { + match (self.sorted_spill_files.len(), self.sorted_streams.len()) { + // No data so empty batch + (0, 0) => Ok(Box::pin(EmptyRecordBatchStream::new(Arc::clone( + &self.schema, + )))), + + // Only in-memory stream, return that + (0, 1) => Ok(self.sorted_streams.remove(0)), + + // Only single sorted spill file so return it + (1, 0) => { + let spill_file = self.sorted_spill_files.remove(0); + + self.spill_manager.read_spill_as_stream(spill_file.file) + } + + // Only in memory streams, so merge them all in a single pass + (0, _) => { + let sorted_stream = mem::take(&mut self.sorted_streams); + self.create_new_merge_sort( + sorted_stream, + // If we have no sorted spill files left, this is the last run + true, + ) + } + + // Need to merge multiple streams + (_, _) => { + // Don't account for existing streams memory + // as we are not holding the memory for them + let mut sorted_streams = mem::take(&mut self.sorted_streams); + + let (sorted_spill_files, buffer_size) = self + .get_sorted_spill_files_to_merge( + 2, + // we must have at least 2 streams to merge + 2_usize.saturating_sub(sorted_streams.len()), + memory_reservation, + )?; + + for spill in sorted_spill_files { + let stream = self + .spill_manager + .clone() + .with_batch_read_buffer_capacity(buffer_size) + .read_spill_as_stream(spill.file)?; + sorted_streams.push(stream); + } + + self.create_new_merge_sort( + sorted_streams, + // If we have no sorted spill files left, this is the last run + self.sorted_spill_files.is_empty(), + ) + } + } + } + + fn create_new_merge_sort( + &mut self, + streams: Vec, + is_output: bool, + ) -> Result { + StreamingMergeBuilder::new() + .with_schema(Arc::clone(&self.schema)) + .with_expressions(&self.expr) + .with_batch_size(self.batch_size) + .with_fetch(self.fetch) + .with_metrics(if is_output { + // Only add the metrics to the last run + self.metrics.clone() + } else { + self.metrics.intermediate() + }) + .with_round_robin_tie_breaker(self.enable_round_robin_tie_breaker) + .with_streams(streams) + // Don't track memory used by this stream as we reserve that memory by worst case sceneries + // (reserving memory for the biggest batch in each stream) + // This is a hack + .with_reservation( + MemoryConsumer::new("merge stream mock memory") + .register(&self.unbounded_memory_pool), + ) + .build() + } + + /// Return the sorted spill files to use for the next phase, and the buffer size + /// This will try to get as many spill files as possible to merge, and if we don't have enough streams + /// it will try to reduce the buffer size until we have enough streams to merge + /// otherwise it will return an error + fn get_sorted_spill_files_to_merge( + &mut self, + buffer_size: usize, + minimum_number_of_required_streams: usize, + reservation: &mut MemoryReservation, + ) -> Result<(Vec, usize)> { + assert_ne!(buffer_size, 0, "Buffer size must be greater than 0"); + let mut number_of_spills_to_read_for_current_phase = 0; + + for spill in &self.sorted_spill_files { + // For memory pools that are not shared this is good, for other this is not + // and there should be some upper limit to memory reservation so we won't starve the system + match reservation.try_grow(spill.max_record_batch_memory * buffer_size) { + Ok(_) => { + number_of_spills_to_read_for_current_phase += 1; + } + // If we can't grow the reservation, we need to stop + Err(err) => { + // We must have at least 2 streams to merge, so if we don't have enough memory + // fail + if minimum_number_of_required_streams + > number_of_spills_to_read_for_current_phase + { + // Free the memory we reserved for this merge as we either try again or fail + reservation.free(); + if buffer_size > 1 { + // Try again with smaller buffer size, it will be slower but at least we can merge + return self.get_sorted_spill_files_to_merge( + buffer_size - 1, + minimum_number_of_required_streams, + reservation, + ); + } + + return Err(err); + } + + // We reached the maximum amount of memory we can use + // for this merge + break; + } + } + } + + let spills = self + .sorted_spill_files + .drain(..number_of_spills_to_read_for_current_phase) + .collect::>(); + + Ok((spills, buffer_size)) + } +} + +struct StreamAttachedReservation { + stream: SendableRecordBatchStream, + reservation: MemoryReservation, +} + +impl StreamAttachedReservation { + fn new(stream: SendableRecordBatchStream, reservation: MemoryReservation) -> Self { + Self { + stream, + reservation, + } + } +} + +impl Stream for StreamAttachedReservation { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let res = self.stream.poll_next_unpin(cx); + + match res { + Poll::Ready(res) => { + match res { + Some(Ok(batch)) => Poll::Ready(Some(Ok(batch))), + Some(Err(err)) => { + // Had an error so drop the data + self.reservation.free(); + Poll::Ready(Some(Err(err))) + } + None => { + // Stream is done so free the memory + self.reservation.free(); + + Poll::Ready(None) + } + } + } + Poll::Pending => Poll::Pending, + } + } +} + +impl RecordBatchStream for StreamAttachedReservation { + fn schema(&self) -> SchemaRef { + self.stream.schema() + } +} diff --git a/datafusion/physical-plan/src/sorts/partial_sort.rs b/datafusion/physical-plan/src/sorts/partial_sort.rs index 320fa21c8665..32b34a75cc76 100644 --- a/datafusion/physical-plan/src/sorts/partial_sort.rs +++ b/datafusion/physical-plan/src/sorts/partial_sort.rs @@ -105,7 +105,8 @@ impl PartialSortExec { ) -> Self { debug_assert!(common_prefix_length > 0); let preserve_partitioning = false; - let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning); + let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning) + .unwrap(); Self { input, expr, @@ -159,7 +160,7 @@ impl PartialSortExec { /// Sort expressions pub fn expr(&self) -> &LexOrdering { - self.expr.as_ref() + &self.expr } /// If `Some(fetch)`, limits output to only the first "fetch" items @@ -189,24 +190,22 @@ impl PartialSortExec { input: &Arc, sort_exprs: LexOrdering, preserve_partitioning: bool, - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties; i.e. reset the ordering equivalence // class with the new ordering: - let eq_properties = input - .equivalence_properties() - .clone() - .with_reorder(sort_exprs); + let mut eq_properties = input.equivalence_properties().clone(); + eq_properties.reorder(sort_exprs)?; // Get output partitioning: let output_partitioning = Self::output_partitioning_helper(input, preserve_partitioning); - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, input.pipeline_behavior(), input.boundedness(), - ) + )) } } @@ -296,10 +295,7 @@ impl ExecutionPlan for PartialSortExec { let input = self.input.execute(partition, Arc::clone(&context))?; - trace!( - "End PartialSortExec's input.execute for partition: {}", - partition - ); + trace!("End PartialSortExec's input.execute for partition: {partition}"); // Make sure common prefix length is larger than 0 // Otherwise, we should use SortExec. @@ -321,7 +317,11 @@ impl ExecutionPlan for PartialSortExec { } fn statistics(&self) -> Result { - self.input.statistics() + self.input.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.input.partition_statistics(partition) } } @@ -420,7 +420,7 @@ impl PartialSortStream { fn sort_in_mem_batches(self: &mut Pin<&mut Self>) -> Result { let input_batch = concat_batches(&self.schema(), &self.in_mem_batches)?; self.in_mem_batches.clear(); - let result = sort_batch(&input_batch, self.expr.as_ref(), self.fetch)?; + let result = sort_batch(&input_batch, &self.expr, self.fetch)?; if let Some(remaining_fetch) = self.fetch { // remaining_fetch - result.num_rows() is always be >= 0 // because result length of sort_batch with limit cannot be @@ -503,7 +503,7 @@ mod tests { }; let partial_sort_exec = Arc::new(PartialSortExec::new( - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -516,10 +516,11 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ]), + ] + .into(), Arc::clone(&source), 2, - )) as Arc; + )); let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; @@ -568,7 +569,7 @@ mod tests { for common_prefix_length in [1, 2] { let partial_sort_exec = Arc::new( PartialSortExec::new( - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -581,12 +582,13 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ]), + ] + .into(), Arc::clone(&source), common_prefix_length, ) .with_fetch(Some(4)), - ) as Arc; + ); let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; @@ -641,7 +643,7 @@ mod tests { [(1, &source_tables[0]), (2, &source_tables[1])] { let partial_sort_exec = Arc::new(PartialSortExec::new( - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -654,7 +656,8 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ]), + ] + .into(), Arc::clone(source), common_prefix_length, )); @@ -730,8 +733,8 @@ mod tests { nulls_first: false, }; let schema = mem_exec.schema(); - let partial_sort_executor = PartialSortExec::new( - LexOrdering::new(vec![ + let partial_sort_exec = PartialSortExec::new( + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -744,17 +747,16 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ]), + ] + .into(), Arc::clone(&mem_exec), 1, ); - let partial_sort_exec = - Arc::new(partial_sort_executor.clone()) as Arc; let sort_exec = Arc::new(SortExec::new( - partial_sort_executor.expr, - partial_sort_executor.input, - )) as Arc; - let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; + partial_sort_exec.expr.clone(), + Arc::clone(&partial_sort_exec.input), + )); + let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?; assert_eq!( result.iter().map(|r| r.num_rows()).collect_vec(), [125, 125, 150] @@ -791,8 +793,8 @@ mod tests { (Some(150), vec![125, 25]), (Some(250), vec![125, 125]), ] { - let partial_sort_executor = PartialSortExec::new( - LexOrdering::new(vec![ + let partial_sort_exec = PartialSortExec::new( + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -805,19 +807,22 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ]), + ] + .into(), Arc::clone(&mem_exec), 1, ) .with_fetch(fetch_size); - let partial_sort_exec = - Arc::new(partial_sort_executor.clone()) as Arc; let sort_exec = Arc::new( - SortExec::new(partial_sort_executor.expr, partial_sort_executor.input) - .with_fetch(fetch_size), - ) as Arc; - let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; + SortExec::new( + partial_sort_exec.expr.clone(), + Arc::clone(&partial_sort_exec.input), + ) + .with_fetch(fetch_size), + ); + let result = + collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?; assert_eq!( result.iter().map(|r| r.num_rows()).collect_vec(), expected_batch_num_rows @@ -846,8 +851,8 @@ mod tests { nulls_first: false, }; let fetch_size = Some(250); - let partial_sort_executor = PartialSortExec::new( - LexOrdering::new(vec![ + let partial_sort_exec = PartialSortExec::new( + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -856,15 +861,14 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ]), + ] + .into(), Arc::clone(&mem_exec), 1, ) .with_fetch(fetch_size); - let partial_sort_exec = - Arc::new(partial_sort_executor.clone()) as Arc; - let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; + let result = collect(Arc::new(partial_sort_exec), Arc::clone(&task_ctx)).await?; for rb in result { assert!(rb.num_rows() > 0); } @@ -897,10 +901,11 @@ mod tests { TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?; let partial_sort_exec = Arc::new(PartialSortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("field_name", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), input, 1, )); @@ -986,7 +991,7 @@ mod tests { )?; let partial_sort_exec = Arc::new(PartialSortExec::new( - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -999,7 +1004,8 @@ mod tests { expr: col("c", &schema)?, options: option_desc, }, - ]), + ] + .into(), TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?, 2, )); @@ -1061,10 +1067,11 @@ mod tests { let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); let sort_exec = Arc::new(PartialSortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), blocking_exec, 1, )); diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 8c0c6a7e8ea9..e977b214cfc8 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -27,13 +27,15 @@ use std::sync::Arc; use crate::common::spawn_buffered; use crate::execution_plan::{Boundedness, CardinalityEffect, EmissionType}; use crate::expressions::PhysicalSortExpr; +use crate::filter_pushdown::{FilterDescription, FilterPushdownPhase}; use crate::limit::LimitStream; use crate::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, SpillMetrics, }; use crate::projection::{make_with_child, update_expr, ProjectionExec}; -use crate::sorts::streaming_merge::StreamingMergeBuilder; +use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::spill::get_record_batch_memory_size; +use crate::spill::get_size::GetActualSize; use crate::spill::in_progress_spill_file::InProgressSpillFile; use crate::spill::spill_manager::SpillManager; use crate::stream::RecordBatchStreamAdapter; @@ -44,6 +46,10 @@ use crate::{ Statistics, }; +use arrow::array::{Array, RecordBatch, RecordBatchOptions, StringViewArray}; +use arrow::datatypes::SchemaRef; +use datafusion_common::config::SpillCompression; +use datafusion_execution::disk_manager::RefCountedTempFile; use arrow::array::{ Array, RecordBatch, RecordBatchOptions, StringViewArray, UInt32Array, }; @@ -53,12 +59,12 @@ use arrow::row::{RowConverter, Rows, SortField}; use datafusion_common::{ exec_datafusion_err, internal_datafusion_err, internal_err, DataFusionError, Result, }; -use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; +use datafusion_physical_expr::expressions::{lit, DynamicFilterPhysicalExpr}; use datafusion_physical_expr::LexOrdering; -use datafusion_physical_expr_common::sort_expr::LexRequirement; +use datafusion_physical_expr::PhysicalExpr; use futures::{StreamExt, TryStreamExt}; use log::{debug, trace}; @@ -89,8 +95,9 @@ impl ExternalSorterMetrics { /// 1. get a non-empty new batch from input /// /// 2. check with the memory manager there is sufficient space to -/// buffer the batch in memory 2.1 if memory sufficient, buffer -/// batch in memory, go to 1. +/// buffer the batch in memory. +/// +/// 2.1 if memory is sufficient, buffer batch in memory, go to 1. /// /// 2.2 if no more memory is available, sort all buffered batches and /// spill to file. buffer the next batch in memory, go to 1. @@ -204,9 +211,7 @@ struct ExternalSorter { /// Schema of the output (and the input) schema: SchemaRef, /// Sort expressions - expr: Arc<[PhysicalSortExpr]>, - /// RowConverter corresponding to the sort expressions - sort_keys_row_converter: Arc, + expr: LexOrdering, /// The target number of rows for output batches batch_size: usize, /// If the in size of buffered memory batches is below this size, @@ -223,12 +228,12 @@ struct ExternalSorter { /// During external sorting, in-memory intermediate data will be appended to /// this file incrementally. Once finished, this file will be moved to [`Self::finished_spill_files`]. - in_progress_spill_file: Option, + in_progress_spill_file: Option<(InProgressSpillFile, usize)>, /// If data has previously been spilled, the locations of the spill files (in /// Arrow IPC format) /// Within the same spill file, the data might be chunked into multiple batches, /// and ordered by sort keys. - finished_spill_files: Vec, + finished_spill_files: Vec, // ======================================================================== // EXECUTION RESOURCES: @@ -262,6 +267,8 @@ impl ExternalSorter { batch_size: usize, sort_spill_reservation_bytes: usize, sort_in_place_threshold_bytes: usize, + // Configured via `datafusion.execution.spill_compression`. + spill_compression: SpillCompression, metrics: &ExecutionPlanMetricsSet, runtime: Arc, ) -> Result { @@ -274,35 +281,19 @@ impl ExternalSorter { MemoryConsumer::new(format!("ExternalSorterMerge[{partition_id}]")) .register(&runtime.memory_pool); - // Construct RowConverter for sort keys - let sort_fields = expr - .iter() - .map(|e| { - let data_type = e - .expr - .data_type(&schema) - .map_err(|e| e.context("Resolving sort expression data type"))?; - Ok(SortField::new_with_options(data_type, e.options)) - }) - .collect::>>()?; - - let converter = RowConverter::new(sort_fields).map_err(|e| { - exec_datafusion_err!("Failed to create RowConverter: {:?}", e) - })?; - let spill_manager = SpillManager::new( Arc::clone(&runtime), metrics.spill_metrics.clone(), Arc::clone(&schema), - ); + ) + .with_compression_type(spill_compression); Ok(Self { schema, in_mem_batches: vec![], in_progress_spill_file: None, finished_spill_files: vec![], - expr: expr.into(), - sort_keys_row_converter: Arc::new(converter), + expr, metrics, reservation, spill_manager, @@ -350,8 +341,6 @@ impl ExternalSorter { self.merge_reservation.free(); if self.spilled_before() { - let mut streams = vec![]; - // Sort `in_mem_batches` and spill it first. If there are many // `in_mem_batches` and the memory limit is almost reached, merging // them with the spilled files at the same time might cause OOM. @@ -359,6 +348,7 @@ impl ExternalSorter { self.sort_and_spill_in_mem_batches().await?; } + // TODO(ding-young) check whether streams are used for spill in self.finished_spill_files.drain(..) { if !spill.path().exists() { return internal_err!("Spill file {:?} does not exist", spill.path()); @@ -367,12 +357,11 @@ impl ExternalSorter { streams.push(stream); } - let expressions: LexOrdering = self.expr.iter().cloned().collect(); - StreamingMergeBuilder::new() - .with_streams(streams) + .with_sorted_spill_files(std::mem::take(&mut self.finished_spill_files)) + .with_spill_manager(self.spill_manager.clone()) .with_schema(Arc::clone(&self.schema)) - .with_expressions(expressions.as_ref()) + .with_expressions(&self.expr.clone()) .with_metrics(self.metrics.baseline.clone()) .with_batch_size(self.batch_size) .with_fetch(None) @@ -416,7 +405,7 @@ impl ExternalSorter { // Lazily initialize the in-progress spill file if self.in_progress_spill_file.is_none() { self.in_progress_spill_file = - Some(self.spill_manager.create_in_progress_file("Sorting")?); + Some((self.spill_manager.create_in_progress_file("Sorting")?, 0)); } Self::organize_stringview_arrays(globally_sorted_batches)?; @@ -426,12 +415,16 @@ impl ExternalSorter { let batches_to_spill = std::mem::take(globally_sorted_batches); self.reservation.free(); - let in_progress_file = self.in_progress_spill_file.as_mut().ok_or_else(|| { - internal_datafusion_err!("In-progress spill file should be initialized") - })?; + let (in_progress_file, max_record_batch_size) = + self.in_progress_spill_file.as_mut().ok_or_else(|| { + internal_datafusion_err!("In-progress spill file should be initialized") + })?; for batch in batches_to_spill { in_progress_file.append_batch(&batch)?; + + *max_record_batch_size = + (*max_record_batch_size).max(batch.get_actually_used_size()); } if !globally_sorted_batches.is_empty() { @@ -443,14 +436,17 @@ impl ExternalSorter { /// Finishes the in-progress spill file and moves it to the finished spill files. async fn spill_finish(&mut self) -> Result<()> { - let mut in_progress_file = + let (mut in_progress_file, max_record_batch_memory) = self.in_progress_spill_file.take().ok_or_else(|| { internal_datafusion_err!("Should be called after `spill_append`") })?; let spill_file = in_progress_file.finish()?; if let Some(spill_file) = spill_file { - self.finished_spill_files.push(spill_file); + self.finished_spill_files.push(SortedSpillFile { + file: spill_file, + max_record_batch_memory, + }); } Ok(()) @@ -697,12 +693,10 @@ impl ExternalSorter { }) .collect::>()?; - let expressions: LexOrdering = self.expr.iter().cloned().collect(); - StreamingMergeBuilder::new() .with_streams(streams) .with_schema(Arc::clone(&self.schema)) - .with_expressions(expressions.as_ref()) + .with_expressions(&self.expr.clone()) .with_metrics(metrics) .with_batch_size(self.batch_size) .with_fetch(None) @@ -726,23 +720,11 @@ impl ExternalSorter { ); let schema = batch.schema(); - let expressions: LexOrdering = self.expr.iter().cloned().collect(); - let row_converter = Arc::clone(&self.sort_keys_row_converter); + let expressions = self.expr.clone(); let stream = futures::stream::once(async move { let _timer = metrics.elapsed_compute().timer(); - let sort_columns = expressions - .iter() - .map(|expr| expr.evaluate_to_sort_column(&batch)) - .collect::>>()?; - - let sorted = if is_multi_column_with_lists(&sort_columns) { - // lex_sort_to_indices doesn't support List with more than one column - // https://github.com/apache/arrow-rs/issues/5454 - sort_batch_row_based(&batch, &expressions, row_converter, None)? - } else { - sort_batch(&batch, &expressions, None)? - }; + let sorted = sort_batch(&batch, &expressions, None)?; metrics.record_output(sorted.num_rows()); drop(batch); @@ -833,45 +815,6 @@ impl Debug for ExternalSorter { } } -/// Converts rows into a sorted array of indices based on their order. -/// This function returns the indices that represent the sorted order of the rows. -fn rows_to_indices(rows: Rows, limit: Option) -> Result { - let mut sort: Vec<_> = rows.iter().enumerate().collect(); - sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)); - - let mut len = rows.num_rows(); - if let Some(limit) = limit { - len = limit.min(len); - } - let indices = - UInt32Array::from_iter_values(sort.iter().take(len).map(|(i, _)| *i as u32)); - Ok(indices) -} - -/// Sorts a `RecordBatch` by converting its sort columns into Arrow Row Format for faster comparison. -fn sort_batch_row_based( - batch: &RecordBatch, - expressions: &LexOrdering, - row_converter: Arc, - fetch: Option, -) -> Result { - let sort_columns = expressions - .iter() - .map(|expr| expr.evaluate_to_sort_column(batch).map(|col| col.values)) - .collect::>>()?; - let rows = row_converter.convert_columns(&sort_columns)?; - let indices = rows_to_indices(rows, fetch)?; - let columns = take_arrays(batch.columns(), &indices, None)?; - - let options = RecordBatchOptions::new().with_row_count(Some(indices.len())); - - Ok(RecordBatch::try_new_with_options( - batch.schema(), - columns, - &options, - )?) -} - pub fn sort_batch( batch: &RecordBatch, expressions: &LexOrdering, @@ -882,14 +825,7 @@ pub fn sort_batch( .map(|expr| expr.evaluate_to_sort_column(batch)) .collect::>>()?; - let indices = if is_multi_column_with_lists(&sort_columns) { - // lex_sort_to_indices doesn't support List with more than one column - // https://github.com/apache/arrow-rs/issues/5454 - lexsort_to_indices_multi_columns(sort_columns, fetch)? - } else { - lexsort_to_indices(&sort_columns, fetch)? - }; - + let indices = lexsort_to_indices(&sort_columns, fetch)?; let mut columns = take_arrays(batch.columns(), &indices, None)?; // The columns may be larger than the unsorted columns in `batch` especially for variable length @@ -908,50 +844,6 @@ pub fn sort_batch( )?) } -#[inline] -fn is_multi_column_with_lists(sort_columns: &[SortColumn]) -> bool { - sort_columns.iter().any(|c| { - matches!( - c.values.data_type(), - DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) - ) - }) -} - -pub(crate) fn lexsort_to_indices_multi_columns( - sort_columns: Vec, - limit: Option, -) -> Result { - let (fields, columns) = sort_columns.into_iter().fold( - (vec![], vec![]), - |(mut fields, mut columns), sort_column| { - fields.push(SortField::new_with_options( - sort_column.values.data_type().clone(), - sort_column.options.unwrap_or_default(), - )); - columns.push(sort_column.values); - (fields, columns) - }, - ); - - // Note: row converter is reused through `sort_batch_row_based()`, this function - // is not used during normal sort execution, but it's kept temporarily because - // it's inside a public interface `sort_batch()`. - let converter = RowConverter::new(fields)?; - let rows = converter.convert_columns(&columns)?; - let mut sort: Vec<_> = rows.iter().enumerate().collect(); - sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)); - - let mut len = rows.num_rows(); - if let Some(limit) = limit { - len = limit.min(len); - } - let indices = - UInt32Array::from_iter_values(sort.iter().take(len).map(|(i, _)| *i as u32)); - - Ok(indices) -} - /// Sort execution plan. /// /// Support sorting datasets that are larger than the memory allotted @@ -970,9 +862,11 @@ pub struct SortExec { /// Fetch highest/lowest n results fetch: Option, /// Normalized common sort prefix between the input and the sort expressions (only used with fetch) - common_sort_prefix: LexOrdering, + common_sort_prefix: Vec, /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, + /// Filter matching the state of the sort for dynamic filter pushdown + filter: Option>, } impl SortExec { @@ -981,7 +875,8 @@ impl SortExec { pub fn new(expr: LexOrdering, input: Arc) -> Self { let preserve_partitioning = false; let (cache, sort_prefix) = - Self::compute_properties(&input, expr.clone(), preserve_partitioning); + Self::compute_properties(&input, expr.clone(), preserve_partitioning) + .unwrap(); Self { expr, input, @@ -990,6 +885,7 @@ impl SortExec { fetch: None, common_sort_prefix: sort_prefix, cache, + filter: None, } } @@ -1035,6 +931,17 @@ impl SortExec { if fetch.is_some() && is_pipeline_friendly { cache = cache.with_boundedness(Boundedness::Bounded); } + let filter = fetch.is_some().then(|| { + // If we already have a filter, keep it. Otherwise, create a new one. + self.filter.clone().unwrap_or_else(|| { + let children = self + .expr + .iter() + .map(|sort_expr| Arc::clone(&sort_expr.expr)) + .collect::>(); + Arc::new(DynamicFilterPhysicalExpr::new(children, lit(true))) + }) + }); SortExec { input: Arc::clone(&self.input), expr: self.expr.clone(), @@ -1043,6 +950,7 @@ impl SortExec { common_sort_prefix: self.common_sort_prefix.clone(), fetch, cache, + filter, } } @@ -1079,13 +987,10 @@ impl SortExec { input: &Arc, sort_exprs: LexOrdering, preserve_partitioning: bool, - ) -> (PlanProperties, LexOrdering) { - // Determine execution mode: - let requirement = LexRequirement::from(sort_exprs); - + ) -> Result<(PlanProperties, Vec)> { let (sort_prefix, sort_satisfied) = input .equivalence_properties() - .extract_common_sort_prefix(&requirement); + .extract_common_sort_prefix(sort_exprs.clone())?; // The emission type depends on whether the input is already sorted: // - If already fully sorted, we can emit results in the same way as the input @@ -1114,25 +1019,22 @@ impl SortExec { // Calculate equivalence properties; i.e. reset the ordering equivalence // class with the new ordering: - let sort_exprs = LexOrdering::from(requirement); - let eq_properties = input - .equivalence_properties() - .clone() - .with_reorder(sort_exprs); + let mut eq_properties = input.equivalence_properties().clone(); + eq_properties.reorder(sort_exprs)?; // Get output partitioning: let output_partitioning = Self::output_partitioning_helper(input, preserve_partitioning); - ( + Ok(( PlanProperties::new( eq_properties, output_partitioning, emission_type, boundedness, ), - LexOrdering::from(sort_prefix), - ) + sort_prefix, + )) } } @@ -1144,8 +1046,25 @@ impl DisplayAs for SortExec { match self.fetch { Some(fetch) => { write!(f, "SortExec: TopK(fetch={fetch}), expr=[{}], preserve_partitioning=[{preserve_partitioning}]", self.expr)?; + if let Some(filter) = &self.filter { + if let Ok(current) = filter.current() { + if !current.eq(&lit(true)) { + write!(f, ", filter=[{current}]")?; + } + } + } if !self.common_sort_prefix.is_empty() { - write!(f, ", sort_prefix=[{}]", self.common_sort_prefix) + write!(f, ", sort_prefix=[")?; + let mut first = true; + for sort_expr in &self.common_sort_prefix { + if first { + first = false; + } else { + write!(f, ", ")?; + } + write!(f, "{sort_expr}")?; + } + write!(f, "]") } else { Ok(()) } @@ -1204,9 +1123,10 @@ impl ExecutionPlan for SortExec { self: Arc, children: Vec>, ) -> Result> { - let new_sort = SortExec::new(self.expr.clone(), Arc::clone(&children[0])) + let mut new_sort = SortExec::new(self.expr.clone(), Arc::clone(&children[0])) .with_fetch(self.fetch) .with_preserve_partitioning(self.preserve_partitioning); + new_sort.filter = self.filter.clone(); Ok(Arc::new(new_sort)) } @@ -1222,14 +1142,12 @@ impl ExecutionPlan for SortExec { let execution_options = &context.session_config().options().execution; - trace!("End SortExec's input.execute for partition: {}", partition); - - let requirement = &LexRequirement::from(self.expr.clone()); + trace!("End SortExec's input.execute for partition: {partition}"); let sort_satisfied = self .input .equivalence_properties() - .ordering_satisfy_requirement(requirement); + .ordering_satisfy(self.expr.clone())?; match (sort_satisfied, self.fetch.as_ref()) { (true, Some(fetch)) => Ok(Box::pin(LimitStream::new( @@ -1249,6 +1167,7 @@ impl ExecutionPlan for SortExec { context.session_config().batch_size(), context.runtime_env(), &self.metrics_set, + self.filter.clone(), )?; Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), @@ -1273,6 +1192,7 @@ impl ExecutionPlan for SortExec { context.session_config().batch_size(), execution_options.sort_spill_reservation_bytes, execution_options.sort_in_place_threshold_bytes, + context.session_config().spill_compression(), &self.metrics_set, context.runtime_env(), )?; @@ -1296,7 +1216,24 @@ impl ExecutionPlan for SortExec { } fn statistics(&self) -> Result { - Statistics::with_fetch(self.input.statistics()?, self.schema(), self.fetch, 0, 1) + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if !self.preserve_partitioning() { + return self.input.partition_statistics(None)?.with_fetch( + self.schema(), + self.fetch, + 0, + 1, + ); + } + self.input.partition_statistics(partition)?.with_fetch( + self.schema(), + self.fetch, + 0, + 1, + ) } fn with_fetch(&self, limit: Option) -> Option> { @@ -1327,17 +1264,10 @@ impl ExecutionPlan for SortExec { return Ok(None); } - let mut updated_exprs = LexOrdering::default(); - for sort in self.expr() { - let Some(new_expr) = update_expr(&sort.expr, projection.expr(), false)? - else { - return Ok(None); - }; - updated_exprs.push(PhysicalSortExpr { - expr: new_expr, - options: sort.options, - }); - } + let Some(updated_exprs) = update_ordering(self.expr.clone(), projection.expr())? + else { + return Ok(None); + }; Ok(Some(Arc::new( SortExec::new(updated_exprs, make_with_child(projection, self.input())?) @@ -1345,6 +1275,28 @@ impl ExecutionPlan for SortExec { .with_preserve_partitioning(self.preserve_partitioning()), ))) } + + fn gather_filters_for_pushdown( + &self, + phase: FilterPushdownPhase, + parent_filters: Vec>, + config: &datafusion_common::config::ConfigOptions, + ) -> Result { + if !matches!(phase, FilterPushdownPhase::Post) { + return Ok(FilterDescription::new_with_child_count(1) + .all_parent_filters_supported(parent_filters)); + } + if let Some(filter) = &self.filter { + if config.optimizer.enable_dynamic_filter_pushdown { + let filter = Arc::clone(filter) as Arc; + return Ok(FilterDescription::new_with_child_count(1) + .all_parent_filters_supported(parent_filters) + .with_self_filter(filter)); + } + } + Ok(FilterDescription::new_with_child_count(1) + .all_parent_filters_supported(parent_filters)) + } } #[cfg(test)] @@ -1399,9 +1351,9 @@ mod tests { impl SortedUnboundedExec { fn compute_properties(schema: SchemaRef) -> PlanProperties { let mut eq_properties = EquivalenceProperties::new(schema); - eq_properties.add_new_orderings(vec![LexOrdering::new(vec![ - PhysicalSortExpr::new_default(Arc::new(Column::new("c1", 0))), - ])]); + eq_properties.add_ordering([PhysicalSortExpr::new_default(Arc::new( + Column::new("c1", 0), + ))]); PlanProperties::new( eq_properties, Partitioning::UnknownPartitioning(1), @@ -1501,10 +1453,11 @@ mod tests { let schema = csv.schema(); let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("i", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), Arc::new(CoalescePartitionsExec::new(csv)), )); @@ -1512,7 +1465,6 @@ mod tests { assert_eq!(result.len(), 1); assert_eq!(result[0].num_rows(), 400); - assert_eq!( task_ctx.runtime_env().memory_pool.reserved(), 0, @@ -1547,10 +1499,11 @@ mod tests { let schema = input.schema(); let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("i", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), Arc::new(CoalescePartitionsExec::new(input)), )); @@ -1576,14 +1529,13 @@ mod tests { // bytes. We leave a little wiggle room for the actual numbers. assert!((3..=10).contains(&spill_count)); assert!((9000..=10000).contains(&spilled_rows)); - assert!((38000..=42000).contains(&spilled_bytes)); + assert!((38000..=44000).contains(&spilled_bytes)); let columns = result[0].columns(); let i = as_primitive_array::(&columns[0])?; assert_eq!(i.value(0), 0); assert_eq!(i.value(i.len() - 1), 81); - assert_eq!( task_ctx.runtime_env().memory_pool.reserved(), 0, @@ -1626,31 +1578,22 @@ mod tests { } let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { - expr: col("i", &plan.schema())?, - options: SortOptions::default(), - }]), + [PhysicalSortExpr::new_default(col("i", &plan.schema())?)].into(), plan, )); - let result = collect( - Arc::clone(&sort_exec) as Arc, - Arc::clone(&task_ctx), - ) - .await; + let result = collect(Arc::clone(&sort_exec) as _, Arc::clone(&task_ctx)).await; let err = result.unwrap_err(); assert!( matches!(err, DataFusionError::Context(..)), - "Assertion failed: expected a Context error, but got: {:?}", - err + "Assertion failed: expected a Context error, but got: {err:?}" ); // Assert that the context error is wrapping a resources exhausted error. assert!( matches!(err.find_root(), DataFusionError::ResourcesExhausted(_)), - "Assertion failed: expected a ResourcesExhausted error, but got: {:?}", - err + "Assertion failed: expected a ResourcesExhausted error, but got: {err:?}" ); Ok(()) @@ -1678,18 +1621,15 @@ mod tests { let schema = input.schema(); let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("i", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), Arc::new(CoalescePartitionsExec::new(input)), )); - let result = collect( - Arc::clone(&sort_exec) as Arc, - Arc::clone(&task_ctx), - ) - .await?; + let result = collect(Arc::clone(&sort_exec) as _, Arc::clone(&task_ctx)).await?; let num_rows = result.iter().map(|batch| batch.num_rows()).sum::(); assert_eq!(num_rows, 20000); @@ -1778,20 +1718,18 @@ mod tests { let sort_exec = Arc::new( SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("i", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), Arc::new(CoalescePartitionsExec::new(csv)), ) .with_fetch(fetch), ); - let result = collect( - Arc::clone(&sort_exec) as Arc, - Arc::clone(&task_ctx), - ) - .await?; + let result = + collect(Arc::clone(&sort_exec) as _, Arc::clone(&task_ctx)).await?; assert_eq!(result.len(), 1); let metrics = sort_exec.metrics().unwrap(); @@ -1821,16 +1759,16 @@ mod tests { let data: ArrayRef = Arc::new(vec![3, 2, 1].into_iter().map(Some).collect::()); - let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data]).unwrap(); + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data])?; let input = - TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None) - .unwrap(); + TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?; let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("field_name", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), input, )); @@ -1839,7 +1777,7 @@ mod tests { let expected_data: ArrayRef = Arc::new(vec![1, 2, 3].into_iter().map(Some).collect::()); let expected_batch = - RecordBatch::try_new(Arc::clone(&schema), vec![expected_data]).unwrap(); + RecordBatch::try_new(Arc::clone(&schema), vec![expected_data])?; // Data is correct assert_eq!(&vec![expected_batch], &result); @@ -1878,7 +1816,7 @@ mod tests { )?; let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { @@ -1893,7 +1831,8 @@ mod tests { nulls_first: false, }, }, - ]), + ] + .into(), TestMemoryExec::try_new_exec(&[vec![batch]], Arc::clone(&schema), None)?, )); @@ -1964,7 +1903,7 @@ mod tests { )?; let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![ + [ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { @@ -1979,7 +1918,8 @@ mod tests { nulls_first: false, }, }, - ]), + ] + .into(), TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?, )); @@ -2043,10 +1983,11 @@ mod tests { let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); let sort_exec = Arc::new(SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), blocking_exec, )); @@ -2074,12 +2015,13 @@ mod tests { RecordBatch::try_new_with_options(Arc::clone(&schema), vec![], &options) .unwrap(); - let expressions = LexOrdering::new(vec![PhysicalSortExpr { + let expressions = [PhysicalSortExpr { expr: Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), options: SortOptions::default(), - }]); + }] + .into(); - let result = sort_batch(&batch, expressions.as_ref(), None).unwrap(); + let result = sort_batch(&batch, &expressions, None).unwrap(); assert_eq!(result.num_rows(), 1); } @@ -2093,9 +2035,10 @@ mod tests { cache: SortedUnboundedExec::compute_properties(Arc::new(schema.clone())), }; let mut plan = SortExec::new( - LexOrdering::new(vec![PhysicalSortExpr::new_default(Arc::new(Column::new( + [PhysicalSortExpr::new_default(Arc::new(Column::new( "c1", 0, - )))]), + )))] + .into(), Arc::new(source), ); plan = plan.with_fetch(Some(9)); diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index b987dff36441..09ad71974e6c 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::common::spawn_buffered; use crate::limit::LimitStream; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use crate::projection::{make_with_child, update_expr, ProjectionExec}; +use crate::projection::{make_with_child, update_ordering, ProjectionExec}; use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, @@ -33,9 +33,9 @@ use crate::{ use datafusion_common::{internal_err, Result}; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; -use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; +use crate::execution_plan::{EvaluationType, SchedulingType}; use log::{debug, trace}; /// Sort preserving merge execution plan @@ -144,7 +144,7 @@ impl SortPreservingMergeExec { /// Sort expressions pub fn expr(&self) -> &LexOrdering { - self.expr.as_ref() + &self.expr } /// Fetch @@ -158,15 +158,27 @@ impl SortPreservingMergeExec { input: &Arc, ordering: LexOrdering, ) -> PlanProperties { + let input_partitions = input.output_partitioning().partition_count(); + let (drive, scheduling) = if input_partitions > 1 { + (EvaluationType::Eager, SchedulingType::Cooperative) + } else { + ( + input.properties().evaluation_type, + input.properties().scheduling_type, + ) + }; + let mut eq_properties = input.equivalence_properties().clone(); eq_properties.clear_per_partition_constants(); - eq_properties.add_new_orderings(vec![ordering]); + eq_properties.add_ordering(ordering); PlanProperties::new( eq_properties, // Equivalence Properties Partitioning::UnknownPartitioning(1), // Output Partitioning input.pipeline_behavior(), // Pipeline Behavior input.boundedness(), // Boundedness ) + .with_evaluation_type(drive) + .with_scheduling_type(scheduling) } } @@ -240,8 +252,8 @@ impl ExecutionPlan for SortPreservingMergeExec { vec![false] } - fn required_input_ordering(&self) -> Vec> { - vec![Some(LexRequirement::from(self.expr.clone()))] + fn required_input_ordering(&self) -> Vec> { + vec![Some(OrderingRequirements::from(self.expr.clone()))] } fn maintains_input_order(&self) -> Vec { @@ -267,10 +279,7 @@ impl ExecutionPlan for SortPreservingMergeExec { partition: usize, context: Arc, ) -> Result { - trace!( - "Start SortPreservingMergeExec::execute for partition: {}", - partition - ); + trace!("Start SortPreservingMergeExec::execute for partition: {partition}"); if 0 != partition { return internal_err!( "SortPreservingMergeExec invalid partition {partition}" @@ -279,8 +288,7 @@ impl ExecutionPlan for SortPreservingMergeExec { let input_partitions = self.input.output_partitioning().partition_count(); trace!( - "Number of input partitions of SortPreservingMergeExec::execute: {}", - input_partitions + "Number of input partitions of SortPreservingMergeExec::execute: {input_partitions}" ); let schema = self.schema(); @@ -323,7 +331,7 @@ impl ExecutionPlan for SortPreservingMergeExec { let result = StreamingMergeBuilder::new() .with_streams(receivers) .with_schema(schema) - .with_expressions(self.expr.as_ref()) + .with_expressions(&self.expr) .with_metrics(BaselineMetrics::new(&self.metrics, partition)) .with_batch_size(context.session_config().batch_size()) .with_fetch(self.fetch) @@ -343,7 +351,11 @@ impl ExecutionPlan for SortPreservingMergeExec { } fn statistics(&self) -> Result { - self.input.statistics() + self.input.partition_statistics(None) + } + + fn partition_statistics(&self, _partition: Option) -> Result { + self.input.partition_statistics(None) } fn supports_limit_pushdown(&self) -> bool { @@ -362,17 +374,10 @@ impl ExecutionPlan for SortPreservingMergeExec { return Ok(None); } - let mut updated_exprs = LexOrdering::default(); - for sort in self.expr() { - let Some(updated_expr) = update_expr(&sort.expr, projection.expr(), false)? - else { - return Ok(None); - }; - updated_exprs.push(PhysicalSortExpr { - expr: updated_expr, - options: sort.options, - }); - } + let Some(updated_exprs) = update_ordering(self.expr.clone(), projection.expr())? + else { + return Ok(None); + }; Ok(Some(Arc::new( SortPreservingMergeExec::new( @@ -386,10 +391,11 @@ impl ExecutionPlan for SortPreservingMergeExec { #[cfg(test)] mod tests { + use std::collections::HashSet; use std::fmt::Formatter; use std::pin::Pin; use std::sync::Mutex; - use std::task::{Context, Poll}; + use std::task::{ready, Context, Poll, Waker}; use std::time::Duration; use super::*; @@ -413,7 +419,7 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::test_util::batches_to_string; - use datafusion_common::{assert_batches_eq, assert_contains, DataFusionError}; + use datafusion_common::{assert_batches_eq, exec_err}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; @@ -421,8 +427,8 @@ mod tests { use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::EquivalenceProperties; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; - use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + use futures::{FutureExt, Stream, StreamExt}; use insta::assert_snapshot; use tokio::time::timeout; @@ -449,24 +455,25 @@ mod tests { let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size])); let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size])); - let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)])?; let rbs = (0..1024).map(|_| rb.clone()).collect::>(); let schema = rb.schema(); - let sort = LexOrdering::new(vec![ + let sort = [ PhysicalSortExpr { - expr: col("b", &schema).unwrap(), + expr: col("b", &schema)?, options: Default::default(), }, PhysicalSortExpr { - expr: col("c", &schema).unwrap(), + expr: col("c", &schema)?, options: Default::default(), }, - ]); + ] + .into(); let repartition_exec = RepartitionExec::try_new( - TestMemoryExec::try_new_exec(&[rbs], schema, None).unwrap(), + TestMemoryExec::try_new_exec(&[rbs], schema, None)?, Partitioning::RoundRobinBatch(2), )?; let coalesce_batches_exec = @@ -485,7 +492,7 @@ mod tests { async fn test_round_robin_tie_breaker_success() -> Result<()> { let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?; let spm = generate_spm_for_round_robin_tie_breaker(true)?; - let _collected = collect(spm, task_ctx).await.unwrap(); + let _collected = collect(spm, task_ctx).await?; Ok(()) } @@ -550,30 +557,6 @@ mod tests { .await; } - #[tokio::test] - async fn test_merge_no_exprs() { - let task_ctx = Arc::new(TaskContext::default()); - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); - let batch = RecordBatch::try_from_iter(vec![("a", a)]).unwrap(); - - let schema = batch.schema(); - let sort = LexOrdering::default(); // no sort expressions - let exec = TestMemoryExec::try_new_exec( - &[vec![batch.clone()], vec![batch]], - schema, - None, - ) - .unwrap(); - - let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); - - let res = collect(merge, task_ctx).await.unwrap_err(); - assert_contains!( - res.to_string(), - "Internal error: Sort expressions cannot be empty for streaming merge" - ); - } - #[tokio::test] async fn test_merge_some_overlap() { let task_ctx = Arc::new(TaskContext::default()); @@ -741,7 +724,7 @@ mod tests { context: Arc, ) { let schema = partitions[0][0].schema(); - let sort = LexOrdering::new(vec![ + let sort = [ PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: Default::default(), @@ -750,7 +733,8 @@ mod tests { expr: col("c", &schema).unwrap(), options: Default::default(), }, - ]); + ] + .into(); let exec = TestMemoryExec::try_new_exec(partitions, schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); @@ -798,13 +782,14 @@ mod tests { let csv = test::scan_partitioned(partitions); let schema = csv.schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("i", &schema).unwrap(), + let sort: LexOrdering = [PhysicalSortExpr { + expr: col("i", &schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]); + }] + .into(); let basic = basic_sort(Arc::clone(&csv), sort.clone(), Arc::clone(&task_ctx)).await; @@ -859,17 +844,18 @@ mod tests { let sorted = basic_sort(csv, sort, context).await; let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect(); - Ok(TestMemoryExec::try_new_exec(&split, sorted.schema(), None).unwrap()) + TestMemoryExec::try_new_exec(&split, sorted.schema(), None).map(|e| e as _) } #[tokio::test] async fn test_partition_sort_streaming_input() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = make_partition(11).schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("i", &schema).unwrap(), + let sort: LexOrdering = [PhysicalSortExpr { + expr: col("i", &schema)?, options: Default::default(), - }]); + }] + .into(); let input = sorted_partitioned_input(sort.clone(), &[10, 3, 11], Arc::clone(&task_ctx)) @@ -881,12 +867,9 @@ mod tests { assert_eq!(basic.num_rows(), 1200); assert_eq!(partition.num_rows(), 1200); - let basic = arrow::util::pretty::pretty_format_batches(&[basic]) - .unwrap() - .to_string(); - let partition = arrow::util::pretty::pretty_format_batches(&[partition]) - .unwrap() - .to_string(); + let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string(); + let partition = + arrow::util::pretty::pretty_format_batches(&[partition])?.to_string(); assert_eq!(basic, partition); @@ -896,10 +879,11 @@ mod tests { #[tokio::test] async fn test_partition_sort_streaming_input_output() -> Result<()> { let schema = make_partition(11).schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { - expr: col("i", &schema).unwrap(), + let sort: LexOrdering = [PhysicalSortExpr { + expr: col("i", &schema)?, options: Default::default(), - }]); + }] + .into(); // Test streaming with default batch size let task_ctx = Arc::new(TaskContext::default()); @@ -914,19 +898,14 @@ mod tests { let task_ctx = Arc::new(task_ctx); let merge = Arc::new(SortPreservingMergeExec::new(sort, input)); - let merged = collect(merge, task_ctx).await.unwrap(); + let merged = collect(merge, task_ctx).await?; assert_eq!(merged.len(), 53); - assert_eq!(basic.num_rows(), 1200); assert_eq!(merged.iter().map(|x| x.num_rows()).sum::(), 1200); - let basic = arrow::util::pretty::pretty_format_batches(&[basic]) - .unwrap() - .to_string(); - let partition = arrow::util::pretty::pretty_format_batches(merged.as_slice()) - .unwrap() - .to_string(); + let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string(); + let partition = arrow::util::pretty::pretty_format_batches(&merged)?.to_string(); assert_eq!(basic, partition); @@ -971,7 +950,7 @@ mod tests { let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); let schema = b1.schema(); - let sort = LexOrdering::new(vec![ + let sort = [ PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: SortOptions { @@ -986,7 +965,8 @@ mod tests { nulls_first: false, }, }, - ]); + ] + .into(); let exec = TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); @@ -1020,13 +1000,14 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); let schema = batch.schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort = [PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec).with_fetch(Some(2))); @@ -1052,13 +1033,14 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); let schema = batch.schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort = [PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); @@ -1082,10 +1064,11 @@ mod tests { async fn test_async() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = make_partition(11).schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort: LexOrdering = [PhysicalSortExpr { expr: col("i", &schema).unwrap(), options: SortOptions::default(), - }]); + }] + .into(); let batches = sorted_partitioned_input(sort.clone(), &[5, 7, 3], Arc::clone(&task_ctx)) @@ -1121,7 +1104,7 @@ mod tests { let merge_stream = StreamingMergeBuilder::new() .with_streams(streams) .with_schema(batches.schema()) - .with_expressions(sort.as_ref()) + .with_expressions(&sort) .with_metrics(BaselineMetrics::new(&metrics, 0)) .with_batch_size(task_ctx.session_config().batch_size()) .with_fetch(fetch) @@ -1161,10 +1144,11 @@ mod tests { let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); let schema = b1.schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort = [PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: Default::default(), - }]); + }] + .into(); let exec = TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); @@ -1220,10 +1204,11 @@ mod tests { let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2)); let refs = blocking_exec.refs(); let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new( - LexOrdering::new(vec![PhysicalSortExpr { + [PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions::default(), - }]), + }] + .into(), blocking_exec, )); @@ -1268,13 +1253,14 @@ mod tests { let schema = partitions[0][0].schema(); - let sort = LexOrdering::new(vec![PhysicalSortExpr { + let sort = [PhysicalSortExpr { expr: col("value", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: true, }, - }]); + }] + .into(); let exec = TestMemoryExec::try_new_exec(&partitions, schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, exec)); @@ -1313,13 +1299,50 @@ mod tests { "#); } + #[derive(Debug)] + struct CongestionState { + wakers: Vec, + unpolled_partitions: HashSet, + } + + #[derive(Debug)] + struct Congestion { + congestion_state: Mutex, + } + + impl Congestion { + fn new(partition_count: usize) -> Self { + Congestion { + congestion_state: Mutex::new(CongestionState { + wakers: vec![], + unpolled_partitions: (0usize..partition_count).collect(), + }), + } + } + + fn check_congested(&self, partition: usize, cx: &mut Context<'_>) -> Poll<()> { + let mut state = self.congestion_state.lock().unwrap(); + + state.unpolled_partitions.remove(&partition); + + if state.unpolled_partitions.is_empty() { + state.wakers.iter().for_each(|w| w.wake_by_ref()); + state.wakers.clear(); + Poll::Ready(()) + } else { + state.wakers.push(cx.waker().clone()); + Poll::Pending + } + } + } + /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st /// partition is exhausted from the start, and if it is polled more than one, it panics. #[derive(Debug, Clone)] struct CongestedExec { schema: Schema, cache: PlanProperties, - congestion_cleared: Arc>, + congestion: Arc, } impl CongestedExec { @@ -1331,10 +1354,11 @@ mod tests { .map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc) .collect::>(); let mut eq_properties = EquivalenceProperties::new(schema); - eq_properties.add_new_orderings(vec![columns - .iter() - .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr))) - .collect::()]); + eq_properties.add_ordering( + columns + .iter() + .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr))), + ); PlanProperties::new( eq_properties, Partitioning::Hash(columns, 3), @@ -1373,7 +1397,7 @@ mod tests { Ok(Box::pin(CongestedStream { schema: Arc::new(self.schema.clone()), none_polled_once: false, - congestion_cleared: Arc::clone(&self.congestion_cleared), + congestion: Arc::clone(&self.congestion), partition, })) } @@ -1400,7 +1424,7 @@ mod tests { pub struct CongestedStream { schema: SchemaRef, none_polled_once: bool, - congestion_cleared: Arc>, + congestion: Arc, partition: usize, } @@ -1408,31 +1432,22 @@ mod tests { type Item = Result; fn poll_next( mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, + cx: &mut Context<'_>, ) -> Poll> { match self.partition { 0 => { + let _ = self.congestion.check_congested(self.partition, cx); if self.none_polled_once { - panic!("Exhausted stream is polled more than one") + panic!("Exhausted stream is polled more than once") } else { self.none_polled_once = true; Poll::Ready(None) } } - 1 => { - let cleared = self.congestion_cleared.lock().unwrap(); - if *cleared { - Poll::Ready(None) - } else { - Poll::Pending - } - } - 2 => { - let mut cleared = self.congestion_cleared.lock().unwrap(); - *cleared = true; + _ => { + ready!(self.congestion.check_congested(self.partition, cx)); Poll::Ready(None) } - _ => unreachable!(), } } } @@ -1447,15 +1462,22 @@ mod tests { async fn test_spm_congestion() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]); + let properties = CongestedExec::compute_properties(Arc::new(schema.clone())); + let &partition_count = match properties.output_partitioning() { + Partitioning::RoundRobinBatch(partitions) => partitions, + Partitioning::Hash(_, partitions) => partitions, + Partitioning::UnknownPartitioning(partitions) => partitions, + }; let source = CongestedExec { schema: schema.clone(), - cache: CongestedExec::compute_properties(Arc::new(schema.clone())), - congestion_cleared: Arc::new(Mutex::new(false)), + cache: properties, + congestion: Arc::new(Congestion::new(partition_count)), }; let spm = SortPreservingMergeExec::new( - LexOrdering::new(vec![PhysicalSortExpr::new_default(Arc::new(Column::new( + [PhysicalSortExpr::new_default(Arc::new(Column::new( "c1", 0, - )))]), + )))] + .into(), Arc::new(source), ); let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx)); @@ -1464,12 +1486,8 @@ mod tests { match result { Ok(Ok(Ok(_batches))) => Ok(()), Ok(Ok(Err(e))) => Err(e), - Ok(Err(_)) => Err(DataFusionError::Execution( - "SortPreservingMerge task panicked or was cancelled".to_string(), - )), - Err(_) => Err(DataFusionError::Execution( - "SortPreservingMerge caused a deadlock".to_string(), - )), + Ok(Err(_)) => exec_err!("SortPreservingMerge task panicked or was cancelled"), + Err(_) => exec_err!("SortPreservingMerge caused a deadlock"), } } } diff --git a/datafusion/physical-plan/src/sorts/streaming_merge.rs b/datafusion/physical-plan/src/sorts/streaming_merge.rs index 3f022ec6095a..3c215caf41c3 100644 --- a/datafusion/physical-plan/src/sorts/streaming_merge.rs +++ b/datafusion/physical-plan/src/sorts/streaming_merge.rs @@ -19,15 +19,17 @@ //! This is an order-preserving merge. use crate::metrics::BaselineMetrics; +use crate::sorts::multi_level_merge::MultiLevelMergeBuilder; use crate::sorts::{ merge::SortPreservingMergeStream, stream::{FieldCursorStream, RowCursorStream}, }; -use crate::SendableRecordBatchStream; +use crate::{SendableRecordBatchStream, SpillManager}; use arrow::array::*; use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::{internal_err, Result}; -use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::memory_pool::{human_readable_size, MemoryReservation}; use datafusion_physical_expr_common::sort_expr::LexOrdering; macro_rules! primitive_merge_helper { @@ -52,10 +54,31 @@ macro_rules! merge_helper { }}; } +pub struct SortedSpillFile { + pub file: RefCountedTempFile, + + /// how much memory the largest memory batch is taking + pub max_record_batch_memory: usize, +} + +impl std::fmt::Debug for SortedSpillFile { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "SortedSpillFile({:?}) takes {}", + self.file.path(), + human_readable_size(self.max_record_batch_memory) + ) + } +} + +#[derive(Default)] pub struct StreamingMergeBuilder<'a> { streams: Vec, + sorted_spill_files: Vec, + spill_manager: Option, schema: Option, - expressions: &'a LexOrdering, + expressions: Option<&'a LexOrdering>, metrics: Option, batch_size: Option, fetch: Option, @@ -67,6 +90,8 @@ impl Default for StreamingMergeBuilder<'_> { fn default() -> Self { Self { streams: vec![], + sorted_spill_files: vec![], + spill_manager: None, schema: None, expressions: LexOrdering::empty(), metrics: None, @@ -91,13 +116,26 @@ impl<'a> StreamingMergeBuilder<'a> { self } + pub fn with_sorted_spill_files( + mut self, + sorted_spill_files: Vec, + ) -> Self { + self.sorted_spill_files = sorted_spill_files; + self + } + + pub fn with_spill_manager(mut self, spill_manager: SpillManager) -> Self { + self.spill_manager = Some(spill_manager); + self + } + pub fn with_schema(mut self, schema: SchemaRef) -> Self { self.schema = Some(schema); self } pub fn with_expressions(mut self, expressions: &'a LexOrdering) -> Self { - self.expressions = expressions; + self.expressions = Some(expressions); self } @@ -136,6 +174,8 @@ impl<'a> StreamingMergeBuilder<'a> { pub fn build(self) -> Result { let Self { streams, + sorted_spill_files, + spill_manager, schema, metrics, batch_size, @@ -144,7 +184,7 @@ impl<'a> StreamingMergeBuilder<'a> { expressions, enable_round_robin_tie_breaker, } = self; - + // Early return if streams or expressions are empty let checks = [ ( @@ -152,7 +192,7 @@ impl<'a> StreamingMergeBuilder<'a> { "Streams cannot be empty for streaming merge", ), ( - expressions.is_empty(), + expressions.is_none_or(|expr| expr.is_empty()), "Sort expressions cannot be empty for streaming merge", ), ]; @@ -161,30 +201,35 @@ impl<'a> StreamingMergeBuilder<'a> { { return internal_err!("{}", error_message); } - - // Unwrapping mandatory fields - let schema = schema.expect("Schema cannot be empty for streaming merge"); - let metrics = metrics.expect("Metrics cannot be empty for streaming merge"); - let batch_size = - batch_size.expect("Batch size cannot be empty for streaming merge"); - let reservation = - reservation.expect("Reservation cannot be empty for streaming merge"); - - // Special case single column comparisons with optimized cursor implementations - if expressions.len() == 1 { - let sort = expressions[0].clone(); - let data_type = sort.expr.data_type(schema.as_ref())?; - downcast_primitive! { - data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker), - DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker) - DataType::Utf8View => merge_helper!(StringViewArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker) - DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker) - DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker) - DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker) - _ => {} - } + let Some(expressions) = expressions else { + return internal_err!("Sort expressions cannot be empty for streaming merge"); + }; + + if !sorted_spill_files.is_empty() { + // Unwrapping mandatory fields + let schema = schema.expect("Schema cannot be empty for streaming merge"); + let metrics = metrics.expect("Metrics cannot be empty for streaming merge"); + let batch_size = + batch_size.expect("Batch size cannot be empty for streaming merge"); + let reservation = + reservation.expect("Reservation cannot be empty for streaming merge"); + + return Ok(MultiLevelMergeBuilder::new( + spill_manager.expect("spill_manager should exist"), + schema, + sorted_spill_files, + streams, + expressions.clone(), + metrics, + batch_size, + reservation, + fetch, + enable_round_robin_tie_breaker, + ) + .create_spillable_merge_stream()); } + let streams = RowCursorStream::try_new( schema.as_ref(), expressions, diff --git a/datafusion/physical-plan/src/spill/get_size.rs b/datafusion/physical-plan/src/spill/get_size.rs new file mode 100644 index 000000000000..1cb56b2678ce --- /dev/null +++ b/datafusion/physical-plan/src/spill/get_size.rs @@ -0,0 +1,216 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, ArrowPrimitiveType, BooleanArray, FixedSizeBinaryArray, + FixedSizeListArray, GenericByteArray, GenericListArray, OffsetSizeTrait, + PrimitiveArray, RecordBatch, StructArray, +}; +use arrow::buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer}; +use arrow::datatypes::{ArrowNativeType, ByteArrayType}; +use arrow::downcast_primitive_array; +use arrow_schema::DataType; + +/// TODO - NEED TO MOVE THIS TO ARROW +/// this is needed as unlike `get_buffer_memory_size` or `get_array_memory_size` +/// or even `get_record_batch_memory_size` calling it on a batch before writing to spill and after reading it back +/// will return the same size (minus some optimization that IPC writer does for dictionaries) +/// +pub trait GetActualSize { + fn get_actually_used_size(&self) -> usize; +} + +impl GetActualSize for RecordBatch { + fn get_actually_used_size(&self) -> usize { + self.columns() + .iter() + .map(|c| c.get_actually_used_size()) + .sum() + } +} + +pub trait GetActualSizeArray: Array { + fn get_actually_used_size(&self) -> usize { + self.get_buffer_memory_size() + } +} + +impl GetActualSize for dyn Array { + fn get_actually_used_size(&self) -> usize { + use arrow::array::AsArray; + let array = self; + + // we can avoid this is we move this trait function to be in arrow + downcast_primitive_array!( + array => { + array.get_actually_used_size() + }, + DataType::Utf8 => { + array.as_string::().get_actually_used_size() + }, + DataType::LargeUtf8 => { + array.as_string::().get_actually_used_size() + }, + DataType::Binary => { + array.as_binary::().get_actually_used_size() + }, + DataType::LargeBinary => { + array.as_binary::().get_actually_used_size() + }, + DataType::FixedSizeBinary(_) => { + array.as_fixed_size_binary().get_actually_used_size() + }, + DataType::Struct(_) => { + array.as_struct().get_actually_used_size() + }, + DataType::List(_) => { + array.as_list::().get_actually_used_size() + }, + DataType::LargeList(_) => { + array.as_list::().get_actually_used_size() + }, + DataType::FixedSizeList(_, _) => { + array.as_fixed_size_list().get_actually_used_size() + }, + DataType::Boolean => { + array.as_boolean().get_actually_used_size() + }, + + _ => { + array.get_buffer_memory_size() + } + ) + } +} + +impl GetActualSizeArray for ArrayRef { + fn get_actually_used_size(&self) -> usize { + self.as_ref().get_actually_used_size() + } +} + +impl GetActualSize for Option<&NullBuffer> { + fn get_actually_used_size(&self) -> usize { + self.map(|b| b.get_actually_used_size()).unwrap_or(0) + } +} + +impl GetActualSize for NullBuffer { + fn get_actually_used_size(&self) -> usize { + // len return in bits + self.len() / 8 + } +} +impl GetActualSize for Buffer { + fn get_actually_used_size(&self) -> usize { + self.len() + } +} + +impl GetActualSize for BooleanBuffer { + fn get_actually_used_size(&self) -> usize { + // len return in bits + self.len() / 8 + } +} + +impl GetActualSize for OffsetBuffer { + fn get_actually_used_size(&self) -> usize { + self.inner().inner().get_actually_used_size() + } +} + +impl GetActualSizeArray for BooleanArray { + fn get_actually_used_size(&self) -> usize { + let null_size = self.nulls().get_actually_used_size(); + let values_size = self.values().get_actually_used_size(); + null_size + values_size + } +} + +impl GetActualSizeArray for GenericByteArray { + fn get_actually_used_size(&self) -> usize { + let null_size = self.nulls().get_actually_used_size(); + let offsets_size = self.offsets().get_actually_used_size(); + + let values_size = { + let first_offset = self.value_offsets()[0].as_usize(); + let last_offset = self.value_offsets()[self.len()].as_usize(); + last_offset - first_offset + }; + null_size + offsets_size + values_size + } +} + +impl GetActualSizeArray for FixedSizeBinaryArray { + fn get_actually_used_size(&self) -> usize { + let null_size = self.nulls().get_actually_used_size(); + let values_size = self.value_length() as usize * self.len(); + + null_size + values_size + } +} + +impl GetActualSizeArray for PrimitiveArray { + fn get_actually_used_size(&self) -> usize { + let null_size = self.nulls().get_actually_used_size(); + let values_size = self.values().inner().get_actually_used_size(); + + null_size + values_size + } +} + +impl GetActualSizeArray for GenericListArray { + fn get_actually_used_size(&self) -> usize { + let null_size = self.nulls().get_actually_used_size(); + let offsets_size = self.offsets().get_actually_used_size(); + + let values_size = { + let first_offset = self.value_offsets()[0].as_usize(); + let last_offset = self.value_offsets()[self.len()].as_usize(); + + self.values() + .slice(first_offset, last_offset - first_offset) + .get_actually_used_size() + }; + null_size + offsets_size + values_size + } +} + +impl GetActualSizeArray for FixedSizeListArray { + fn get_actually_used_size(&self) -> usize { + let null_size = self.nulls().get_actually_used_size(); + let values_size = self.values().get_actually_used_size(); + + null_size + values_size + } +} + +impl GetActualSizeArray for StructArray { + fn get_actually_used_size(&self) -> usize { + let null_size = self.nulls().get_actually_used_size(); + let values_size = self + .columns() + .iter() + .map(|array| array.get_actually_used_size()) + .sum::(); + + null_size + values_size + } +} + +// TODO - need to add to more arrays diff --git a/datafusion/physical-plan/src/spill/in_progress_spill_file.rs b/datafusion/physical-plan/src/spill/in_progress_spill_file.rs index 7617e0a22a50..14917e23b792 100644 --- a/datafusion/physical-plan/src/spill/in_progress_spill_file.rs +++ b/datafusion/physical-plan/src/spill/in_progress_spill_file.rs @@ -67,6 +67,7 @@ impl InProgressSpillFile { self.writer = Some(IPCStreamWriter::new( in_progress_file.path(), schema.as_ref(), + self.spill_writer.compression, )?); // Update metrics @@ -74,7 +75,7 @@ impl InProgressSpillFile { } } if let Some(writer) = &mut self.writer { - let (spilled_rows, spilled_bytes) = writer.write(batch)?; + let (spilled_rows, _) = writer.write(batch)?; if let Some(in_progress_file) = &mut self.in_progress_file { in_progress_file.update_disk_usage()?; } else { @@ -82,7 +83,6 @@ impl InProgressSpillFile { } // Update metrics - self.spill_writer.metrics.spilled_bytes.add(spilled_bytes); self.spill_writer.metrics.spilled_rows.add(spilled_rows); } Ok(()) @@ -97,6 +97,14 @@ impl InProgressSpillFile { return Ok(None); } + // Since spill files are append-only, add the file size to spilled_bytes + if let Some(in_progress_file) = &mut self.in_progress_file { + // Since writer.finish() writes continuation marker and message length at the end + in_progress_file.update_disk_usage()?; + let size = in_progress_file.current_disk_usage(); + self.spill_writer.metrics.spilled_bytes.add(size as usize); + } + Ok(self.in_progress_file.take()) } } diff --git a/datafusion/physical-plan/src/spill/mod.rs b/datafusion/physical-plan/src/spill/mod.rs index 1101616a4106..0542899a8918 100644 --- a/datafusion/physical-plan/src/spill/mod.rs +++ b/datafusion/physical-plan/src/spill/mod.rs @@ -17,6 +17,7 @@ //! Defines the spilling functions +pub(crate) mod get_size; pub(crate) mod in_progress_spill_file; pub(crate) mod spill_manager; @@ -30,9 +31,14 @@ use std::task::{Context, Poll}; use arrow::array::ArrayData; use arrow::datatypes::{Schema, SchemaRef}; -use arrow::ipc::{reader::StreamReader, writer::StreamWriter}; +use arrow::ipc::{ + reader::StreamReader, + writer::{IpcWriteOptions, StreamWriter}, + MetadataVersion, +}; use arrow::record_batch::RecordBatch; +use datafusion_common::config::SpillCompression; use datafusion_common::{exec_datafusion_err, DataFusionError, HashSet, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::disk_manager::RefCountedTempFile; @@ -194,7 +200,8 @@ pub fn spill_record_batch_by_size( ) -> Result<()> { let mut offset = 0; let total_rows = batch.num_rows(); - let mut writer = IPCStreamWriter::new(&path, schema.as_ref())?; + let mut writer = + IPCStreamWriter::new(&path, schema.as_ref(), SpillCompression::Uncompressed)?; while offset < total_rows { let length = std::cmp::min(total_rows - offset, batch_size_rows); @@ -292,15 +299,27 @@ struct IPCStreamWriter { impl IPCStreamWriter { /// Create new writer - pub fn new(path: &Path, schema: &Schema) -> Result { + pub fn new( + path: &Path, + schema: &Schema, + compression_type: SpillCompression, + ) -> Result { let file = File::create(path).map_err(|e| { exec_datafusion_err!("Failed to create partition file at {path:?}: {e:?}") })?; + + let metadata_version = MetadataVersion::V5; + let alignment = 8; + let mut write_options = + IpcWriteOptions::try_new(alignment, false, metadata_version)?; + write_options = write_options.try_with_compression(compression_type.into())?; + + let writer = StreamWriter::try_new_with_options(file, schema, write_options)?; Ok(Self { num_batches: 0, num_rows: 0, num_bytes: 0, - writer: StreamWriter::try_new(file, schema)?, + writer, }) } @@ -332,7 +351,7 @@ mod tests { use crate::metrics::SpillMetrics; use crate::spill::spill_manager::SpillManager; use crate::test::build_table_i32; - use arrow::array::{Float64Array, Int32Array, ListArray, StringArray}; + use arrow::array::{ArrayRef, Float64Array, Int32Array, ListArray, StringArray}; use arrow::compute::cast; use arrow::datatypes::{DataType, Field, Int32Type, Schema}; use arrow::record_batch::RecordBatch; @@ -470,6 +489,113 @@ mod tests { Ok(()) } + fn build_compressible_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, true), + ])); + + let a: ArrayRef = Arc::new(StringArray::from_iter_values(std::iter::repeat_n( + "repeated", 100, + ))); + let b: ArrayRef = Arc::new(Int32Array::from(vec![1; 100])); + let c: ArrayRef = Arc::new(Int32Array::from(vec![2; 100])); + + RecordBatch::try_new(schema, vec![a, b, c]).unwrap() + } + + async fn validate( + spill_manager: &SpillManager, + spill_file: RefCountedTempFile, + num_rows: usize, + schema: SchemaRef, + batch_count: usize, + ) -> Result<()> { + let spilled_rows = spill_manager.metrics.spilled_rows.value(); + assert_eq!(spilled_rows, num_rows); + + let stream = spill_manager.read_spill_as_stream(spill_file)?; + assert_eq!(stream.schema(), schema); + + let batches = collect(stream).await?; + assert_eq!(batches.len(), batch_count); + + Ok(()) + } + + #[tokio::test] + async fn test_spill_compression() -> Result<()> { + let batch = build_compressible_batch(); + let num_rows = batch.num_rows(); + let schema = batch.schema(); + let batch_count = 1; + let batches = [batch]; + + // Construct SpillManager + let env = Arc::new(RuntimeEnv::default()); + let uncompressed_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let lz4_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let zstd_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let uncompressed_spill_manager = SpillManager::new( + Arc::clone(&env), + uncompressed_metrics, + Arc::clone(&schema), + ); + let lz4_spill_manager = + SpillManager::new(Arc::clone(&env), lz4_metrics, Arc::clone(&schema)) + .with_compression_type(SpillCompression::Lz4Frame); + let zstd_spill_manager = + SpillManager::new(env, zstd_metrics, Arc::clone(&schema)) + .with_compression_type(SpillCompression::Zstd); + let uncompressed_spill_file = uncompressed_spill_manager + .spill_record_batch_and_finish(&batches, "Test")? + .unwrap(); + let lz4_spill_file = lz4_spill_manager + .spill_record_batch_and_finish(&batches, "Lz4_Test")? + .unwrap(); + let zstd_spill_file = zstd_spill_manager + .spill_record_batch_and_finish(&batches, "ZSTD_Test")? + .unwrap(); + assert!(uncompressed_spill_file.path().exists()); + assert!(lz4_spill_file.path().exists()); + assert!(zstd_spill_file.path().exists()); + + let lz4_spill_size = std::fs::metadata(lz4_spill_file.path())?.len(); + let zstd_spill_size = std::fs::metadata(zstd_spill_file.path())?.len(); + let uncompressed_spill_size = + std::fs::metadata(uncompressed_spill_file.path())?.len(); + + assert!(uncompressed_spill_size > lz4_spill_size); + assert!(uncompressed_spill_size > zstd_spill_size); + + validate( + &lz4_spill_manager, + lz4_spill_file, + num_rows, + Arc::clone(&schema), + batch_count, + ) + .await?; + validate( + &zstd_spill_manager, + zstd_spill_file, + num_rows, + Arc::clone(&schema), + batch_count, + ) + .await?; + validate( + &uncompressed_spill_manager, + uncompressed_spill_file, + num_rows, + schema, + batch_count, + ) + .await?; + Ok(()) + } + #[test] fn test_get_record_batch_memory_size() { // Create a simple record batch with two columns @@ -684,12 +810,13 @@ mod tests { Arc::new(StringArray::from(vec!["d", "e", "f"])), ], )?; - + // After appending each batch, spilled_rows should increase, while spill_file_count and + // spilled_bytes remain the same (spilled_bytes is updated only after finish() is called) in_progress_file.append_batch(&batch1)?; - verify_metrics(&in_progress_file, 1, 356, 3)?; + verify_metrics(&in_progress_file, 1, 0, 3)?; in_progress_file.append_batch(&batch2)?; - verify_metrics(&in_progress_file, 1, 712, 6)?; + verify_metrics(&in_progress_file, 1, 0, 6)?; let completed_file = in_progress_file.finish()?; assert!(completed_file.is_some()); diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index 78cd47a8bad0..40e0848c25ab 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -17,19 +17,19 @@ //! Define the `SpillManager` struct, which is responsible for reading and writing `RecordBatch`es to raw files based on the provided configurations. -use std::sync::Arc; - use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_execution::runtime_env::RuntimeEnv; +use std::sync::Arc; -use datafusion_common::Result; +use datafusion_common::{config::SpillCompression, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::SendableRecordBatchStream; -use crate::{common::spawn_buffered, metrics::SpillMetrics}; - use super::{in_progress_spill_file::InProgressSpillFile, SpillReaderStream}; +use crate::coop::cooperative; +use crate::spill::get_size::GetActualSize; +use crate::{common::spawn_buffered, metrics::SpillMetrics}; /// The `SpillManager` is responsible for the following tasks: /// - Reading and writing `RecordBatch`es to raw files based on the provided configurations. @@ -44,7 +44,8 @@ pub struct SpillManager { schema: SchemaRef, /// Number of batches to buffer in memory during disk reads batch_read_buffer_capacity: usize, - // TODO: Add general-purpose compression options + /// general-purpose compression options + pub(crate) compression: SpillCompression, } impl SpillManager { @@ -54,9 +55,23 @@ impl SpillManager { metrics, schema, batch_read_buffer_capacity: 2, + compression: SpillCompression::default(), } } + pub fn with_compression_type(mut self, spill_compression: SpillCompression) -> Self { + self.compression = spill_compression; + self + } + + pub fn with_batch_read_buffer_capacity( + mut self, + batch_read_buffer_capacity: usize, + ) -> Self { + self.batch_read_buffer_capacity = batch_read_buffer_capacity; + self + } + /// Creates a temporary file for in-progress operations, returning an error /// message if file creation fails. The file can be used to append batches /// incrementally and then finish the file when done. @@ -118,6 +133,156 @@ impl SpillManager { self.spill_record_batch_and_finish(&batches, request_description) } + /// Refer to the documentation for [`Self::spill_record_batch_and_finish`]. This method + /// additionally spills the `RecordBatch` into smaller batches, divided by `row_limit`. + /// + /// # Errors + /// - Returns an error if spilling would exceed the disk usage limit configured + /// by `max_temp_directory_size` in `DiskManager` + pub(crate) fn spill_record_batch_by_size_and_return_max_batch_memory( + &self, + batch: &RecordBatch, + request_description: &str, + row_limit: usize, + ) -> Result> { + let total_rows = batch.num_rows(); + let mut batches = Vec::new(); + let mut offset = 0; + + // It's ok to calculate all slices first, because slicing is zero-copy. + while offset < total_rows { + let length = std::cmp::min(total_rows - offset, row_limit); + let sliced_batch = batch.slice(offset, length); + batches.push(sliced_batch); + offset += length; + } + + let mut in_progress_file = self.create_in_progress_file(request_description)?; + + let mut max_record_batch_size = 0; + + for batch in batches { + in_progress_file.append_batch(&batch)?; + + max_record_batch_size = + max_record_batch_size.max(batch.get_actually_used_size()); + } + + let file = in_progress_file.finish()?; + + Ok(file.map(|f| (f, max_record_batch_size))) + } + + /// Spill the `RecordBatch` to disk as smaller batches + /// split by `batch_size_rows`. + /// + /// will return the spill file and the size of the largest batch in memory + pub async fn spill_record_batch_stream_by_size( + &self, + stream: &mut SendableRecordBatchStream, + batch_size_rows: usize, + request_msg: &str, + ) -> Result> { + use futures::StreamExt; + let mut in_progress_file = self.create_in_progress_file(request_msg)?; + + let mut max_record_batch_size = 0; + + let mut maybe_last_batch: Option = None; + + while let Some(batch) = stream.next().await { + let mut batch = batch?; + + if let Some(mut last_batch) = maybe_last_batch.take() { + assert!( + last_batch.num_rows() < batch_size_rows, + "last batch size must be smaller than the requested batch size" + ); + + // Get the number of rows to take from current batch so the last_batch + // will have `batch_size_rows` rows + let current_batch_offset = std::cmp::min( + // rows needed to fill + batch_size_rows - last_batch.num_rows(), + // Current length of the batch + batch.num_rows(), + ); + + // if have last batch that has less rows than concat and spill + last_batch = arrow::compute::concat_batches( + &stream.schema(), + &[last_batch, batch.slice(0, current_batch_offset)], + )?; + + assert!(last_batch.num_rows() <= batch_size_rows, "must build a batch that is smaller or equal to the requested batch size from the current batch"); + + // If not enough rows + if last_batch.num_rows() < batch_size_rows { + // keep the last batch for next iteration + maybe_last_batch = Some(last_batch); + continue; + } + + max_record_batch_size = + max_record_batch_size.max(last_batch.get_actually_used_size()); + + in_progress_file.append_batch(&last_batch)?; + + if current_batch_offset == batch.num_rows() { + // No remainder + continue; + } + + // remainder + batch = batch.slice( + current_batch_offset, + batch.num_rows() - current_batch_offset, + ); + } + + let mut offset = 0; + let total_rows = batch.num_rows(); + + // Keep slicing the batch until we have left with a batch that is smaller than + // the wanted batch size + while total_rows - offset >= batch_size_rows { + let batch = batch.slice(offset, batch_size_rows); + offset += batch_size_rows; + + max_record_batch_size = + max_record_batch_size.max(batch.get_actually_used_size()); + + in_progress_file.append_batch(&batch)?; + } + + // If there is a remainder for the current batch that is smaller than the wanted batch size + // keep it for next iteration + if offset < total_rows { + // remainder + let batch = batch.slice(offset, total_rows - offset); + + maybe_last_batch = Some(batch); + } + } + if let Some(last_batch) = maybe_last_batch.take() { + assert!( + last_batch.num_rows() < batch_size_rows, + "last batch size must be smaller than the requested batch size" + ); + + // Write it to disk + in_progress_file.append_batch(&last_batch)?; + + max_record_batch_size = + max_record_batch_size.max(last_batch.get_actually_used_size()); + } + + // Flush disk + let spill_file = in_progress_file.finish()?; + + Ok(spill_file.map(|f| (f, max_record_batch_size))) + } + /// Reads a spill file as a stream. The file must be created by the current `SpillManager`. /// This method will generate output in FIFO order: the batch appended first /// will be read first. @@ -125,10 +290,10 @@ impl SpillManager { &self, spill_file_path: RefCountedTempFile, ) -> Result { - let stream = Box::pin(SpillReaderStream::new( + let stream = Box::pin(cooperative(SpillReaderStream::new( Arc::clone(&self.schema), spill_file_path, - )); + ))); Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) } diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index 18c472a7e187..d4e6ba4c96c7 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -22,12 +22,13 @@ use std::fmt::Debug; use std::sync::Arc; use super::{DisplayAs, DisplayFormatType, PlanProperties}; +use crate::coop::make_cooperative; use crate::display::{display_orderings, ProjectSchemaDisplay}; -use crate::execution_plan::{Boundedness, EmissionType}; +use crate::execution_plan::{Boundedness, EmissionType, SchedulingType}; use crate::limit::LimitStream; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::projection::{ - all_alias_free_columns, new_projections_for_columns, update_expr, ProjectionExec, + all_alias_free_columns, new_projections_for_columns, update_ordering, ProjectionExec, }; use crate::stream::RecordBatchStreamAdapter; use crate::{ExecutionPlan, Partitioning, SendableRecordBatchStream}; @@ -35,7 +36,7 @@ use crate::{ExecutionPlan, Partitioning, SendableRecordBatchStream}; use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::{internal_err, plan_err, Result}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use async_trait::async_trait; use futures::stream::StreamExt; @@ -99,7 +100,7 @@ impl StreamingTableExec { projected_output_ordering.into_iter().collect::>(); let cache = Self::compute_properties( Arc::clone(&projected_schema), - &projected_output_ordering, + projected_output_ordering.clone(), &partitions, infinite, ); @@ -146,7 +147,7 @@ impl StreamingTableExec { /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties( schema: SchemaRef, - orderings: &[LexOrdering], + orderings: Vec, partitions: &[Arc], infinite: bool, ) -> PlanProperties { @@ -168,6 +169,7 @@ impl StreamingTableExec { EmissionType::Incremental, boundedness, ) + .with_scheduling_type(SchedulingType::Cooperative) } } @@ -262,7 +264,7 @@ impl ExecutionPlan for StreamingTableExec { partition: usize, ctx: Arc, ) -> Result { - let stream = self.partitions[partition].execute(ctx); + let stream = self.partitions[partition].execute(Arc::clone(&ctx)); let projected_stream = match self.projection.clone() { Some(projection) => Box::pin(RecordBatchStreamAdapter::new( Arc::clone(&self.projected_schema), @@ -272,16 +274,13 @@ impl ExecutionPlan for StreamingTableExec { )), None => stream, }; + let stream = make_cooperative(projected_stream); + Ok(match self.limit { - None => projected_stream, + None => stream, Some(fetch) => { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - Box::pin(LimitStream::new( - projected_stream, - 0, - Some(fetch), - baseline_metrics, - )) + Box::pin(LimitStream::new(stream, 0, Some(fetch), baseline_metrics)) } }) } @@ -302,24 +301,15 @@ impl ExecutionPlan for StreamingTableExec { let new_projections = new_projections_for_columns( projection, &streaming_table_projections - .unwrap_or((0..self.schema().fields().len()).collect()), + .unwrap_or_else(|| (0..self.schema().fields().len()).collect()), ); let mut lex_orderings = vec![]; - for lex_ordering in self.projected_output_ordering().into_iter() { - let mut orderings = LexOrdering::default(); - for order in lex_ordering { - let Some(new_ordering) = - update_expr(&order.expr, projection.expr(), false)? - else { - return Ok(None); - }; - orderings.push(PhysicalSortExpr { - expr: new_ordering, - options: order.options, - }); - } - lex_orderings.push(orderings); + for ordering in self.projected_output_ordering().into_iter() { + let Some(ordering) = update_ordering(ordering, projection.expr())? else { + return Ok(None); + }; + lex_orderings.push(ordering); } StreamingTableExec::try_new( diff --git a/datafusion/physical-plan/src/test.rs b/datafusion/physical-plan/src/test.rs index a2dc1d778436..5e6410a0171e 100644 --- a/datafusion/physical-plan/src/test.rs +++ b/datafusion/physical-plan/src/test.rs @@ -40,10 +40,12 @@ use datafusion_common::{ config::ConfigOptions, internal_err, project_schema, Result, Statistics, }; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{ - equivalence::ProjectionMapping, expressions::Column, utils::collect_columns, - EquivalenceProperties, LexOrdering, Partitioning, +use datafusion_physical_expr::equivalence::{ + OrderingEquivalenceClass, ProjectionMapping, }; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::utils::collect_columns; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, Partitioning}; use futures::{Future, FutureExt}; @@ -87,9 +89,7 @@ impl DisplayAs for TestMemoryExec { let output_ordering = self .sort_information .first() - .map(|output_ordering| { - format!(", output_ordering={}", output_ordering) - }) + .map(|output_ordering| format!(", output_ordering={output_ordering}")) .unwrap_or_default(); let eq_properties = self.eq_properties(); @@ -97,12 +97,12 @@ impl DisplayAs for TestMemoryExec { let constraints = if constraints.is_empty() { String::new() } else { - format!(", {}", constraints) + format!(", {constraints}") }; let limit = self .fetch - .map_or(String::new(), |limit| format!(", fetch={}", limit)); + .map_or(String::new(), |limit| format!(", fetch={limit}")); if self.show_sizes { write!( f, @@ -170,7 +170,15 @@ impl ExecutionPlan for TestMemoryExec { } fn statistics(&self) -> Result { - self.statistics() + self.statistics_inner() + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + Ok(Statistics::new_unknown(&self.schema)) + } else { + self.statistics_inner() + } } fn fetch(&self) -> Option { @@ -210,11 +218,11 @@ impl TestMemoryExec { fn eq_properties(&self) -> EquivalenceProperties { EquivalenceProperties::new_with_orderings( Arc::clone(&self.projected_schema), - self.sort_information.as_slice(), + self.sort_information.clone(), ) } - fn statistics(&self) -> Result { + fn statistics_inner(&self) -> Result { Ok(common::compute_record_batch_statistics( &self.partitions, &self.schema, @@ -234,7 +242,7 @@ impl TestMemoryExec { cache: PlanProperties::new( EquivalenceProperties::new_with_orderings( Arc::clone(&projected_schema), - vec![].as_slice(), + Vec::::new(), ), Partitioning::UnknownPartitioning(partitions.len()), EmissionType::Incremental, @@ -292,7 +300,7 @@ impl TestMemoryExec { } /// refer to `try_with_sort_information` at MemorySourceConfig for more information. - /// https://github.com/apache/datafusion/tree/main/datafusion/datasource/src/memory.rs + /// pub fn try_with_sort_information( mut self, mut sort_information: Vec, @@ -318,24 +326,21 @@ impl TestMemoryExec { // If there is a projection on the source, we also need to project orderings if let Some(projection) = &self.projection { + let base_schema = self.original_schema(); + let proj_exprs = projection.iter().map(|idx| { + let name = base_schema.field(*idx).name(); + (Arc::new(Column::new(name, *idx)) as _, name.to_string()) + }); + let projection_mapping = + ProjectionMapping::try_new(proj_exprs, &base_schema)?; let base_eqp = EquivalenceProperties::new_with_orderings( - self.original_schema(), - &sort_information, + Arc::clone(&base_schema), + sort_information, ); - let proj_exprs = projection - .iter() - .map(|idx| { - let base_schema = self.original_schema(); - let name = base_schema.field(*idx).name(); - (Arc::new(Column::new(name, *idx)) as _, name.to_string()) - }) - .collect::>(); - let projection_mapping = - ProjectionMapping::try_new(&proj_exprs, &self.original_schema())?; - sort_information = base_eqp - .project(&projection_mapping, Arc::clone(&self.projected_schema)) - .into_oeq_class() - .into_inner(); + let proj_eqp = + base_eqp.project(&projection_mapping, Arc::clone(&self.projected_schema)); + let oeq_class: OrderingEquivalenceClass = proj_eqp.into(); + sort_information = oeq_class.into(); } self.sort_information = sort_information; @@ -450,7 +455,7 @@ pub fn make_partition_utf8(sz: i32) -> RecordBatch { let seq_start = 0; let seq_end = sz; let values = (seq_start..seq_end) - .map(|i| format!("test_long_string_that_is_roughly_42_bytes_{}", i)) + .map(|i| format!("test_long_string_that_is_roughly_42_bytes_{i}")) .collect::>(); let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Utf8, true)])); let mut string_array = arrow::array::StringArray::from(values); diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index d0a0d25779cc..12ffca871f07 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -255,6 +255,13 @@ impl ExecutionPlan for MockExec { // Panics if one of the batches is an error fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema)); + } let data: Result> = self .data .iter() @@ -405,6 +412,13 @@ impl ExecutionPlan for BarrierExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema)); + } Ok(common::compute_record_batch_statistics( &self.data, &self.schema, @@ -590,6 +604,14 @@ impl ExecutionPlan for StatisticsExec { fn statistics(&self) -> Result { Ok(self.stats.clone()) } + + fn partition_statistics(&self, partition: Option) -> Result { + Ok(if partition.is_some() { + Statistics::new_unknown(&self.schema) + } else { + self.stats.clone() + }) + } } /// Execution plan that emits streams that block forever. diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 0b5780b9143f..71029662f5f5 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -18,25 +18,32 @@ //! TopK: Combination of Sort / LIMIT use arrow::{ + array::Array, compute::interleave_record_batch, row::{RowConverter, Rows, SortField}, }; +use datafusion_expr::{ColumnarValue, Operator}; use std::mem::size_of; use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder}; use crate::spill::get_record_batch_memory_size; use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}; + use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::SchemaRef; -use datafusion_common::Result; -use datafusion_common::{internal_datafusion_err, HashMap}; +use datafusion_common::{ + internal_datafusion_err, internal_err, HashMap, Result, ScalarValue, +}; use datafusion_execution::{ memory_pool::{MemoryConsumer, MemoryReservation}, runtime_env::RuntimeEnv, }; -use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr::{ + expressions::{is_not_null, is_null, lit, BinaryExpr, DynamicFilterPhysicalExpr}, + PhysicalExpr, +}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; /// Global TopK /// @@ -102,7 +109,7 @@ pub struct TopK { /// The target number of rows for output batches batch_size: usize, /// sort expressions - expr: Arc<[PhysicalSortExpr]>, + expr: LexOrdering, /// row converter, for sort keys row_converter: RowConverter, /// scratch space for converting rows @@ -113,6 +120,8 @@ pub struct TopK { common_sort_prefix_converter: Option, /// Common sort prefix between the input and the sort expressions to allow early exit optimization common_sort_prefix: Arc<[PhysicalSortExpr]>, + /// Filter matching the state of the `TopK` heap used for dynamic filter pushdown + filter: Option>, /// If true, indicates that all rows of subsequent batches are guaranteed /// to be greater (by byte order, after row conversion) than the top K, /// which means the top K won't change and the computation can be finished early. @@ -123,7 +132,7 @@ pub struct TopK { const ESTIMATED_BYTES_PER_ROW: usize = 20; fn build_sort_fields( - ordering: &LexOrdering, + ordering: &[PhysicalSortExpr], schema: &SchemaRef, ) -> Result> { ordering @@ -145,17 +154,18 @@ impl TopK { pub fn try_new( partition_id: usize, schema: SchemaRef, - common_sort_prefix: LexOrdering, + common_sort_prefix: Vec, expr: LexOrdering, k: usize, batch_size: usize, runtime: Arc, metrics: &ExecutionPlanMetricsSet, + filter: Option>, ) -> Result { let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]")) .register(&runtime.memory_pool); - let sort_fields: Vec<_> = build_sort_fields(&expr, &schema)?; + let sort_fields = build_sort_fields(&expr, &schema)?; // TODO there is potential to add special cases for single column sort fields // to improve performance @@ -166,8 +176,7 @@ impl TopK { let prefix_row_converter = if common_sort_prefix.is_empty() { None } else { - let input_sort_fields: Vec<_> = - build_sort_fields(&common_sort_prefix, &schema)?; + let input_sort_fields = build_sort_fields(&common_sort_prefix, &schema)?; Some(RowConverter::new(input_sort_fields)?) }; @@ -176,13 +185,14 @@ impl TopK { metrics: TopKMetrics::new(metrics, partition_id), reservation, batch_size, - expr: Arc::from(expr), + expr, row_converter, scratch_rows, heap: TopKHeap::new(k, batch_size), common_sort_prefix_converter: prefix_row_converter, common_sort_prefix: Arc::from(common_sort_prefix), finished: false, + filter, }) } @@ -207,34 +217,156 @@ impl TopK { rows.clear(); self.row_converter.append(rows, &sort_keys)?; - // TODO make this algorithmically better?: - // Idea: filter out rows >= self.heap.max() early (before passing to `RowConverter`) - // this avoids some work and also might be better vectorizable. let mut batch_entry = self.heap.register_batch(batch.clone()); - for (index, row) in rows.iter().enumerate() { + + let replacements = + self.find_new_topk_items(0..sort_keys[0].len(), &mut batch_entry); + + if replacements > 0 { + self.metrics.row_replacements.add(replacements); + + self.heap.insert_batch_entry(batch_entry); + + // conserve memory + self.heap.maybe_compact()?; + + // update memory reservation + self.reservation.try_resize(self.size())?; + + // flag the topK as finished if we know that all + // subsequent batches are guaranteed to be greater (by byte order, after row conversion) than the top K, + // which means the top K won't change and the computation can be finished early. + self.attempt_early_completion(&batch)?; + + // update the filter representation of our TopK heap + self.update_filter()?; + } + + Ok(()) + } + + fn find_new_topk_items( + &mut self, + items: impl Iterator, + batch_entry: &mut RecordBatchEntry, + ) -> usize { + let mut replacements = 0; + let rows = &mut self.scratch_rows; + for (index, row) in items.zip(rows.iter()) { match self.heap.max() { // heap has k items, and the new row is greater than the // current max in the heap ==> it is not a new topk Some(max_row) if row.as_ref() >= max_row.row() => {} // don't yet have k items or new item is lower than the currently k low values None | Some(_) => { - self.heap.add(&mut batch_entry, row, index); - self.metrics.row_replacements.add(1); + self.heap.add(batch_entry, row, index); + replacements += 1; } } } - self.heap.insert_batch_entry(batch_entry); + replacements + } - // conserve memory - self.heap.maybe_compact()?; + /// Update the filter representation of our TopK heap. + /// For example, given the sort expression `ORDER BY a DESC, b ASC LIMIT 3`, + /// and the current heap values `[(1, 5), (1, 4), (2, 3)]`, + /// the filter will be updated to: + /// + /// ```sql + /// (a > 1 OR (a = 1 AND b < 5)) AND + /// (a > 1 OR (a = 1 AND b < 4)) AND + /// (a > 2 OR (a = 2 AND b < 3)) + /// ``` + fn update_filter(&mut self) -> Result<()> { + let Some(filter) = &self.filter else { + return Ok(()); + }; + let Some(thresholds) = self.heap.get_threshold_values(&self.expr)? else { + return Ok(()); + }; - // update memory reservation - self.reservation.try_resize(self.size())?; + // Create filter expressions for each threshold + let mut filters: Vec> = + Vec::with_capacity(thresholds.len()); - // flag the topK as finished if we know that all - // subsequent batches are guaranteed to be greater (by byte order, after row conversion) than the top K, - // which means the top K won't change and the computation can be finished early. - self.attempt_early_completion(&batch)?; + let mut prev_sort_expr: Option> = None; + for (sort_expr, value) in self.expr.iter().zip(thresholds.iter()) { + // Create the appropriate operator based on sort order + let op = if sort_expr.options.descending { + // For descending sort, we want col > threshold (exclude smaller values) + Operator::Gt + } else { + // For ascending sort, we want col < threshold (exclude larger values) + Operator::Lt + }; + + let value_null = value.is_null(); + + let comparison = Arc::new(BinaryExpr::new( + Arc::clone(&sort_expr.expr), + op, + lit(value.clone()), + )); + + let comparison_with_null = match (sort_expr.options.nulls_first, value_null) { + // For nulls first, transform to (threshold.value is not null) and (threshold.expr is null or comparison) + (true, true) => lit(false), + (true, false) => Arc::new(BinaryExpr::new( + is_null(Arc::clone(&sort_expr.expr))?, + Operator::Or, + comparison, + )), + // For nulls last, transform to (threshold.value is null and threshold.expr is not null) + // or (threshold.value is not null and comparison) + (false, true) => is_not_null(Arc::clone(&sort_expr.expr))?, + (false, false) => comparison, + }; + + let mut eq_expr = Arc::new(BinaryExpr::new( + Arc::clone(&sort_expr.expr), + Operator::Eq, + lit(value.clone()), + )); + + if value_null { + eq_expr = Arc::new(BinaryExpr::new( + is_null(Arc::clone(&sort_expr.expr))?, + Operator::Or, + eq_expr, + )); + } + + // For a query like order by a, b, the filter for column `b` is only applied if + // the condition a = threshold.value (considering null equality) is met. + // Therefore, we add equality predicates for all preceding fields to the filter logic of the current field, + // and include the current field's equality predicate in `prev_sort_expr` for use with subsequent fields. + match prev_sort_expr.take() { + None => { + prev_sort_expr = Some(eq_expr); + filters.push(comparison_with_null); + } + Some(p) => { + filters.push(Arc::new(BinaryExpr::new( + Arc::clone(&p), + Operator::And, + comparison_with_null, + ))); + + prev_sort_expr = + Some(Arc::new(BinaryExpr::new(p, Operator::And, eq_expr))); + } + } + } + + let dynamic_predicate = filters + .into_iter() + .reduce(|a, b| Arc::new(BinaryExpr::new(a, Operator::Or, b))); + + if let Some(predicate) = dynamic_predicate { + if !predicate.eq(&lit(true)) { + filter.update(predicate)?; + } + } Ok(()) } @@ -328,6 +460,7 @@ impl TopK { common_sort_prefix_converter: _, common_sort_prefix: _, finished: _, + filter: _, } = self; let _timer = metrics.baseline.elapsed_compute().timer(); // time updated on drop @@ -570,6 +703,47 @@ impl TopKHeap { + self.store.size() + self.owned_bytes } + + fn get_threshold_values( + &self, + sort_exprs: &[PhysicalSortExpr], + ) -> Result>> { + // If the heap doesn't have k elements yet, we can't create thresholds + let max_row = match self.max() { + Some(row) => row, + None => return Ok(None), + }; + + // Get the batch that contains the max row + let batch_entry = match self.store.get(max_row.batch_id) { + Some(entry) => entry, + None => return internal_err!("Invalid batch ID in TopKRow"), + }; + + // Extract threshold values for each sort expression + let mut scalar_values = Vec::with_capacity(sort_exprs.len()); + for sort_expr in sort_exprs { + // Extract the value for this column from the max row + let expr = Arc::clone(&sort_expr.expr); + let value = expr.evaluate(&batch_entry.batch.slice(max_row.index, 1))?; + + // Convert to scalar value - should be a single value since we're evaluating on a single row batch + let scalar = match value { + ColumnarValue::Scalar(scalar) => scalar, + ColumnarValue::Array(array) if array.len() == 1 => { + // Extract the first (and only) value from the array + ScalarValue::try_from_array(&array, 0)? + } + array => { + return internal_err!("Expected a scalar value, got {:?}", array) + } + }; + + scalar_values.push(scalar); + } + + Ok(Some(scalar_values)) + } } /// Represents one of the top K rows held in this heap. Orders @@ -821,8 +995,8 @@ mod tests { }; // Input ordering uses only column "a" (a prefix of the full sort). - let input_ordering = LexOrdering::from(vec![sort_expr_a.clone()]); - let full_expr = LexOrdering::from(vec![sort_expr_a, sort_expr_b]); + let prefix = vec![sort_expr_a.clone()]; + let full_expr = LexOrdering::from([sort_expr_a, sort_expr_b]); // Create a dummy runtime environment and metrics. let runtime = Arc::new(RuntimeEnv::default()); @@ -832,12 +1006,13 @@ mod tests { let mut topk = TopK::try_new( 0, Arc::clone(&schema), - input_ordering, + prefix, full_expr, 3, 2, runtime, &metrics, + None, )?; // Create the first batch with two columns: diff --git a/datafusion/physical-plan/src/tree_node.rs b/datafusion/physical-plan/src/tree_node.rs index 69b0a165315e..78ba984ed1a5 100644 --- a/datafusion/physical-plan/src/tree_node.rs +++ b/datafusion/physical-plan/src/tree_node.rs @@ -94,7 +94,7 @@ impl PlanContext { impl Display for PlanContext { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let node_string = displayable(self.plan.as_ref()).one_line(); - write!(f, "Node plan: {}", node_string)?; + write!(f, "Node plan: {node_string}")?; write!(f, "Node data: {}", self.data)?; write!(f, "") } diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 2b666093f29e..73d7933e7c05 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -248,7 +248,7 @@ impl ExecutionPlan for UnionExec { } } - warn!("Error in Union: Partition {} not found", partition); + warn!("Error in Union: Partition {partition} not found"); exec_err!("Partition {partition} not found in Union") } @@ -258,16 +258,36 @@ impl ExecutionPlan for UnionExec { } fn statistics(&self) -> Result { - let stats = self - .inputs - .iter() - .map(|stat| stat.statistics()) - .collect::>>()?; + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if let Some(partition_idx) = partition { + // For a specific partition, find which input it belongs to + let mut remaining_idx = partition_idx; + for input in &self.inputs { + let input_partition_count = input.output_partitioning().partition_count(); + if remaining_idx < input_partition_count { + // This partition belongs to this input + return input.partition_statistics(Some(remaining_idx)); + } + remaining_idx -= input_partition_count; + } + // If we get here, the partition index is out of bounds + Ok(Statistics::new_unknown(&self.schema())) + } else { + // Collect statistics from all inputs + let stats = self + .inputs + .iter() + .map(|input_exec| input_exec.partition_statistics(None)) + .collect::>>()?; - Ok(stats - .into_iter() - .reduce(stats_union) - .unwrap_or_else(|| Statistics::new_unknown(&self.schema()))) + Ok(stats + .into_iter() + .reduce(stats_union) + .unwrap_or_else(|| Statistics::new_unknown(&self.schema()))) + } } fn benefits_from_input_partitioning(&self) -> Vec { @@ -461,7 +481,7 @@ impl ExecutionPlan for InterleaveExec { ))); } - warn!("Error in InterleaveExec: Partition {} not found", partition); + warn!("Error in InterleaveExec: Partition {partition} not found"); exec_err!("Partition {partition} not found in InterleaveExec") } @@ -471,10 +491,17 @@ impl ExecutionPlan for InterleaveExec { } fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema())); + } let stats = self .inputs .iter() - .map(|stat| stat.statistics()) + .map(|stat| stat.partition_statistics(None)) .collect::>>()?; Ok(stats @@ -513,7 +540,12 @@ fn union_schema(inputs: &[Arc]) -> SchemaRef { let fields = (0..first_schema.fields().len()) .map(|i| { - inputs + // We take the name from the left side of the union to match how names are coerced during logical planning, + // which also uses the left side names. + let base_field = first_schema.field(i).clone(); + + // Coerce metadata and nullability across all inputs + let merged_field = inputs .iter() .enumerate() .map(|(input_idx, input)| { @@ -535,6 +567,9 @@ fn union_schema(inputs: &[Arc]) -> SchemaRef { // We can unwrap this because if inputs was empty, this would've already panic'ed when we // indexed into inputs[0]. .unwrap() + .with_name(base_field.name()); + + merged_field }) .collect::>(); @@ -642,15 +677,13 @@ fn stats_union(mut left: Statistics, right: Statistics) -> Statistics { mod tests { use super::*; use crate::collect; - use crate::test; - use crate::test::TestMemoryExec; + use crate::test::{self, TestMemoryExec}; use arrow::compute::SortOptions; use arrow::datatypes::DataType; use datafusion_common::ScalarValue; + use datafusion_physical_expr::equivalence::convert_to_orderings; use datafusion_physical_expr::expressions::col; - use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; - use datafusion_physical_expr_common::sort_expr::LexOrdering; // Generate a schema which consists of 7 columns (a, b, c, d, e, f, g) fn create_test_schema() -> Result { @@ -666,19 +699,6 @@ mod tests { Ok(schema) } - // Convert each tuple to PhysicalSortExpr - fn convert_to_sort_exprs( - in_data: &[(&Arc, SortOptions)], - ) -> LexOrdering { - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: Arc::clone(*expr), - options: *options, - }) - .collect::() - } - #[tokio::test] async fn test_union_partitions() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); @@ -854,18 +874,9 @@ mod tests { (first_child_orderings, second_child_orderings, union_orderings), ) in test_cases.iter().enumerate() { - let first_orderings = first_child_orderings - .iter() - .map(|ordering| convert_to_sort_exprs(ordering)) - .collect::>(); - let second_orderings = second_child_orderings - .iter() - .map(|ordering| convert_to_sort_exprs(ordering)) - .collect::>(); - let union_expected_orderings = union_orderings - .iter() - .map(|ordering| convert_to_sort_exprs(ordering)) - .collect::>(); + let first_orderings = convert_to_orderings(first_child_orderings); + let second_orderings = convert_to_orderings(second_child_orderings); + let union_expected_orderings = convert_to_orderings(union_orderings); let child1 = Arc::new(TestMemoryExec::update_cache(Arc::new( TestMemoryExec::try_new(&[], Arc::clone(&schema), None)? .try_with_sort_information(first_orderings)?, @@ -876,7 +887,7 @@ mod tests { ))); let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema)); - union_expected_eq.add_new_orderings(union_expected_orderings); + union_expected_eq.add_orderings(union_expected_orderings); let union = UnionExec::new(vec![child1, child2]); let union_eq_properties = union.properties().equivalence_properties(); @@ -897,7 +908,7 @@ mod tests { // Check whether orderings are same. let lhs_orderings = lhs.oeq_class(); let rhs_orderings = rhs.oeq_class(); - assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{}", err_msg); + assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{err_msg}"); for rhs_ordering in rhs_orderings.iter() { assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg); } diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index c06b09f2fecd..e36cd2b6c242 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -21,7 +21,10 @@ use std::cmp::{self, Ordering}; use std::task::{ready, Poll}; use std::{any::Any, sync::Arc}; -use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use super::metrics::{ + self, BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, + RecordOutput, +}; use super::{DisplayAs, ExecutionPlanProperties, PlanProperties}; use crate::{ DisplayFormatType, Distribution, ExecutionPlan, RecordBatchStream, @@ -38,13 +41,12 @@ use arrow::compute::{cast, is_not_null, kernels, sum}; use arrow::datatypes::{DataType, Int64Type, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow_ord::cmp::lt; +use async_trait::async_trait; use datafusion_common::{ exec_datafusion_err, exec_err, internal_err, HashMap, HashSet, Result, UnnestOptions, }; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; - -use async_trait::async_trait; use futures::{Stream, StreamExt}; use log::trace; @@ -203,22 +205,18 @@ impl ExecutionPlan for UnnestExec { #[derive(Clone, Debug)] struct UnnestMetrics { - /// Total time for column unnesting - elapsed_compute: metrics::Time, + /// Execution metrics + baseline_metrics: BaselineMetrics, /// Number of batches consumed input_batches: metrics::Count, /// Number of rows consumed input_rows: metrics::Count, /// Number of batches produced output_batches: metrics::Count, - /// Number of rows produced by this operator - output_rows: metrics::Count, } impl UnnestMetrics { fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { - let elapsed_compute = MetricBuilder::new(metrics).elapsed_compute(partition); - let input_batches = MetricBuilder::new(metrics).counter("input_batches", partition); @@ -227,14 +225,11 @@ impl UnnestMetrics { let output_batches = MetricBuilder::new(metrics).counter("output_batches", partition); - let output_rows = MetricBuilder::new(metrics).output_rows(partition); - Self { + baseline_metrics: BaselineMetrics::new(metrics, partition), input_batches, input_rows, output_batches, - output_rows, - elapsed_compute, } } } @@ -284,7 +279,9 @@ impl UnnestStream { loop { return Poll::Ready(match ready!(self.input.poll_next_unpin(cx)) { Some(Ok(batch)) => { - let timer = self.metrics.elapsed_compute.timer(); + let elapsed_compute = + self.metrics.baseline_metrics.elapsed_compute().clone(); + let timer = elapsed_compute.timer(); self.metrics.input_batches.add(1); self.metrics.input_rows.add(batch.num_rows()); let result = build_batch( @@ -299,7 +296,7 @@ impl UnnestStream { continue; }; self.metrics.output_batches.add(1); - self.metrics.output_rows.add(result_batch.num_rows()); + (&result_batch).record_output(&self.metrics.baseline_metrics); // Empty record batches should not be emitted. // They need to be treated as [`Option`]es and handled separately @@ -313,8 +310,8 @@ impl UnnestStream { self.metrics.input_batches, self.metrics.input_rows, self.metrics.output_batches, - self.metrics.output_rows, - self.metrics.elapsed_compute, + self.metrics.baseline_metrics.output_rows(), + self.metrics.baseline_metrics.elapsed_compute(), ); other } diff --git a/datafusion/physical-plan/src/values.rs b/datafusion/physical-plan/src/values.rs index 6cb64bcb5d86..fb27ccf30179 100644 --- a/datafusion/physical-plan/src/values.rs +++ b/datafusion/physical-plan/src/values.rs @@ -308,8 +308,10 @@ mod tests { data, )?; + #[allow(deprecated)] + let stats = values.statistics()?; assert_eq!( - values.statistics()?, + stats, Statistics { num_rows: Precision::Exact(rows), total_byte_size: Precision::Exact(8), // not important diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 92138bf6a7a1..d851d08a101f 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -38,7 +38,7 @@ use crate::{ ExecutionPlanProperties, InputOrderMode, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, }; -use ahash::RandomState; + use arrow::compute::take_record_batch; use arrow::{ array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder}, @@ -60,9 +60,12 @@ use datafusion_expr::ColumnarValue; use datafusion_physical_expr::window::{ PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowState, }; -use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{ + OrderingRequirements, PhysicalSortExpr, +}; +use ahash::RandomState; use futures::stream::Stream; use futures::{ready, StreamExt}; use hashbrown::hash_table::HashTable; @@ -111,7 +114,7 @@ impl BoundedWindowAggExec { let indices = get_ordered_partition_by_indices( window_expr[0].partition_by(), &input, - ); + )?; if indices.len() == partition_by_exprs.len() { indices } else { @@ -123,7 +126,7 @@ impl BoundedWindowAggExec { vec![] } }; - let cache = Self::compute_properties(&input, &schema, &window_expr); + let cache = Self::compute_properties(&input, &schema, &window_expr)?; Ok(Self { input, window_expr, @@ -151,7 +154,7 @@ impl BoundedWindowAggExec { // We are sure that partition by columns are always at the beginning of sort_keys // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely // to calculate partition separation points - pub fn partition_by_sort_keys(&self) -> Result { + pub fn partition_by_sort_keys(&self) -> Result> { let partition_by = self.window_expr()[0].partition_by(); get_partition_by_sort_exprs( &self.input, @@ -191,9 +194,9 @@ impl BoundedWindowAggExec { input: &Arc, schema: &SchemaRef, window_exprs: &[Arc], - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties: - let eq_properties = window_equivalence_properties(schema, input, window_exprs); + let eq_properties = window_equivalence_properties(schema, input, window_exprs)?; // As we can have repartitioning using the partition keys, this can // be either one or more than one, depending on the presence of @@ -201,13 +204,13 @@ impl BoundedWindowAggExec { let output_partitioning = input.output_partitioning().clone(); // Construct properties cache - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, // TODO: Emission type and boundedness information can be enhanced here input.pipeline_behavior(), input.boundedness(), - ) + )) } pub fn partition_keys(&self) -> Vec> { @@ -226,6 +229,23 @@ impl BoundedWindowAggExec { .unwrap_or_else(Vec::new) } } + + fn statistics_helper(&self, statistics: Statistics) -> Result { + let win_cols = self.window_expr.len(); + let input_cols = self.input.schema().fields().len(); + // TODO stats: some windowing function will maintain invariants such as min, max... + let mut column_statistics = Vec::with_capacity(win_cols + input_cols); + // copy stats of the input to the beginning of the schema. + column_statistics.extend(statistics.column_statistics); + for _ in 0..win_cols { + column_statistics.push(ColumnStatistics::new_unknown()) + } + Ok(Statistics { + num_rows: statistics.num_rows, + column_statistics, + total_byte_size: Precision::Absent, + }) + } } impl DisplayAs for BoundedWindowAggExec { @@ -261,7 +281,7 @@ impl DisplayAs for BoundedWindowAggExec { writeln!(f, "select_list={}", g.join(", "))?; let mode = &self.input_order_mode; - writeln!(f, "mode={:?}", mode)?; + writeln!(f, "mode={mode:?}")?; } } Ok(()) @@ -286,14 +306,14 @@ impl ExecutionPlan for BoundedWindowAggExec { vec![&self.input] } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { let partition_bys = self.window_expr()[0].partition_by(); let order_keys = self.window_expr()[0].order_by(); let partition_bys = self .ordered_partition_by_indices .iter() .map(|idx| &partition_bys[*idx]); - vec![calc_requirements(partition_bys, order_keys.iter())] + vec![calc_requirements(partition_bys, order_keys)] } fn required_input_distribution(&self) -> Vec { @@ -343,21 +363,12 @@ impl ExecutionPlan for BoundedWindowAggExec { } fn statistics(&self) -> Result { - let input_stat = self.input.statistics()?; - let win_cols = self.window_expr.len(); - let input_cols = self.input.schema().fields().len(); - // TODO stats: some windowing function will maintain invariants such as min, max... - let mut column_statistics = Vec::with_capacity(win_cols + input_cols); - // copy stats of the input to the beginning of the schema. - column_statistics.extend(input_stat.column_statistics); - for _ in 0..win_cols { - column_statistics.push(ColumnStatistics::new_unknown()) - } - Ok(Statistics { - num_rows: input_stat.num_rows, - column_statistics, - total_byte_size: Precision::Absent, - }) + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + let input_stat = self.input.partition_statistics(partition)?; + self.statistics_helper(input_stat) } } @@ -742,7 +753,7 @@ impl LinearSearch { /// when computing partitions. pub struct SortedSearch { /// Stores partition by columns and their ordering information - partition_by_sort_keys: LexOrdering, + partition_by_sort_keys: Vec, /// Input ordering and partition by key ordering need not be the same, so /// this vector stores the mapping between them. For instance, if the input /// is ordered by a, b and the window expression contains a PARTITION BY b, a @@ -1339,18 +1350,17 @@ mod tests { Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc; let args = vec![col_expr]; let partitionby_exprs = vec![col(hash, &schema)?]; - let orderby_exprs = LexOrdering::new(vec![PhysicalSortExpr { + let orderby_exprs = vec![PhysicalSortExpr { expr: col(order_by, &schema)?, options: SortOptions::default(), - }]); + }]; let window_frame = WindowFrame::new_bounds( WindowFrameUnits::Range, WindowFrameBound::CurrentRow, WindowFrameBound::Following(ScalarValue::UInt64(Some(n_future_range as u64))), ); let fn_name = format!( - "{}({:?}) PARTITION BY: [{:?}], ORDER BY: [{:?}]", - window_fn, args, partitionby_exprs, orderby_exprs + "{window_fn}({args:?}) PARTITION BY: [{partitionby_exprs:?}], ORDER BY: [{orderby_exprs:?}]" ); let input_order_mode = InputOrderMode::Linear; Ok(Arc::new(BoundedWindowAggExec::try_new( @@ -1359,7 +1369,7 @@ mod tests { fn_name, &args, &partitionby_exprs, - orderby_exprs.as_ref(), + &orderby_exprs, Arc::new(window_frame), &input.schema(), false, @@ -1456,13 +1466,14 @@ mod tests { } fn schema_orders(schema: &SchemaRef) -> Result> { - let orderings = vec![LexOrdering::new(vec![PhysicalSortExpr { + let orderings = vec![[PhysicalSortExpr { expr: col("sn", schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }])]; + }] + .into()]; Ok(orderings) } @@ -1613,7 +1624,7 @@ mod tests { Arc::new(StandardWindowExpr::new( last_value_func, &[], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new_bounds( WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::UInt64(None)), @@ -1624,7 +1635,7 @@ mod tests { Arc::new(StandardWindowExpr::new( nth_value_func1, &[], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new_bounds( WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::UInt64(None)), @@ -1635,7 +1646,7 @@ mod tests { Arc::new(StandardWindowExpr::new( nth_value_func2, &[], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new_bounds( WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::UInt64(None)), @@ -1776,8 +1787,8 @@ mod tests { let plan = projection_exec(window)?; let expected_plan = vec![ - "ProjectionExec: expr=[sn@0 as sn, hash@1 as hash, count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [LexOrdering { inner: [PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }] }]@2 as col_2]", - " BoundedWindowAggExec: wdw=[count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [LexOrdering { inner: [PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }] }]: Ok(Field { name: \"count([Column { name: \\\"sn\\\", index: 0 }]) PARTITION BY: [[Column { name: \\\"hash\\\", index: 1 }]], ORDER BY: [LexOrdering { inner: [PhysicalSortExpr { expr: Column { name: \\\"sn\\\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }] }]\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(1)), is_causal: false }], mode=[Linear]", + "ProjectionExec: expr=[sn@0 as sn, hash@1 as hash, count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]@2 as col_2]", + " BoundedWindowAggExec: wdw=[count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]: Ok(Field { name: \"count([Column { name: \\\"sn\\\", index: 0 }]) PARTITION BY: [[Column { name: \\\"hash\\\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \\\"sn\\\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(1)), is_causal: false }], mode=[Linear]", " StreamingTableExec: partition_sizes=1, projection=[sn, hash], infinite_source=true, output_ordering=[sn@0 ASC NULLS LAST]", ]; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index d38bf2a186a8..5583abfd72a2 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -22,7 +22,6 @@ mod utils; mod window_agg_exec; use std::borrow::Borrow; -use std::iter; use std::sync::Arc; use crate::{ @@ -30,8 +29,8 @@ use crate::{ InputOrderMode, PhysicalExpr, }; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use arrow_schema::SortOptions; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow_schema::{FieldRef, SortOptions}; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ PartitionEvaluator, ReversedUDWF, SetMonotonicity, WindowFrame, @@ -42,12 +41,13 @@ use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::{ - reverse_order_bys, - window::{SlidingAggregateWindowExpr, StandardWindowFunctionExpr}, - ConstExpr, EquivalenceProperties, LexOrdering, PhysicalSortRequirement, +use datafusion_physical_expr::window::{ + SlidingAggregateWindowExpr, StandardWindowFunctionExpr, +}; +use datafusion_physical_expr::{ConstExpr, EquivalenceProperties}; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortRequirement, }; -use datafusion_physical_expr_common::sort_expr::LexRequirement; use itertools::Itertools; @@ -65,16 +65,16 @@ pub fn schema_add_window_field( window_fn: &WindowFunctionDefinition, fn_name: &str, ) -> Result> { - let data_types = args + let fields = args .iter() - .map(|e| Arc::clone(e).as_ref().data_type(schema)) + .map(|e| Arc::clone(e).as_ref().return_field(schema)) .collect::>>()?; let nullability = args .iter() .map(|e| Arc::clone(e).as_ref().nullable(schema)) .collect::>>()?; - let window_expr_return_type = - window_fn.return_type(&data_types, &nullability, fn_name)?; + let window_expr_return_field = + window_fn.return_field(&fields, &nullability, fn_name)?; let mut window_fields = schema .fields() .iter() @@ -84,11 +84,10 @@ pub fn schema_add_window_field( if let WindowFunctionDefinition::AggregateUDF(_) = window_fn { Ok(Arc::new(Schema::new(window_fields))) } else { - window_fields.extend_from_slice(&[Field::new( - fn_name, - window_expr_return_type, - false, - )]); + window_fields.extend_from_slice(&[window_expr_return_field + .as_ref() + .clone() + .with_name(fn_name)]); Ok(Arc::new(Schema::new(window_fields))) } } @@ -100,7 +99,7 @@ pub fn create_window_expr( name: String, args: &[Arc], partition_by: &[Arc], - order_by: &LexOrdering, + order_by: &[PhysicalSortExpr], window_frame: Arc, input_schema: &Schema, ignore_nulls: bool, @@ -132,7 +131,7 @@ pub fn create_window_expr( /// Creates an appropriate [`WindowExpr`] based on the window frame and fn window_expr_from_aggregate_expr( partition_by: &[Arc], - order_by: &LexOrdering, + order_by: &[PhysicalSortExpr], window_frame: Arc, aggregate: Arc, ) -> Arc { @@ -165,15 +164,15 @@ pub fn create_udwf_window_expr( ignore_nulls: bool, ) -> Result> { // need to get the types into an owned vec for some reason - let input_types: Vec<_> = args + let input_fields: Vec<_> = args .iter() - .map(|arg| arg.data_type(input_schema)) + .map(|arg| arg.return_field(input_schema)) .collect::>()?; let udwf_expr = Arc::new(WindowUDFExpr { fun: Arc::clone(fun), args: args.to_vec(), - input_types, + input_fields, name, is_reversed: false, ignore_nulls, @@ -202,8 +201,8 @@ pub struct WindowUDFExpr { args: Vec>, /// Display name name: String, - /// Types of input expressions - input_types: Vec, + /// Fields of input expressions + input_fields: Vec, /// This is set to `true` only if the user-defined window function /// expression supports evaluation in reverse order, and the /// evaluation order is reversed. @@ -223,21 +222,21 @@ impl StandardWindowFunctionExpr for WindowUDFExpr { self } - fn field(&self) -> Result { + fn field(&self) -> Result { self.fun - .field(WindowUDFFieldArgs::new(&self.input_types, &self.name)) + .field(WindowUDFFieldArgs::new(&self.input_fields, &self.name)) } fn expressions(&self) -> Vec> { self.fun - .expressions(ExpressionArgs::new(&self.args, &self.input_types)) + .expressions(ExpressionArgs::new(&self.args, &self.input_fields)) } fn create_evaluator(&self) -> Result> { self.fun .partition_evaluator_factory(PartitionEvaluatorArgs::new( &self.args, - &self.input_types, + &self.input_fields, self.is_reversed, self.ignore_nulls, )) @@ -255,7 +254,7 @@ impl StandardWindowFunctionExpr for WindowUDFExpr { fun, args: self.args.clone(), name: self.name.clone(), - input_types: self.input_types.clone(), + input_fields: self.input_fields.clone(), is_reversed: !self.is_reversed, ignore_nulls: self.ignore_nulls, })), @@ -279,26 +278,33 @@ pub(crate) fn calc_requirements< >( partition_by_exprs: impl IntoIterator, orderby_sort_exprs: impl IntoIterator, -) -> Option { - let mut sort_reqs = LexRequirement::new( - partition_by_exprs - .into_iter() - .map(|partition_by| { - PhysicalSortRequirement::new(Arc::clone(partition_by.borrow()), None) - }) - .collect::>(), - ); +) -> Option { + let mut sort_reqs_with_partition = partition_by_exprs + .into_iter() + .map(|partition_by| { + PhysicalSortRequirement::new(Arc::clone(partition_by.borrow()), None) + }) + .collect::>(); + let mut sort_reqs = vec![]; for element in orderby_sort_exprs.into_iter() { let PhysicalSortExpr { expr, options } = element.borrow(); - if !sort_reqs.iter().any(|e| e.expr.eq(expr)) { - sort_reqs.push(PhysicalSortRequirement::new( - Arc::clone(expr), - Some(*options), - )); + let sort_req = PhysicalSortRequirement::new(Arc::clone(expr), Some(*options)); + if !sort_reqs_with_partition.iter().any(|e| e.expr.eq(expr)) { + sort_reqs_with_partition.push(sort_req.clone()); + } + if !sort_reqs + .iter() + .any(|e: &PhysicalSortRequirement| e.expr.eq(expr)) + { + sort_reqs.push(sort_req); } } - // Convert empty result to None. Otherwise wrap result inside Some() - (!sort_reqs.is_empty()).then_some(sort_reqs) + + let mut alternatives = vec![]; + alternatives.extend(LexRequirement::new(sort_reqs_with_partition)); + alternatives.extend(LexRequirement::new(sort_reqs)); + + OrderingRequirements::new_alternatives(alternatives, false) } /// This function calculates the indices such that when partition by expressions reordered with the indices @@ -309,18 +315,18 @@ pub(crate) fn calc_requirements< pub fn get_ordered_partition_by_indices( partition_by_exprs: &[Arc], input: &Arc, -) -> Vec { +) -> Result> { let (_, indices) = input .equivalence_properties() - .find_longest_permutation(partition_by_exprs); - indices + .find_longest_permutation(partition_by_exprs)?; + Ok(indices) } pub(crate) fn get_partition_by_sort_exprs( input: &Arc, partition_by_exprs: &[Arc], ordered_partition_by_indices: &[usize], -) -> Result { +) -> Result> { let ordered_partition_exprs = ordered_partition_by_indices .iter() .map(|idx| Arc::clone(&partition_by_exprs[*idx])) @@ -329,7 +335,7 @@ pub(crate) fn get_partition_by_sort_exprs( assert!(ordered_partition_by_indices.len() <= partition_by_exprs.len()); let (ordering, _) = input .equivalence_properties() - .find_longest_permutation(&ordered_partition_exprs); + .find_longest_permutation(&ordered_partition_exprs)?; if ordering.len() == ordered_partition_exprs.len() { Ok(ordering) } else { @@ -341,11 +347,11 @@ pub(crate) fn window_equivalence_properties( schema: &SchemaRef, input: &Arc, window_exprs: &[Arc], -) -> EquivalenceProperties { +) -> Result { // We need to update the schema, so we can't directly use input's equivalence // properties. let mut window_eq_properties = EquivalenceProperties::new(Arc::clone(schema)) - .extend(input.equivalence_properties().clone()); + .extend(input.equivalence_properties().clone())?; let window_schema_len = schema.fields.len(); let input_schema_len = window_schema_len - window_exprs.len(); @@ -357,22 +363,25 @@ pub(crate) fn window_equivalence_properties( // Collect columns defining partitioning, and construct all `SortOptions` // variations for them. Then, we will check each one whether it satisfies // the existing ordering provided by the input plan. - let partition_by_orders = partitioning_exprs + let mut all_satisfied_lexs = vec![]; + for lex in partitioning_exprs .iter() - .map(|pb_order| sort_options_resolving_constant(Arc::clone(pb_order))); - let all_satisfied_lexs = partition_by_orders + .map(|pb_order| sort_options_resolving_constant(Arc::clone(pb_order))) .multi_cartesian_product() - .map(LexOrdering::new) - .filter(|lex| window_eq_properties.ordering_satisfy(lex)) - .collect::>(); + .filter_map(LexOrdering::new) + { + if window_eq_properties.ordering_satisfy(lex.clone())? { + all_satisfied_lexs.push(lex); + } + } // If there is a partitioning, and no possible ordering cannot satisfy // the input plan's orderings, then we cannot further introduce any // new orderings for the window plan. if !no_partitioning && all_satisfied_lexs.is_empty() { - return window_eq_properties; + return Ok(window_eq_properties); } else if let Some(std_expr) = expr.as_any().downcast_ref::() { - std_expr.add_equal_orderings(&mut window_eq_properties); + std_expr.add_equal_orderings(&mut window_eq_properties)?; } else if let Some(plain_expr) = expr.as_any().downcast_ref::() { @@ -380,26 +389,26 @@ pub(crate) fn window_equivalence_properties( // unbounded starting point. // First, check if the frame covers the whole table: if plain_expr.get_window_frame().end_bound.is_unbounded() { - let window_col = Column::new(expr.name(), i + input_schema_len); + let window_col = + Arc::new(Column::new(expr.name(), i + input_schema_len)) as _; if no_partitioning { // Window function has a constant result across the table: - window_eq_properties = window_eq_properties - .with_constants(iter::once(ConstExpr::new(Arc::new(window_col)))) + window_eq_properties + .add_constants(std::iter::once(ConstExpr::from(window_col)))? } else { // Window function results in a partial constant value in // some ordering. Adjust the ordering equivalences accordingly: let new_lexs = all_satisfied_lexs.into_iter().flat_map(|lex| { - let orderings = lex.take_exprs(); let new_partial_consts = - sort_options_resolving_constant(Arc::new(window_col.clone())); + sort_options_resolving_constant(Arc::clone(&window_col)); new_partial_consts.into_iter().map(move |partial| { - let mut existing = orderings.clone(); + let mut existing = lex.clone(); existing.push(partial); - LexOrdering::new(existing) + existing }) }); - window_eq_properties.add_new_orderings(new_lexs); + window_eq_properties.add_orderings(new_lexs); } } else { // The window frame is ever expanding, so set monotonicity comes @@ -407,7 +416,7 @@ pub(crate) fn window_equivalence_properties( plain_expr.add_equal_orderings( &mut window_eq_properties, window_expr_indices[i], - ); + )?; } } else if let Some(sliding_expr) = expr.as_any().downcast_ref::() @@ -425,22 +434,18 @@ pub(crate) fn window_equivalence_properties( let window_col = Column::new(expr.name(), i + input_schema_len); if no_partitioning { // Reverse set-monotonic cases with no partitioning: - let new_ordering = - vec![LexOrdering::new(vec![PhysicalSortExpr::new( - Arc::new(window_col), - SortOptions::new(increasing, true), - )])]; - window_eq_properties.add_new_orderings(new_ordering); + window_eq_properties.add_ordering([PhysicalSortExpr::new( + Arc::new(window_col), + SortOptions::new(increasing, true), + )]); } else { // Reverse set-monotonic cases for all orderings: - for lex in all_satisfied_lexs.into_iter() { - let mut existing = lex.take_exprs(); - existing.push(PhysicalSortExpr::new( + for mut lex in all_satisfied_lexs.into_iter() { + lex.push(PhysicalSortExpr::new( Arc::new(window_col.clone()), SortOptions::new(increasing, true), )); - window_eq_properties - .add_new_ordering(LexOrdering::new(existing)); + window_eq_properties.add_ordering(lex); } } } @@ -451,44 +456,44 @@ pub(crate) fn window_equivalence_properties( // utilize set-monotonicity since the set shrinks as the frame // boundary starts "touching" the end of the table. else if frame.is_causal() { - let mut args_all_lexs = sliding_expr + let args_all_lexs = sliding_expr .get_aggregate_expr() .expressions() .into_iter() .map(sort_options_resolving_constant) .multi_cartesian_product(); - let mut asc = false; - if args_all_lexs.any(|order| { + let (mut asc, mut satisfied) = (false, false); + for order in args_all_lexs { if let Some(f) = order.first() { asc = !f.options.descending; } - window_eq_properties.ordering_satisfy(&LexOrdering::new(order)) - }) { + if window_eq_properties.ordering_satisfy(order)? { + satisfied = true; + break; + } + } + if satisfied { let increasing = set_monotonicity.eq(&SetMonotonicity::Increasing); let window_col = Column::new(expr.name(), i + input_schema_len); if increasing && (asc || no_partitioning) { - let new_ordering = - LexOrdering::new(vec![PhysicalSortExpr::new( - Arc::new(window_col), - SortOptions::new(false, false), - )]); - window_eq_properties.add_new_ordering(new_ordering); + window_eq_properties.add_ordering([PhysicalSortExpr::new( + Arc::new(window_col), + SortOptions::new(false, false), + )]); } else if !increasing && (!asc || no_partitioning) { - let new_ordering = - LexOrdering::new(vec![PhysicalSortExpr::new( - Arc::new(window_col), - SortOptions::new(true, false), - )]); - window_eq_properties.add_new_ordering(new_ordering); + window_eq_properties.add_ordering([PhysicalSortExpr::new( + Arc::new(window_col), + SortOptions::new(true, false), + )]); }; } } } } } - window_eq_properties + Ok(window_eq_properties) } /// Constructs the best-fitting windowing operator (a `WindowAggExec` or a @@ -515,7 +520,7 @@ pub fn get_best_fitting_window( let orderby_keys = window_exprs[0].order_by(); let (should_reverse, input_order_mode) = if let Some((should_reverse, input_order_mode)) = - get_window_mode(partitionby_exprs, orderby_keys, input) + get_window_mode(partitionby_exprs, orderby_keys, input)? { (should_reverse, input_order_mode) } else { @@ -581,35 +586,29 @@ pub fn get_best_fitting_window( /// the mode this window operator should work in to accommodate the existing ordering. pub fn get_window_mode( partitionby_exprs: &[Arc], - orderby_keys: &LexOrdering, + orderby_keys: &[PhysicalSortExpr], input: &Arc, -) -> Option<(bool, InputOrderMode)> { - let input_eqs = input.equivalence_properties().clone(); - let mut partition_by_reqs: LexRequirement = LexRequirement::new(vec![]); - let (_, indices) = input_eqs.find_longest_permutation(partitionby_exprs); - vec![].extend(indices.iter().map(|&idx| PhysicalSortRequirement { - expr: Arc::clone(&partitionby_exprs[idx]), - options: None, - })); - partition_by_reqs - .inner - .extend(indices.iter().map(|&idx| PhysicalSortRequirement { +) -> Result> { + let mut input_eqs = input.equivalence_properties().clone(); + let (_, indices) = input_eqs.find_longest_permutation(partitionby_exprs)?; + let partition_by_reqs = indices + .iter() + .map(|&idx| PhysicalSortRequirement { expr: Arc::clone(&partitionby_exprs[idx]), options: None, - })); + }) + .collect::>(); // Treat partition by exprs as constant. During analysis of requirements are satisfied. - let const_exprs = partitionby_exprs.iter().map(ConstExpr::from); - let partition_by_eqs = input_eqs.with_constants(const_exprs); - let order_by_reqs = LexRequirement::from(orderby_keys.clone()); - let reverse_order_by_reqs = LexRequirement::from(reverse_order_bys(orderby_keys)); - for (should_swap, order_by_reqs) in - [(false, order_by_reqs), (true, reverse_order_by_reqs)] + let const_exprs = partitionby_exprs.iter().cloned().map(ConstExpr::from); + input_eqs.add_constants(const_exprs)?; + let reverse_orderby_keys = + orderby_keys.iter().map(|e| e.reverse()).collect::>(); + for (should_swap, orderbys) in + [(false, orderby_keys), (true, reverse_orderby_keys.as_ref())] { - let req = LexRequirement::new( - [partition_by_reqs.inner.clone(), order_by_reqs.inner].concat(), - ) - .collapse(); - if partition_by_eqs.ordering_satisfy_requirement(&req) { + let mut req = partition_by_reqs.clone(); + req.extend(orderbys.iter().cloned().map(Into::into)); + if req.is_empty() || input_eqs.ordering_satisfy_requirement(req)? { // Window can be run with existing ordering let mode = if indices.len() == partitionby_exprs.len() { InputOrderMode::Sorted @@ -618,10 +617,10 @@ pub fn get_window_mode( } else { InputOrderMode::PartiallySorted(indices) }; - return Some((should_swap, mode)); + return Ok(Some((should_swap, mode))); } } - None + Ok(None) } fn sort_options_resolving_constant(expr: Arc) -> Vec { @@ -641,12 +640,13 @@ mod tests { use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use arrow::compute::SortOptions; + use arrow_schema::{DataType, Field}; use datafusion_execution::TaskContext; - use datafusion_functions_aggregate::count::count_udaf; - use futures::FutureExt; use InputOrderMode::{Linear, PartiallySorted, Sorted}; + use futures::FutureExt; + fn create_test_schema() -> Result { let nullable_column = Field::new("nullable_col", DataType::Int32, true); let non_nullable_column = Field::new("non_nullable_col", DataType::Int32, false); @@ -696,16 +696,14 @@ mod tests { /// Created a sorted Streaming Table exec pub fn streaming_table_exec( schema: &SchemaRef, - sort_exprs: impl IntoIterator, + ordering: LexOrdering, infinite_source: bool, ) -> Result> { - let sort_exprs = sort_exprs.into_iter().collect(); - Ok(Arc::new(StreamingTableExec::try_new( Arc::clone(schema), vec![], None, - Some(sort_exprs), + Some(ordering), infinite_source, None, )?)) @@ -719,25 +717,38 @@ mod tests { ( vec!["a"], vec![("b", true, true)], - vec![("a", None), ("b", Some((true, true)))], + vec![ + vec![("a", None), ("b", Some((true, true)))], + vec![("b", Some((true, true)))], + ], ), // PARTITION BY a, ORDER BY a ASC NULLS FIRST - (vec!["a"], vec![("a", true, true)], vec![("a", None)]), + ( + vec!["a"], + vec![("a", true, true)], + vec![vec![("a", None)], vec![("a", Some((true, true)))]], + ), // PARTITION BY a, ORDER BY b ASC NULLS FIRST, c DESC NULLS LAST ( vec!["a"], vec![("b", true, true), ("c", false, false)], vec![ - ("a", None), - ("b", Some((true, true))), - ("c", Some((false, false))), + vec![ + ("a", None), + ("b", Some((true, true))), + ("c", Some((false, false))), + ], + vec![("b", Some((true, true))), ("c", Some((false, false)))], ], ), // PARTITION BY a, c, ORDER BY b ASC NULLS FIRST, c DESC NULLS LAST ( vec!["a", "c"], vec![("b", true, true), ("c", false, false)], - vec![("a", None), ("c", None), ("b", Some((true, true)))], + vec![ + vec![("a", None), ("c", None), ("b", Some((true, true)))], + vec![("b", Some((true, true))), ("c", Some((false, false)))], + ], ), ]; for (pb_params, ob_params, expected_params) in test_data { @@ -749,25 +760,26 @@ mod tests { let mut orderbys = vec![]; for (col_name, descending, nulls_first) in ob_params { let expr = col(col_name, &schema)?; - let options = SortOptions { - descending, - nulls_first, - }; - orderbys.push(PhysicalSortExpr { expr, options }); + let options = SortOptions::new(descending, nulls_first); + orderbys.push(PhysicalSortExpr::new(expr, options)); } - let mut expected: Option = None; - for (col_name, reqs) in expected_params { - let options = reqs.map(|(descending, nulls_first)| SortOptions { - descending, - nulls_first, - }); - let expr = col(col_name, &schema)?; - let res = PhysicalSortRequirement::new(expr, options); - if let Some(expected) = &mut expected { - expected.push(res); - } else { - expected = Some(LexRequirement::new(vec![res])); + let mut expected: Option = None; + for expected_param in expected_params.clone() { + let mut requirements = vec![]; + for (col_name, reqs) in expected_param { + let options = reqs.map(|(descending, nulls_first)| { + SortOptions::new(descending, nulls_first) + }); + let expr = col(col_name, &schema)?; + requirements.push(PhysicalSortRequirement::new(expr, options)); + } + if let Some(requirements) = LexRequirement::new(requirements) { + if let Some(alts) = expected.as_mut() { + alts.add_alternative(requirements); + } else { + expected = Some(OrderingRequirements::new(requirements)); + } } } assert_eq!(calc_requirements(partitionbys, orderbys), expected); @@ -789,7 +801,7 @@ mod tests { "count".to_owned(), &[col("a", &schema)?], &[], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new(None)), schema.as_ref(), false, @@ -893,13 +905,14 @@ mod tests { // Columns a,c are nullable whereas b,d are not nullable. // Source is sorted by a ASC NULLS FIRST, b ASC NULLS FIRST, c ASC NULLS FIRST, d ASC NULLS FIRST // Column e is not ordered. - let sort_exprs = vec![ + let ordering = [ sort_expr("a", &test_schema), sort_expr("b", &test_schema), sort_expr("c", &test_schema), sort_expr("d", &test_schema), - ]; - let exec_unbounded = streaming_table_exec(&test_schema, sort_exprs, true)?; + ] + .into(); + let exec_unbounded = streaming_table_exec(&test_schema, ordering, true)?; // test cases consists of vector of tuples. Where each tuple represents a single test case. // First field in the tuple is Vec where each element in the vector represents PARTITION BY columns @@ -986,7 +999,7 @@ mod tests { partition_by_exprs.push(col(col_name, &test_schema)?); } - let mut order_by_exprs = LexOrdering::default(); + let mut order_by_exprs = vec![]; for col_name in order_by_params { let expr = col(col_name, &test_schema)?; // Give default ordering, this is same with input ordering direction @@ -994,11 +1007,8 @@ mod tests { let options = SortOptions::default(); order_by_exprs.push(PhysicalSortExpr { expr, options }); } - let res = get_window_mode( - &partition_by_exprs, - order_by_exprs.as_ref(), - &exec_unbounded, - ); + let res = + get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded)?; // Since reversibility is not important in this test. Convert Option<(bool, InputOrderMode)> to Option let res = res.map(|(_, mode)| mode); assert_eq!( @@ -1016,13 +1026,14 @@ mod tests { // Columns a,c are nullable whereas b,d are not nullable. // Source is sorted by a ASC NULLS FIRST, b ASC NULLS FIRST, c ASC NULLS FIRST, d ASC NULLS FIRST // Column e is not ordered. - let sort_exprs = vec![ + let ordering = [ sort_expr("a", &test_schema), sort_expr("b", &test_schema), sort_expr("c", &test_schema), sort_expr("d", &test_schema), - ]; - let exec_unbounded = streaming_table_exec(&test_schema, sort_exprs, true)?; + ] + .into(); + let exec_unbounded = streaming_table_exec(&test_schema, ordering, true)?; // test cases consists of vector of tuples. Where each tuple represents a single test case. // First field in the tuple is Vec where each element in the vector represents PARTITION BY columns @@ -1151,7 +1162,7 @@ mod tests { partition_by_exprs.push(col(col_name, &test_schema)?); } - let mut order_by_exprs = LexOrdering::default(); + let mut order_by_exprs = vec![]; for (col_name, descending, nulls_first) in order_by_params { let expr = col(col_name, &test_schema)?; let options = SortOptions { @@ -1162,7 +1173,7 @@ mod tests { } assert_eq!( - get_window_mode(&partition_by_exprs, order_by_exprs.as_ref(), &exec_unbounded), + get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded)?, *expected, "Unexpected result for in unbounded test case#: {case_idx:?}, case: {test_case:?}" ); diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs b/datafusion/physical-plan/src/windows/window_agg_exec.rs index 3c42d3032ed5..1b7cb9bb76e1 100644 --- a/datafusion/physical-plan/src/windows/window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs @@ -44,7 +44,9 @@ use datafusion_common::stats::Precision; use datafusion_common::utils::{evaluate_partition_ranges, transpose}; use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr_common::sort_expr::{ + OrderingRequirements, PhysicalSortExpr, +}; use futures::{ready, Stream, StreamExt}; @@ -79,8 +81,8 @@ impl WindowAggExec { let schema = Arc::new(schema); let ordered_partition_by_indices = - get_ordered_partition_by_indices(window_expr[0].partition_by(), &input); - let cache = Self::compute_properties(Arc::clone(&schema), &input, &window_expr); + get_ordered_partition_by_indices(window_expr[0].partition_by(), &input)?; + let cache = Self::compute_properties(Arc::clone(&schema), &input, &window_expr)?; Ok(Self { input, window_expr, @@ -107,7 +109,7 @@ impl WindowAggExec { // We are sure that partition by columns are always at the beginning of sort_keys // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely // to calculate partition separation points - pub fn partition_by_sort_keys(&self) -> Result { + pub fn partition_by_sort_keys(&self) -> Result> { let partition_by = self.window_expr()[0].partition_by(); get_partition_by_sort_exprs( &self.input, @@ -121,9 +123,9 @@ impl WindowAggExec { schema: SchemaRef, input: &Arc, window_exprs: &[Arc], - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties: - let eq_properties = window_equivalence_properties(&schema, input, window_exprs); + let eq_properties = window_equivalence_properties(&schema, input, window_exprs)?; // Get output partitioning: // Because we can have repartitioning using the partition keys this @@ -131,13 +133,13 @@ impl WindowAggExec { let output_partitioning = input.output_partitioning().clone(); // Construct properties cache: - PlanProperties::new( + Ok(PlanProperties::new( eq_properties, output_partitioning, // TODO: Emission type and boundedness information can be enhanced here EmissionType::Final, input.boundedness(), - ) + )) } pub fn partition_keys(&self) -> Vec> { @@ -156,6 +158,24 @@ impl WindowAggExec { .unwrap_or_else(Vec::new) } } + + fn statistics_inner(&self) -> Result { + let input_stat = self.input.partition_statistics(None)?; + let win_cols = self.window_expr.len(); + let input_cols = self.input.schema().fields().len(); + // TODO stats: some windowing function will maintain invariants such as min, max... + let mut column_statistics = Vec::with_capacity(win_cols + input_cols); + // copy stats of the input to the beginning of the schema. + column_statistics.extend(input_stat.column_statistics); + for _ in 0..win_cols { + column_statistics.push(ColumnStatistics::new_unknown()) + } + Ok(Statistics { + num_rows: input_stat.num_rows, + column_statistics, + total_byte_size: Precision::Absent, + }) + } } impl DisplayAs for WindowAggExec { @@ -216,17 +236,17 @@ impl ExecutionPlan for WindowAggExec { vec![true] } - fn required_input_ordering(&self) -> Vec> { + fn required_input_ordering(&self) -> Vec> { let partition_bys = self.window_expr()[0].partition_by(); let order_keys = self.window_expr()[0].order_by(); if self.ordered_partition_by_indices.len() < partition_bys.len() { - vec![calc_requirements(partition_bys, order_keys.iter())] + vec![calc_requirements(partition_bys, order_keys)] } else { let partition_bys = self .ordered_partition_by_indices .iter() .map(|idx| &partition_bys[*idx]); - vec![calc_requirements(partition_bys, order_keys.iter())] + vec![calc_requirements(partition_bys, order_keys)] } } @@ -271,21 +291,15 @@ impl ExecutionPlan for WindowAggExec { } fn statistics(&self) -> Result { - let input_stat = self.input.statistics()?; - let win_cols = self.window_expr.len(); - let input_cols = self.input.schema().fields().len(); - // TODO stats: some windowing function will maintain invariants such as min, max... - let mut column_statistics = Vec::with_capacity(win_cols + input_cols); - // copy stats of the input to the beginning of the schema. - column_statistics.extend(input_stat.column_statistics); - for _ in 0..win_cols { - column_statistics.push(ColumnStatistics::new_unknown()) + self.statistics_inner() + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_none() { + self.statistics_inner() + } else { + Ok(Statistics::new_unknown(&self.schema())) } - Ok(Statistics { - num_rows: input_stat.num_rows, - column_statistics, - total_byte_size: Precision::Absent, - }) } } @@ -307,7 +321,7 @@ pub struct WindowAggStream { batches: Vec, finished: bool, window_expr: Vec>, - partition_by_sort_keys: LexOrdering, + partition_by_sort_keys: Vec, baseline_metrics: BaselineMetrics, ordered_partition_by_indices: Vec, } @@ -319,7 +333,7 @@ impl WindowAggStream { window_expr: Vec>, input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, - partition_by_sort_keys: LexOrdering, + partition_by_sort_keys: Vec, ordered_partition_by_indices: Vec, ) -> Result { // In WindowAggExec all partition by columns should be ordered. diff --git a/datafusion/physical-plan/src/work_table.rs b/datafusion/physical-plan/src/work_table.rs index 126a7d0bba29..076e30ab902d 100644 --- a/datafusion/physical-plan/src/work_table.rs +++ b/datafusion/physical-plan/src/work_table.rs @@ -20,13 +20,14 @@ use std::any::Any; use std::sync::{Arc, Mutex}; -use crate::execution_plan::{Boundedness, EmissionType}; +use crate::coop::cooperative; +use crate::execution_plan::{Boundedness, EmissionType, SchedulingType}; use crate::memory::MemoryStream; +use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::{ - metrics::{ExecutionPlanMetricsSet, MetricsSet}, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream, Statistics, }; -use crate::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; @@ -56,7 +57,7 @@ impl ReservedBatches { /// See /// This table serves as a mirror or buffer between each iteration of a recursive query. #[derive(Debug)] -pub(super) struct WorkTable { +pub struct WorkTable { batches: Mutex>, } @@ -131,16 +132,6 @@ impl WorkTableExec { Arc::clone(&self.schema) } - pub(super) fn with_work_table(&self, work_table: Arc) -> Self { - Self { - name: self.name.clone(), - schema: Arc::clone(&self.schema), - metrics: ExecutionPlanMetricsSet::new(), - work_table, - cache: self.cache.clone(), - } - } - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties(schema: SchemaRef) -> PlanProperties { PlanProperties::new( @@ -149,6 +140,7 @@ impl WorkTableExec { EmissionType::Incremental, Boundedness::Bounded, ) + .with_scheduling_type(SchedulingType::Cooperative) } } @@ -182,10 +174,6 @@ impl ExecutionPlan for WorkTableExec { &self.cache } - fn children(&self) -> Vec<&Arc> { - vec![] - } - fn maintains_input_order(&self) -> Vec { vec![false] } @@ -194,6 +182,10 @@ impl ExecutionPlan for WorkTableExec { vec![false] } + fn children(&self) -> Vec<&Arc> { + vec![] + } + fn with_new_children( self: Arc, _: Vec>, @@ -214,10 +206,11 @@ impl ExecutionPlan for WorkTableExec { ); } let batch = self.work_table.take()?; - Ok(Box::pin( + + let stream = MemoryStream::try_new(batch.batches, Arc::clone(&self.schema), None)? - .with_reservation(batch.reservation), - )) + .with_reservation(batch.reservation); + Ok(Box::pin(cooperative(stream))) } fn metrics(&self) -> Option { @@ -227,6 +220,33 @@ impl ExecutionPlan for WorkTableExec { fn statistics(&self) -> Result { Ok(Statistics::new_unknown(&self.schema())) } + + fn partition_statistics(&self, _partition: Option) -> Result { + Ok(Statistics::new_unknown(&self.schema())) + } + + /// Injects run-time state into this `WorkTableExec`. + /// + /// The only state this node currently understands is an [`Arc`]. + /// If `state` can be down-cast to that type, a new `WorkTableExec` backed + /// by the provided work table is returned. Otherwise `None` is returned + /// so that callers can attempt to propagate the state further down the + /// execution plan tree. + fn with_new_state( + &self, + state: Arc, + ) -> Option> { + // Down-cast to the expected state type; propagate `None` on failure + let work_table = state.downcast::().ok()?; + + Some(Arc::new(Self { + name: self.name.clone(), + schema: Arc::clone(&self.schema), + metrics: ExecutionPlanMetricsSet::new(), + work_table, + cache: self.cache.clone(), + })) + } } #[cfg(test)] diff --git a/datafusion/proto-common/README.md b/datafusion/proto-common/README.md index c8b46424f701..67b3b2787006 100644 --- a/datafusion/proto-common/README.md +++ b/datafusion/proto-common/README.md @@ -24,5 +24,10 @@ bytes, which can be useful for sending data over the network. See [API Docs] for details and examples. +Most projects should use the [`datafusion-proto`] crate directly, which re-exports +this module. If you are already using the [`datafusion-protp`] crate, there is no +reason to use this crate directly in your project as well. + +[`datafusion-proto`]: https://crates.io/crates/datafusion-proto [datafusion]: https://datafusion.apache.org [api docs]: http://docs.rs/datafusion-proto/latest diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index 82f1e91d9c9b..81fc9cceb777 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -85,6 +85,7 @@ enum JoinType { RIGHTSEMI = 6; RIGHTANTI = 7; LEFTMARK = 8; + RIGHTMARK = 9; } enum JoinConstraint { @@ -92,6 +93,11 @@ enum JoinConstraint { USING = 1; } +enum NullEquality { + NULL_EQUALS_NOTHING = 0; + NULL_EQUALS_NULL = 1; +} + message AvroOptions {} message ArrowOptions {} @@ -108,7 +114,6 @@ message Field { // for complex data types like structs, unions repeated Field children = 4; map metadata = 5; - bool dict_ordered = 6; } message Timestamp{ diff --git a/datafusion/proto-common/src/common.rs b/datafusion/proto-common/src/common.rs index 61711dcf8e08..9af63e3b0736 100644 --- a/datafusion/proto-common/src/common.rs +++ b/datafusion/proto-common/src/common.rs @@ -17,6 +17,7 @@ use datafusion_common::{internal_datafusion_err, DataFusionError}; +/// Return a `DataFusionError::Internal` with the given message pub fn proto_error>(message: S) -> DataFusionError { internal_datafusion_err!("{}", message.into()) } diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index bd969db31687..0823e150268d 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -1066,6 +1066,7 @@ impl TryFrom<&protobuf::TableParquetOptions> for TableParquetOptions { .unwrap(), column_specific_options, key_value_metadata: Default::default(), + crypto: Default::default(), }) } } diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index b44b05e9ca29..c3b6686df005 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -3107,9 +3107,6 @@ impl serde::Serialize for Field { if !self.metadata.is_empty() { len += 1; } - if self.dict_ordered { - len += 1; - } let mut struct_ser = serializer.serialize_struct("datafusion_common.Field", len)?; if !self.name.is_empty() { struct_ser.serialize_field("name", &self.name)?; @@ -3126,9 +3123,6 @@ impl serde::Serialize for Field { if !self.metadata.is_empty() { struct_ser.serialize_field("metadata", &self.metadata)?; } - if self.dict_ordered { - struct_ser.serialize_field("dictOrdered", &self.dict_ordered)?; - } struct_ser.end() } } @@ -3145,8 +3139,6 @@ impl<'de> serde::Deserialize<'de> for Field { "nullable", "children", "metadata", - "dict_ordered", - "dictOrdered", ]; #[allow(clippy::enum_variant_names)] @@ -3156,7 +3148,6 @@ impl<'de> serde::Deserialize<'de> for Field { Nullable, Children, Metadata, - DictOrdered, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -3183,7 +3174,6 @@ impl<'de> serde::Deserialize<'de> for Field { "nullable" => Ok(GeneratedField::Nullable), "children" => Ok(GeneratedField::Children), "metadata" => Ok(GeneratedField::Metadata), - "dictOrdered" | "dict_ordered" => Ok(GeneratedField::DictOrdered), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -3208,7 +3198,6 @@ impl<'de> serde::Deserialize<'de> for Field { let mut nullable__ = None; let mut children__ = None; let mut metadata__ = None; - let mut dict_ordered__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { @@ -3243,12 +3232,6 @@ impl<'de> serde::Deserialize<'de> for Field { map_.next_value::>()? ); } - GeneratedField::DictOrdered => { - if dict_ordered__.is_some() { - return Err(serde::de::Error::duplicate_field("dictOrdered")); - } - dict_ordered__ = Some(map_.next_value()?); - } } } Ok(Field { @@ -3257,7 +3240,6 @@ impl<'de> serde::Deserialize<'de> for Field { nullable: nullable__.unwrap_or_default(), children: children__.unwrap_or_default(), metadata: metadata__.unwrap_or_default(), - dict_ordered: dict_ordered__.unwrap_or_default(), }) } } @@ -3856,6 +3838,7 @@ impl serde::Serialize for JoinType { Self::Rightsemi => "RIGHTSEMI", Self::Rightanti => "RIGHTANTI", Self::Leftmark => "LEFTMARK", + Self::Rightmark => "RIGHTMARK", }; serializer.serialize_str(variant) } @@ -3876,6 +3859,7 @@ impl<'de> serde::Deserialize<'de> for JoinType { "RIGHTSEMI", "RIGHTANTI", "LEFTMARK", + "RIGHTMARK", ]; struct GeneratedVisitor; @@ -3925,6 +3909,7 @@ impl<'de> serde::Deserialize<'de> for JoinType { "RIGHTSEMI" => Ok(JoinType::Rightsemi), "RIGHTANTI" => Ok(JoinType::Rightanti), "LEFTMARK" => Ok(JoinType::Leftmark), + "RIGHTMARK" => Ok(JoinType::Rightmark), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -4433,6 +4418,77 @@ impl<'de> serde::Deserialize<'de> for NdJsonFormat { deserializer.deserialize_struct("datafusion_common.NdJsonFormat", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for NullEquality { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::NullEqualsNothing => "NULL_EQUALS_NOTHING", + Self::NullEqualsNull => "NULL_EQUALS_NULL", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for NullEquality { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "NULL_EQUALS_NOTHING", + "NULL_EQUALS_NULL", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = NullEquality; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "NULL_EQUALS_NOTHING" => Ok(NullEquality::NullEqualsNothing), + "NULL_EQUALS_NULL" => Ok(NullEquality::NullEqualsNull), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for ParquetColumnOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index e029327d481d..411d72af4c62 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -106,8 +106,6 @@ pub struct Field { ::prost::alloc::string::String, ::prost::alloc::string::String, >, - #[prost(bool, tag = "6")] - pub dict_ordered: bool, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct Timestamp { @@ -906,6 +904,7 @@ pub enum JoinType { Rightsemi = 6, Rightanti = 7, Leftmark = 8, + Rightmark = 9, } impl JoinType { /// String value of the enum field names used in the ProtoBuf definition. @@ -923,6 +922,7 @@ impl JoinType { Self::Rightsemi => "RIGHTSEMI", Self::Rightanti => "RIGHTANTI", Self::Leftmark => "LEFTMARK", + Self::Rightmark => "RIGHTMARK", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -937,6 +937,7 @@ impl JoinType { "RIGHTSEMI" => Some(Self::Rightsemi), "RIGHTANTI" => Some(Self::Rightanti), "LEFTMARK" => Some(Self::Leftmark), + "RIGHTMARK" => Some(Self::Rightmark), _ => None, } } @@ -969,6 +970,32 @@ impl JoinConstraint { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum NullEquality { + NullEqualsNothing = 0, + NullEqualsNull = 1, +} +impl NullEquality { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::NullEqualsNothing => "NULL_EQUALS_NOTHING", + Self::NullEqualsNull => "NULL_EQUALS_NULL", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "NULL_EQUALS_NOTHING" => Some(Self::NullEqualsNothing), + "NULL_EQUALS_NULL" => Some(Self::NullEqualsNull), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum TimeUnit { Second = 0, Millisecond = 1, diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index 28927cad03b4..b6cbe5759cfc 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -97,7 +97,6 @@ impl TryFrom<&Field> for protobuf::Field { nullable: field.is_nullable(), children: Vec::new(), metadata: field.metadata().clone(), - dict_ordered: field.dict_is_ordered().unwrap_or(false), }) } } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 92e697ad2d9c..a1eeabdf87f4 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -60,5 +60,4 @@ datafusion-functions = { workspace = true, default-features = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-window-common = { workspace = true } doc-comment = { workspace = true } -strum = { version = "0.27.1", features = ["derive"] } tokio = { workspace = true, features = ["rt-multi-thread"] } diff --git a/datafusion/proto/README.md b/datafusion/proto/README.md index f51e4664d5d9..f8930779db89 100644 --- a/datafusion/proto/README.md +++ b/datafusion/proto/README.md @@ -19,11 +19,11 @@ # `datafusion-proto`: Apache DataFusion Protobuf Serialization / Deserialization -This crate contains code to convert Apache [DataFusion] plans to and from +This crate contains code to convert [Apache DataFusion] plans to and from bytes, which can be useful for sending plans over the network, for example when building a distributed query engine. See [API Docs] for details and examples. -[datafusion]: https://datafusion.apache.org +[apache datafusion]: https://datafusion.apache.org [api docs]: http://docs.rs/datafusion-proto/latest diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 908b95ab56a4..64789f5de0d2 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -90,7 +90,7 @@ message ListingTableScanNode { ProjectionColumns projection = 4; datafusion_common.Schema schema = 5; repeated LogicalExprNode filters = 6; - repeated string table_partition_cols = 7; + repeated PartitionColumn table_partition_cols = 7; bool collect_stat = 8; uint32 target_partitions = 9; oneof FileFormatType { @@ -243,7 +243,7 @@ message JoinNode { datafusion_common.JoinConstraint join_constraint = 4; repeated LogicalExprNode left_join_key = 5; repeated LogicalExprNode right_join_key = 6; - bool null_equals_null = 7; + datafusion_common.NullEquality null_equality = 7; LogicalExprNode filter = 8; } @@ -726,6 +726,7 @@ message PhysicalPlanNode { ParquetSinkExecNode parquet_sink = 29; UnnestExecNode unnest = 30; JsonScanExecNode json_scan = 31; + CooperativeExecNode cooperative = 32; } } @@ -1033,6 +1034,10 @@ message AvroScanExecNode { FileScanExecConf base_conf = 1; } +message CooperativeExecNode { + PhysicalPlanNode input = 1; +} + enum PartitionMode { COLLECT_LEFT = 0; PARTITIONED = 1; @@ -1045,7 +1050,7 @@ message HashJoinExecNode { repeated JoinOn on = 3; datafusion_common.JoinType join_type = 4; PartitionMode partition_mode = 6; - bool null_equals_null = 7; + datafusion_common.NullEquality null_equality = 7; JoinFilter filter = 8; repeated uint32 projection = 9; } @@ -1061,7 +1066,7 @@ message SymmetricHashJoinExecNode { repeated JoinOn on = 3; datafusion_common.JoinType join_type = 4; StreamPartitionMode partition_mode = 6; - bool null_equals_null = 7; + datafusion_common.NullEquality null_equality = 7; JoinFilter filter = 8; repeated PhysicalSortExprNode left_sort_exprs = 9; repeated PhysicalSortExprNode right_sort_exprs = 10; @@ -1217,6 +1222,7 @@ message CoalesceBatchesExecNode { message CoalescePartitionsExecNode { PhysicalPlanNode input = 1; + optional uint32 fetch = 2; } message PhysicalHashRepartition { diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index e029327d481d..411d72af4c62 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -106,8 +106,6 @@ pub struct Field { ::prost::alloc::string::String, ::prost::alloc::string::String, >, - #[prost(bool, tag = "6")] - pub dict_ordered: bool, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct Timestamp { @@ -906,6 +904,7 @@ pub enum JoinType { Rightsemi = 6, Rightanti = 7, Leftmark = 8, + Rightmark = 9, } impl JoinType { /// String value of the enum field names used in the ProtoBuf definition. @@ -923,6 +922,7 @@ impl JoinType { Self::Rightsemi => "RIGHTSEMI", Self::Rightanti => "RIGHTANTI", Self::Leftmark => "LEFTMARK", + Self::Rightmark => "RIGHTMARK", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -937,6 +937,7 @@ impl JoinType { "RIGHTSEMI" => Some(Self::Rightsemi), "RIGHTANTI" => Some(Self::Rightanti), "LEFTMARK" => Some(Self::Leftmark), + "RIGHTMARK" => Some(Self::Rightmark), _ => None, } } @@ -969,6 +970,32 @@ impl JoinConstraint { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum NullEquality { + NullEqualsNothing = 0, + NullEqualsNull = 1, +} +impl NullEquality { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::NullEqualsNothing => "NULL_EQUALS_NOTHING", + Self::NullEqualsNull => "NULL_EQUALS_NULL", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "NULL_EQUALS_NOTHING" => Some(Self::NullEqualsNothing), + "NULL_EQUALS_NULL" => Some(Self::NullEqualsNull), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum TimeUnit { Second = 0, Millisecond = 1, diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 6166b6ec4796..92309ea6a5cb 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -2050,10 +2050,16 @@ impl serde::Serialize for CoalescePartitionsExecNode { if self.input.is_some() { len += 1; } + if self.fetch.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CoalescePartitionsExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; } + if let Some(v) = self.fetch.as_ref() { + struct_ser.serialize_field("fetch", v)?; + } struct_ser.end() } } @@ -2065,11 +2071,13 @@ impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { { const FIELDS: &[&str] = &[ "input", + "fetch", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, + Fetch, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2092,6 +2100,7 @@ impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { { match value { "input" => Ok(GeneratedField::Input), + "fetch" => Ok(GeneratedField::Fetch), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2112,6 +2121,7 @@ impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { V: serde::de::MapAccess<'de>, { let mut input__ = None; + let mut fetch__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -2120,10 +2130,19 @@ impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { } input__ = map_.next_value()?; } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); + } + fetch__ = + map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0) + ; + } } } Ok(CoalescePartitionsExecNode { input: input__, + fetch: fetch__, }) } } @@ -2555,6 +2574,97 @@ impl<'de> serde::Deserialize<'de> for ColumnUnnestListRecursions { deserializer.deserialize_struct("datafusion.ColumnUnnestListRecursions", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for CooperativeExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CooperativeExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CooperativeExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CooperativeExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CooperativeExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + } + } + Ok(CooperativeExecNode { + input: input__, + }) + } + } + deserializer.deserialize_struct("datafusion.CooperativeExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for CopyToNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -6832,7 +6942,7 @@ impl serde::Serialize for HashJoinExecNode { if self.partition_mode != 0 { len += 1; } - if self.null_equals_null { + if self.null_equality != 0 { len += 1; } if self.filter.is_some() { @@ -6861,8 +6971,10 @@ impl serde::Serialize for HashJoinExecNode { .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.partition_mode)))?; struct_ser.serialize_field("partitionMode", &v)?; } - if self.null_equals_null { - struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + if self.null_equality != 0 { + let v = super::datafusion_common::NullEquality::try_from(self.null_equality) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.null_equality)))?; + struct_ser.serialize_field("nullEquality", &v)?; } if let Some(v) = self.filter.as_ref() { struct_ser.serialize_field("filter", v)?; @@ -6887,8 +6999,8 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { "joinType", "partition_mode", "partitionMode", - "null_equals_null", - "nullEqualsNull", + "null_equality", + "nullEquality", "filter", "projection", ]; @@ -6900,7 +7012,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { On, JoinType, PartitionMode, - NullEqualsNull, + NullEquality, Filter, Projection, } @@ -6929,7 +7041,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { "on" => Ok(GeneratedField::On), "joinType" | "join_type" => Ok(GeneratedField::JoinType), "partitionMode" | "partition_mode" => Ok(GeneratedField::PartitionMode), - "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), + "nullEquality" | "null_equality" => Ok(GeneratedField::NullEquality), "filter" => Ok(GeneratedField::Filter), "projection" => Ok(GeneratedField::Projection), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -6956,7 +7068,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { let mut on__ = None; let mut join_type__ = None; let mut partition_mode__ = None; - let mut null_equals_null__ = None; + let mut null_equality__ = None; let mut filter__ = None; let mut projection__ = None; while let Some(k) = map_.next_key()? { @@ -6991,11 +7103,11 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { } partition_mode__ = Some(map_.next_value::()? as i32); } - GeneratedField::NullEqualsNull => { - if null_equals_null__.is_some() { - return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + GeneratedField::NullEquality => { + if null_equality__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEquality")); } - null_equals_null__ = Some(map_.next_value()?); + null_equality__ = Some(map_.next_value::()? as i32); } GeneratedField::Filter => { if filter__.is_some() { @@ -7020,7 +7132,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { on: on__.unwrap_or_default(), join_type: join_type__.unwrap_or_default(), partition_mode: partition_mode__.unwrap_or_default(), - null_equals_null: null_equals_null__.unwrap_or_default(), + null_equality: null_equality__.unwrap_or_default(), filter: filter__, projection: projection__.unwrap_or_default(), }) @@ -8456,7 +8568,7 @@ impl serde::Serialize for JoinNode { if !self.right_join_key.is_empty() { len += 1; } - if self.null_equals_null { + if self.null_equality != 0 { len += 1; } if self.filter.is_some() { @@ -8485,8 +8597,10 @@ impl serde::Serialize for JoinNode { if !self.right_join_key.is_empty() { struct_ser.serialize_field("rightJoinKey", &self.right_join_key)?; } - if self.null_equals_null { - struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + if self.null_equality != 0 { + let v = super::datafusion_common::NullEquality::try_from(self.null_equality) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.null_equality)))?; + struct_ser.serialize_field("nullEquality", &v)?; } if let Some(v) = self.filter.as_ref() { struct_ser.serialize_field("filter", v)?; @@ -8511,8 +8625,8 @@ impl<'de> serde::Deserialize<'de> for JoinNode { "leftJoinKey", "right_join_key", "rightJoinKey", - "null_equals_null", - "nullEqualsNull", + "null_equality", + "nullEquality", "filter", ]; @@ -8524,7 +8638,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode { JoinConstraint, LeftJoinKey, RightJoinKey, - NullEqualsNull, + NullEquality, Filter, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -8553,7 +8667,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode { "joinConstraint" | "join_constraint" => Ok(GeneratedField::JoinConstraint), "leftJoinKey" | "left_join_key" => Ok(GeneratedField::LeftJoinKey), "rightJoinKey" | "right_join_key" => Ok(GeneratedField::RightJoinKey), - "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), + "nullEquality" | "null_equality" => Ok(GeneratedField::NullEquality), "filter" => Ok(GeneratedField::Filter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -8580,7 +8694,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode { let mut join_constraint__ = None; let mut left_join_key__ = None; let mut right_join_key__ = None; - let mut null_equals_null__ = None; + let mut null_equality__ = None; let mut filter__ = None; while let Some(k) = map_.next_key()? { match k { @@ -8620,11 +8734,11 @@ impl<'de> serde::Deserialize<'de> for JoinNode { } right_join_key__ = Some(map_.next_value()?); } - GeneratedField::NullEqualsNull => { - if null_equals_null__.is_some() { - return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + GeneratedField::NullEquality => { + if null_equality__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEquality")); } - null_equals_null__ = Some(map_.next_value()?); + null_equality__ = Some(map_.next_value::()? as i32); } GeneratedField::Filter => { if filter__.is_some() { @@ -8641,7 +8755,7 @@ impl<'de> serde::Deserialize<'de> for JoinNode { join_constraint: join_constraint__.unwrap_or_default(), left_join_key: left_join_key__.unwrap_or_default(), right_join_key: right_join_key__.unwrap_or_default(), - null_equals_null: null_equals_null__.unwrap_or_default(), + null_equality: null_equality__.unwrap_or_default(), filter: filter__, }) } @@ -15777,6 +15891,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::JsonScan(v) => { struct_ser.serialize_field("jsonScan", v)?; } + physical_plan_node::PhysicalPlanType::Cooperative(v) => { + struct_ser.serialize_field("cooperative", v)?; + } } } struct_ser.end() @@ -15835,6 +15952,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "unnest", "json_scan", "jsonScan", + "cooperative", ]; #[allow(clippy::enum_variant_names)] @@ -15869,6 +15987,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { ParquetSink, Unnest, JsonScan, + Cooperative, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -15920,6 +16039,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "parquetSink" | "parquet_sink" => Ok(GeneratedField::ParquetSink), "unnest" => Ok(GeneratedField::Unnest), "jsonScan" | "json_scan" => Ok(GeneratedField::JsonScan), + "cooperative" => Ok(GeneratedField::Cooperative), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16150,6 +16270,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("jsonScan")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::JsonScan) +; + } + GeneratedField::Cooperative => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("cooperative")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Cooperative) ; } } @@ -20008,7 +20135,7 @@ impl serde::Serialize for SymmetricHashJoinExecNode { if self.partition_mode != 0 { len += 1; } - if self.null_equals_null { + if self.null_equality != 0 { len += 1; } if self.filter.is_some() { @@ -20040,8 +20167,10 @@ impl serde::Serialize for SymmetricHashJoinExecNode { .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.partition_mode)))?; struct_ser.serialize_field("partitionMode", &v)?; } - if self.null_equals_null { - struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + if self.null_equality != 0 { + let v = super::datafusion_common::NullEquality::try_from(self.null_equality) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.null_equality)))?; + struct_ser.serialize_field("nullEquality", &v)?; } if let Some(v) = self.filter.as_ref() { struct_ser.serialize_field("filter", v)?; @@ -20069,8 +20198,8 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { "joinType", "partition_mode", "partitionMode", - "null_equals_null", - "nullEqualsNull", + "null_equality", + "nullEquality", "filter", "left_sort_exprs", "leftSortExprs", @@ -20085,7 +20214,7 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { On, JoinType, PartitionMode, - NullEqualsNull, + NullEquality, Filter, LeftSortExprs, RightSortExprs, @@ -20115,7 +20244,7 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { "on" => Ok(GeneratedField::On), "joinType" | "join_type" => Ok(GeneratedField::JoinType), "partitionMode" | "partition_mode" => Ok(GeneratedField::PartitionMode), - "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), + "nullEquality" | "null_equality" => Ok(GeneratedField::NullEquality), "filter" => Ok(GeneratedField::Filter), "leftSortExprs" | "left_sort_exprs" => Ok(GeneratedField::LeftSortExprs), "rightSortExprs" | "right_sort_exprs" => Ok(GeneratedField::RightSortExprs), @@ -20143,7 +20272,7 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { let mut on__ = None; let mut join_type__ = None; let mut partition_mode__ = None; - let mut null_equals_null__ = None; + let mut null_equality__ = None; let mut filter__ = None; let mut left_sort_exprs__ = None; let mut right_sort_exprs__ = None; @@ -20179,11 +20308,11 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { } partition_mode__ = Some(map_.next_value::()? as i32); } - GeneratedField::NullEqualsNull => { - if null_equals_null__.is_some() { - return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + GeneratedField::NullEquality => { + if null_equality__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEquality")); } - null_equals_null__ = Some(map_.next_value()?); + null_equality__ = Some(map_.next_value::()? as i32); } GeneratedField::Filter => { if filter__.is_some() { @@ -20211,7 +20340,7 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { on: on__.unwrap_or_default(), join_type: join_type__.unwrap_or_default(), partition_mode: partition_mode__.unwrap_or_default(), - null_equals_null: null_equals_null__.unwrap_or_default(), + null_equality: null_equality__.unwrap_or_default(), filter: filter__, left_sort_exprs: left_sort_exprs__.unwrap_or_default(), right_sort_exprs: right_sort_exprs__.unwrap_or_default(), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index d2165dad4850..b0fc0ce60436 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -115,8 +115,8 @@ pub struct ListingTableScanNode { pub schema: ::core::option::Option, #[prost(message, repeated, tag = "6")] pub filters: ::prost::alloc::vec::Vec, - #[prost(string, repeated, tag = "7")] - pub table_partition_cols: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + #[prost(message, repeated, tag = "7")] + pub table_partition_cols: ::prost::alloc::vec::Vec, #[prost(bool, tag = "8")] pub collect_stat: bool, #[prost(uint32, tag = "9")] @@ -369,8 +369,8 @@ pub struct JoinNode { pub left_join_key: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "6")] pub right_join_key: ::prost::alloc::vec::Vec, - #[prost(bool, tag = "7")] - pub null_equals_null: bool, + #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")] + pub null_equality: i32, #[prost(message, optional, tag = "8")] pub filter: ::core::option::Option, } @@ -1048,7 +1048,7 @@ pub mod table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32" )] pub physical_plan_type: ::core::option::Option, } @@ -1118,6 +1118,8 @@ pub mod physical_plan_node { Unnest(::prost::alloc::boxed::Box), #[prost(message, tag = "31")] JsonScan(super::JsonScanExecNode), + #[prost(message, tag = "32")] + Cooperative(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1572,6 +1574,11 @@ pub struct AvroScanExecNode { pub base_conf: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct CooperativeExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct HashJoinExecNode { #[prost(message, optional, boxed, tag = "1")] pub left: ::core::option::Option<::prost::alloc::boxed::Box>, @@ -1583,8 +1590,8 @@ pub struct HashJoinExecNode { pub join_type: i32, #[prost(enumeration = "PartitionMode", tag = "6")] pub partition_mode: i32, - #[prost(bool, tag = "7")] - pub null_equals_null: bool, + #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")] + pub null_equality: i32, #[prost(message, optional, tag = "8")] pub filter: ::core::option::Option, #[prost(uint32, repeated, tag = "9")] @@ -1602,8 +1609,8 @@ pub struct SymmetricHashJoinExecNode { pub join_type: i32, #[prost(enumeration = "StreamPartitionMode", tag = "6")] pub partition_mode: i32, - #[prost(bool, tag = "7")] - pub null_equals_null: bool, + #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")] + pub null_equality: i32, #[prost(message, optional, tag = "8")] pub filter: ::core::option::Option, #[prost(message, repeated, tag = "9")] @@ -1824,6 +1831,8 @@ pub struct CoalesceBatchesExecNode { pub struct CoalescePartitionsExecNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(uint32, optional, tag = "2")] + pub fetch: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalHashRepartition { diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs index 5c33277dc9f7..620442c79e72 100644 --- a/datafusion/proto/src/logical_plan/file_formats.rs +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -205,10 +205,7 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { _ctx: &SessionContext, ) -> datafusion_common::Result> { let proto = CsvOptionsProto::decode(buf).map_err(|e| { - DataFusionError::Execution(format!( - "Failed to decode CsvOptionsProto: {:?}", - e - )) + DataFusionError::Execution(format!("Failed to decode CsvOptionsProto: {e:?}")) })?; let options: CsvOptions = (&proto).into(); Ok(Arc::new(CsvFormatFactory { @@ -233,7 +230,7 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { }); proto.encode(buf).map_err(|e| { - DataFusionError::Execution(format!("Failed to encode CsvOptions: {:?}", e)) + DataFusionError::Execution(format!("Failed to encode CsvOptions: {e:?}")) })?; Ok(()) @@ -316,8 +313,7 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec { ) -> datafusion_common::Result> { let proto = JsonOptionsProto::decode(buf).map_err(|e| { DataFusionError::Execution(format!( - "Failed to decode JsonOptionsProto: {:?}", - e + "Failed to decode JsonOptionsProto: {e:?}" )) })?; let options: JsonOptions = (&proto).into(); @@ -346,7 +342,7 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec { }); proto.encode(buf).map_err(|e| { - DataFusionError::Execution(format!("Failed to encode JsonOptions: {:?}", e)) + DataFusionError::Execution(format!("Failed to encode JsonOptions: {e:?}")) })?; Ok(()) @@ -580,6 +576,7 @@ impl From<&TableParquetOptionsProto> for TableParquetOptions { .iter() .map(|(k, v)| (k.clone(), Some(v.clone()))) .collect(), + crypto: Default::default(), } } } @@ -632,8 +629,7 @@ impl LogicalExtensionCodec for ParquetLogicalExtensionCodec { ) -> datafusion_common::Result> { let proto = TableParquetOptionsProto::decode(buf).map_err(|e| { DataFusionError::Execution(format!( - "Failed to decode TableParquetOptionsProto: {:?}", - e + "Failed to decode TableParquetOptionsProto: {e:?}" )) })?; let options: TableParquetOptions = (&proto).into(); @@ -663,8 +659,7 @@ impl LogicalExtensionCodec for ParquetLogicalExtensionCodec { proto.encode(buf).map_err(|e| { DataFusionError::Execution(format!( - "Failed to encode TableParquetOptionsProto: {:?}", - e + "Failed to encode TableParquetOptionsProto: {e:?}" )) })?; diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index cac2f9db1645..66ef0ebfe361 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -19,8 +19,8 @@ use std::sync::Arc; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ - exec_datafusion_err, internal_err, plan_datafusion_err, RecursionUnnestOption, - Result, ScalarValue, TableReference, UnnestOptions, + exec_datafusion_err, internal_err, plan_datafusion_err, NullEquality, + RecursionUnnestOption, Result, ScalarValue, TableReference, UnnestOptions, }; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{Alias, Placeholder, Sort}; @@ -205,6 +205,7 @@ impl From for JoinType { protobuf::JoinType::Leftanti => JoinType::LeftAnti, protobuf::JoinType::Rightanti => JoinType::RightAnti, protobuf::JoinType::Leftmark => JoinType::LeftMark, + protobuf::JoinType::Rightmark => JoinType::RightMark, } } } @@ -218,6 +219,15 @@ impl From for JoinConstraint { } } +impl From for NullEquality { + fn from(t: protobuf::NullEquality) -> Self { + match t { + protobuf::NullEquality::NullEqualsNothing => NullEquality::NullEqualsNothing, + protobuf::NullEquality::NullEqualsNull => NullEquality::NullEqualsNull, + } + } +} + impl From for WriteOp { fn from(t: protobuf::dml_node::Type) -> Self { match t { @@ -268,7 +278,7 @@ pub fn parse_expr( ExprType::Column(column) => Ok(Expr::Column(column.into())), ExprType::Literal(literal) => { let scalar_value: ScalarValue = literal.try_into()?; - Ok(Expr::Literal(scalar_value)) + Ok(Expr::Literal(scalar_value, None)) } ExprType::WindowExpr(expr) => { let window_function = expr @@ -296,11 +306,13 @@ pub fn parse_expr( window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = match &expr.fun_definition { Some(buf) => codec.try_decode_udaf(udaf_name, buf)?, - None => registry.udaf(udaf_name)?, + None => registry + .udaf(udaf_name) + .or_else(|_| codec.try_decode_udaf(udaf_name, &[]))?, }; let args = parse_exprs(&expr.exprs, registry, codec)?; - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), args, )) @@ -313,11 +325,13 @@ pub fn parse_expr( window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = match &expr.fun_definition { Some(buf) => codec.try_decode_udwf(udwf_name, buf)?, - None => registry.udwf(udwf_name)?, + None => registry + .udwf(udwf_name) + .or_else(|_| codec.try_decode_udwf(udwf_name, &[]))?, }; let args = parse_exprs(&expr.exprs, registry, codec)?; - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), args, )) @@ -540,7 +554,9 @@ pub fn parse_expr( }) => { let scalar_fn = match fun_definition { Some(buf) => codec.try_decode_udf(fun_name, buf)?, - None => registry.udf(fun_name.as_str())?, + None => registry + .udf(fun_name.as_str()) + .or_else(|_| codec.try_decode_udf(fun_name, &[]))?, }; Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( scalar_fn, @@ -550,7 +566,9 @@ pub fn parse_expr( ExprType::AggregateUdfExpr(pb) => { let agg_fn = match &pb.fun_definition { Some(buf) => codec.try_decode_udaf(&pb.fun_name, buf)?, - None => registry.udaf(&pb.fun_name)?, + None => registry + .udaf(&pb.fun_name) + .or_else(|_| codec.try_decode_udaf(&pb.fun_name, &[]))?, }; Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index c65569ef1cfb..1acf1ee27bfe 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -33,7 +33,7 @@ use crate::{ }; use crate::protobuf::{proto_error, ToProtoError}; -use arrow::datatypes::{DataType, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Schema, SchemaBuilder, SchemaRef}; use datafusion::datasource::cte_worktable::CteWorkTable; #[cfg(feature = "avro")] use datafusion::datasource::file_format::avro::AvroFormat; @@ -355,10 +355,7 @@ impl AsLogicalPlan for LogicalPlanNode { .as_ref() .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .transpose()? - .ok_or_else(|| { - DataFusionError::Internal("expression required".to_string()) - })?; - // .try_into()?; + .ok_or_else(|| proto_error("expression required"))?; LogicalPlanBuilder::from(input).filter(expr)?.build() } LogicalPlanType::Window(window) => { @@ -458,23 +455,25 @@ impl AsLogicalPlan for LogicalPlanNode { .map(ListingTableUrl::parse) .collect::, _>>()?; + let partition_columns = scan + .table_partition_cols + .iter() + .map(|col| { + let Some(arrow_type) = col.arrow_type.as_ref() else { + return Err(proto_error( + "Missing Arrow type in partition columns", + )); + }; + let arrow_type = DataType::try_from(arrow_type).map_err(|e| { + proto_error(format!("Received an unknown ArrowType: {e}")) + })?; + Ok((col.name.clone(), arrow_type)) + }) + .collect::>>()?; + let options = ListingOptions::new(file_format) .with_file_extension(&scan.file_extension) - .with_table_partition_cols( - scan.table_partition_cols - .iter() - .map(|col| { - ( - col.clone(), - schema - .field_with_name(col) - .unwrap() - .data_type() - .clone(), - ) - }) - .collect(), - ) + .with_table_partition_cols(partition_columns) .with_collect_stat(scan.collect_stat) .with_target_partitions(scan.target_partitions as usize) .with_file_sort_order(all_sort_orders); @@ -1046,7 +1045,6 @@ impl AsLogicalPlan for LogicalPlanNode { }) } }; - let schema: protobuf::Schema = schema.as_ref().try_into()?; let filters: Vec = serialize_exprs(filters, extension_codec)?; @@ -1099,6 +1097,21 @@ impl AsLogicalPlan for LogicalPlanNode { let options = listing_table.options(); + let mut builder = SchemaBuilder::from(schema.as_ref()); + for (idx, field) in schema.fields().iter().enumerate().rev() { + if options + .table_partition_cols + .iter() + .any(|(name, _)| name == field.name()) + { + builder.remove(idx); + } + } + + let schema = builder.finish(); + + let schema: protobuf::Schema = (&schema).try_into()?; + let mut exprs_vec: Vec = vec![]; for order in &options.file_sort_order { let expr_vec = SortExprNodeCollection { @@ -1107,6 +1120,23 @@ impl AsLogicalPlan for LogicalPlanNode { exprs_vec.push(expr_vec); } + let partition_columns = options + .table_partition_cols + .iter() + .map(|(name, arrow_type)| { + let arrow_type = protobuf::ArrowType::try_from(arrow_type) + .map_err(|e| { + proto_error(format!( + "Received an unknown ArrowType: {e}" + )) + })?; + Ok(protobuf::PartitionColumn { + name: name.clone(), + arrow_type: Some(arrow_type), + }) + }) + .collect::>>()?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::ListingScan( protobuf::ListingTableScanNode { @@ -1114,11 +1144,7 @@ impl AsLogicalPlan for LogicalPlanNode { table_name: Some(table_name.clone().into()), collect_stat: options.collect_stat, file_extension: options.file_extension.clone(), - table_partition_cols: options - .table_partition_cols - .iter() - .map(|x| x.0.clone()) - .collect::>(), + table_partition_cols: partition_columns, paths: listing_table .table_paths() .iter() @@ -1133,6 +1159,7 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } else if let Some(view_table) = source.downcast_ref::() { + let schema: protobuf::Schema = schema.as_ref().try_into()?; Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::ViewScan(Box::new( protobuf::ViewTableScanNode { @@ -1167,6 +1194,7 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } else { + let schema: protobuf::Schema = schema.as_ref().try_into()?; let mut bytes = vec![]; extension_codec .try_encode_table_provider(table_name, provider, &mut bytes) @@ -1299,7 +1327,7 @@ impl AsLogicalPlan for LogicalPlanNode { filter, join_type, join_constraint, - null_equals_null, + null_equality, .. }) => { let left: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( @@ -1324,6 +1352,8 @@ impl AsLogicalPlan for LogicalPlanNode { let join_type: protobuf::JoinType = join_type.to_owned().into(); let join_constraint: protobuf::JoinConstraint = join_constraint.to_owned().into(); + let null_equality: protobuf::NullEquality = + null_equality.to_owned().into(); let filter = filter .as_ref() .map(|e| serialize_expr(e, extension_codec)) @@ -1337,7 +1367,7 @@ impl AsLogicalPlan for LogicalPlanNode { join_constraint: join_constraint.into(), left_join_key, right_join_key, - null_equals_null: *null_equals_null, + null_equality: null_equality.into(), filter, }, ))), diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 841c31fa035f..b14ad7aadf58 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -21,7 +21,7 @@ use std::collections::HashMap; -use datafusion_common::{TableReference, UnnestOptions}; +use datafusion_common::{NullEquality, TableReference, UnnestOptions}; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, @@ -211,13 +211,16 @@ pub fn serialize_expr( .map(|r| vec![r.into()]) .unwrap_or(vec![]), alias: name.to_owned(), - metadata: metadata.to_owned().unwrap_or(HashMap::new()), + metadata: metadata + .as_ref() + .map(|m| m.to_hashmap()) + .unwrap_or(HashMap::new()), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::Alias(alias)), } } - Expr::Literal(value) => { + Expr::Literal(value, _) => { let pb_value: protobuf::ScalarValue = value.try_into()?; protobuf::LogicalExprNode { expr_type: Some(ExprType::Literal(pb_value)), @@ -302,40 +305,35 @@ pub fn serialize_expr( expr_type: Some(ExprType::SimilarTo(pb)), } } - Expr::WindowFunction(expr::WindowFunction { - ref fun, - params: - expr::WindowFunctionParams { - ref args, - ref partition_by, - ref order_by, - ref window_frame, - // TODO: support null treatment in proto - null_treatment: _, - }, - }) => { - let (window_function, fun_definition) = match fun { + Expr::WindowFunction(window_fun) => { + let expr::WindowFunction { + ref fun, + params: + expr::WindowFunctionParams { + ref args, + ref partition_by, + ref order_by, + ref window_frame, + // TODO: support null treatment in proto + null_treatment: _, + }, + } = window_fun.as_ref(); + let mut buf = Vec::new(); + let window_function = match fun { WindowFunctionDefinition::AggregateUDF(aggr_udf) => { - let mut buf = Vec::new(); let _ = codec.try_encode_udaf(aggr_udf, &mut buf); - ( - protobuf::window_expr_node::WindowFunction::Udaf( - aggr_udf.name().to_string(), - ), - (!buf.is_empty()).then_some(buf), + protobuf::window_expr_node::WindowFunction::Udaf( + aggr_udf.name().to_string(), ) } WindowFunctionDefinition::WindowUDF(window_udf) => { - let mut buf = Vec::new(); let _ = codec.try_encode_udwf(window_udf, &mut buf); - ( - protobuf::window_expr_node::WindowFunction::Udwf( - window_udf.name().to_string(), - ), - (!buf.is_empty()).then_some(buf), + protobuf::window_expr_node::WindowFunction::Udwf( + window_udf.name().to_string(), ) } }; + let fun_definition = (!buf.is_empty()).then_some(buf); let partition_by = serialize_exprs(partition_by, codec)?; let order_by = serialize_sorts(order_by, codec)?; @@ -687,6 +685,7 @@ impl From for protobuf::JoinType { JoinType::LeftAnti => protobuf::JoinType::Leftanti, JoinType::RightAnti => protobuf::JoinType::Rightanti, JoinType::LeftMark => protobuf::JoinType::Leftmark, + JoinType::RightMark => protobuf::JoinType::Rightmark, } } } @@ -700,6 +699,15 @@ impl From for protobuf::JoinConstraint { } } +impl From for protobuf::NullEquality { + fn from(t: NullEquality) -> Self { + match t { + NullEquality::NullEqualsNothing => protobuf::NullEquality::NullEqualsNothing, + NullEquality::NullEqualsNull => protobuf::NullEquality::NullEqualsNull, + } + } +} + impl From<&WriteOp> for protobuf::dml_node::Type { fn from(t: &WriteOp) -> Self { match t { diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index a886fc242545..1c60470b2218 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use arrow::compute::SortOptions; +use arrow::datatypes::Field; use chrono::{TimeZone, Utc}; use datafusion_expr::dml::InsertOp; use object_store::path::Path; @@ -101,13 +102,13 @@ pub fn parse_physical_sort_exprs( registry: &dyn FunctionRegistry, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, -) -> Result { +) -> Result> { proto .iter() .map(|sort_expr| { parse_physical_sort_expr(sort_expr, registry, input_schema, codec) }) - .collect::>() + .collect() } /// Parses a physical window expr from a protobuf. @@ -151,13 +152,13 @@ pub fn parse_physical_window_expr( protobuf::physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(udaf_name) => { WindowFunctionDefinition::AggregateUDF(match &proto.fun_definition { Some(buf) => codec.try_decode_udaf(udaf_name, buf)?, - None => registry.udaf(udaf_name)? + None => registry.udaf(udaf_name).or_else(|_| codec.try_decode_udaf(udaf_name, &[]))?, }) } protobuf::physical_window_expr_node::WindowFunction::UserDefinedWindowFunction(udwf_name) => { WindowFunctionDefinition::WindowUDF(match &proto.fun_definition { Some(buf) => codec.try_decode_udwf(udwf_name, buf)?, - None => registry.udwf(udwf_name)? + None => registry.udwf(udwf_name).or_else(|_| codec.try_decode_udwf(udwf_name, &[]))? }) } } @@ -174,7 +175,7 @@ pub fn parse_physical_window_expr( name, &window_node_expr, &partition_by, - order_by.as_ref(), + &order_by, Arc::new(window_frame), &extended_schema, false, @@ -354,7 +355,9 @@ pub fn parse_physical_expr( ExprType::ScalarUdf(e) => { let udf = match &e.fun_definition { Some(buf) => codec.try_decode_udf(&e.name, buf)?, - None => registry.udf(e.name.as_str())?, + None => registry + .udf(e.name.as_str()) + .or_else(|_| codec.try_decode_udf(&e.name, &[]))?, }; let scalar_fun_def = Arc::clone(&udf); @@ -365,7 +368,7 @@ pub fn parse_physical_expr( e.name.as_str(), scalar_fun_def, args, - convert_required!(e.return_type)?, + Field::new("f", convert_required!(e.return_type)?, true).into(), ) .with_nullable(e.nullable), ) @@ -525,13 +528,13 @@ pub fn parse_protobuf_file_scan_config( let mut output_ordering = vec![]; for node_collection in &proto.output_ordering { - let sort_expr = parse_physical_sort_exprs( + let sort_exprs = parse_physical_sort_exprs( &node_collection.physical_sort_expr_nodes, registry, &schema, codec, )?; - output_ordering.push(sort_expr); + output_ordering.extend(LexOrdering::new(sort_exprs)); } let config = FileScanConfigBuilder::new(object_store_url, file_schema, file_source) diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 90d071ab23f5..242b36786d07 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -64,6 +64,7 @@ use datafusion::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::physical_plan::coop::CooperativeExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::explain::ExplainExec; use datafusion::physical_plan::expressions::PhysicalSortExpr; @@ -309,7 +310,6 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { runtime, extension_codec, ), - #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] PhysicalPlanType::ParquetSink(sink) => self .try_into_parquet_sink_physical_plan( @@ -324,6 +324,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { runtime, extension_codec, ), + PhysicalPlanType::Cooperative(cooperative) => self + .try_into_cooperative_physical_plan( + cooperative, + registry, + runtime, + extension_codec, + ), } } @@ -513,6 +520,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { ); } + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_cooperative_exec( + exec, + extension_codec, + ); + } + let mut buf: Vec = vec![]; match extension_codec.try_encode(Arc::clone(&plan_clone), &mut buf) { Ok(_) => { @@ -728,7 +742,7 @@ impl protobuf::PhysicalPlanNode { let mut source = ParquetSource::new(options); if let Some(predicate) = predicate { - source = source.with_predicate(Arc::clone(&schema), predicate); + source = source.with_predicate(predicate); } let base_config = parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), @@ -792,7 +806,10 @@ impl protobuf::PhysicalPlanNode { ) -> Result> { let input: Arc = into_physical_plan(&merge.input, registry, runtime, extension_codec)?; - Ok(Arc::new(CoalescePartitionsExec::new(input))) + Ok(Arc::new( + CoalescePartitionsExec::new(input) + .with_fetch(merge.fetch.map(|f| f as usize)), + )) } fn try_into_repartition_physical_plan( @@ -1027,7 +1044,7 @@ impl protobuf::PhysicalPlanNode { ) }) .collect::>>()?; - let ordering_req: LexOrdering = agg_node + let order_bys = agg_node .ordering_req .iter() .map(|e| { @@ -1038,7 +1055,7 @@ impl protobuf::PhysicalPlanNode { extension_codec, ) }) - .collect::>()?; + .collect::>()?; agg_node .aggregate_function .as_ref() @@ -1047,7 +1064,12 @@ impl protobuf::PhysicalPlanNode { let agg_udf = match &agg_node.fun_definition { Some(buf) => extension_codec .try_decode_udaf(udaf_name, buf)?, - None => registry.udaf(udaf_name)?, + None => { + registry.udaf(udaf_name).or_else(|_| { + extension_codec + .try_decode_udaf(udaf_name, &[]) + })? + } }; AggregateExprBuilder::new(agg_udf, input_phy_expr) @@ -1055,7 +1077,7 @@ impl protobuf::PhysicalPlanNode { .alias(name) .with_ignore_nulls(agg_node.ignore_nulls) .with_distinct(agg_node.distinct) - .order_by(ordering_req) + .order_by(order_bys) .build() .map(Arc::new) } @@ -1130,6 +1152,13 @@ impl protobuf::PhysicalPlanNode { hashjoin.join_type )) })?; + let null_equality = protobuf::NullEquality::try_from(hashjoin.null_equality) + .map_err(|_| { + proto_error(format!( + "Received a HashJoinNode message with unknown NullEquality {}", + hashjoin.null_equality + )) + })?; let filter = hashjoin .filter .as_ref() @@ -1198,7 +1227,7 @@ impl protobuf::PhysicalPlanNode { &join_type.into(), projection, partition_mode, - hashjoin.null_equals_null, + null_equality.into(), )?)) } @@ -1241,6 +1270,13 @@ impl protobuf::PhysicalPlanNode { sym_join.join_type )) })?; + let null_equality = protobuf::NullEquality::try_from(sym_join.null_equality) + .map_err(|_| { + proto_error(format!( + "Received a SymmetricHashJoin message with unknown NullEquality {}", + sym_join.null_equality + )) + })?; let filter = sym_join .filter .as_ref() @@ -1284,11 +1320,7 @@ impl protobuf::PhysicalPlanNode { &left_schema, extension_codec, )?; - let left_sort_exprs = if left_sort_exprs.is_empty() { - None - } else { - Some(left_sort_exprs) - }; + let left_sort_exprs = LexOrdering::new(left_sort_exprs); let right_sort_exprs = parse_physical_sort_exprs( &sym_join.right_sort_exprs, @@ -1296,11 +1328,7 @@ impl protobuf::PhysicalPlanNode { &right_schema, extension_codec, )?; - let right_sort_exprs = if right_sort_exprs.is_empty() { - None - } else { - Some(right_sort_exprs) - }; + let right_sort_exprs = LexOrdering::new(right_sort_exprs); let partition_mode = protobuf::StreamPartitionMode::try_from( sym_join.partition_mode, @@ -1325,7 +1353,7 @@ impl protobuf::PhysicalPlanNode { on, filter, &join_type.into(), - sym_join.null_equals_null, + null_equality.into(), left_sort_exprs, right_sort_exprs, partition_mode, @@ -1412,47 +1440,45 @@ impl protobuf::PhysicalPlanNode { runtime: &RuntimeEnv, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let input: Arc = - into_physical_plan(&sort.input, registry, runtime, extension_codec)?; + let input = into_physical_plan(&sort.input, registry, runtime, extension_codec)?; let exprs = sort - .expr - .iter() - .map(|expr| { - let expr = expr.expr_type.as_ref().ok_or_else(|| { + .expr + .iter() + .map(|expr| { + let expr = expr.expr_type.as_ref().ok_or_else(|| { + proto_error(format!( + "physical_plan::from_proto() Unexpected expr {self:?}" + )) + })?; + if let ExprType::Sort(sort_expr) = expr { + let expr = sort_expr + .expr + .as_ref() + .ok_or_else(|| { proto_error(format!( - "physical_plan::from_proto() Unexpected expr {self:?}" + "physical_plan::from_proto() Unexpected sort expr {self:?}" )) - })?; - if let ExprType::Sort(sort_expr) = expr { - let expr = sort_expr - .expr - .as_ref() - .ok_or_else(|| { - proto_error(format!( - "physical_plan::from_proto() Unexpected sort expr {self:?}" - )) - })? - .as_ref(); - Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, registry, input.schema().as_ref(), extension_codec)?, - options: SortOptions { - descending: !sort_expr.asc, - nulls_first: sort_expr.nulls_first, - }, - }) - } else { - internal_err!( - "physical_plan::from_proto() {self:?}" - ) - } + })? + .as_ref(); + Ok(PhysicalSortExpr { + expr: parse_physical_expr(expr, registry, input.schema().as_ref(), extension_codec)?, + options: SortOptions { + descending: !sort_expr.asc, + nulls_first: sort_expr.nulls_first, + }, }) - .collect::>()?; - let fetch = if sort.fetch < 0 { - None - } else { - Some(sort.fetch as usize) + } else { + internal_err!( + "physical_plan::from_proto() {self:?}" + ) + } + }) + .collect::>>()?; + let Some(ordering) = LexOrdering::new(exprs) else { + return internal_err!("SortExec requires an ordering"); }; - let new_sort = SortExec::new(exprs, input) + let fetch = (sort.fetch >= 0).then_some(sort.fetch as _); + let new_sort = SortExec::new(ordering, input) .with_fetch(fetch) .with_preserve_partitioning(sort.preserve_partitioning); @@ -1466,8 +1492,7 @@ impl protobuf::PhysicalPlanNode { runtime: &RuntimeEnv, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let input: Arc = - into_physical_plan(&sort.input, registry, runtime, extension_codec)?; + let input = into_physical_plan(&sort.input, registry, runtime, extension_codec)?; let exprs = sort .expr .iter() @@ -1503,14 +1528,13 @@ impl protobuf::PhysicalPlanNode { internal_err!("physical_plan::from_proto() {self:?}") } }) - .collect::>()?; - let fetch = if sort.fetch < 0 { - None - } else { - Some(sort.fetch as usize) + .collect::>>()?; + let Some(ordering) = LexOrdering::new(exprs) else { + return internal_err!("SortExec requires an ordering"); }; + let fetch = (sort.fetch >= 0).then_some(sort.fetch as _); Ok(Arc::new( - SortPreservingMergeExec::new(exprs, input).with_fetch(fetch), + SortPreservingMergeExec::new(ordering, input).with_fetch(fetch), )) } @@ -1649,9 +1673,12 @@ impl protobuf::PhysicalPlanNode { &sink_schema, extension_codec, ) - .map(LexRequirement::from) + .map(|sort_exprs| { + LexRequirement::new(sort_exprs.into_iter().map(Into::into)) + }) }) - .transpose()?; + .transpose()? + .flatten(); Ok(Arc::new(DataSinkExec::new( input, Arc::new(data_sink), @@ -1684,9 +1711,12 @@ impl protobuf::PhysicalPlanNode { &sink_schema, extension_codec, ) - .map(LexRequirement::from) + .map(|sort_exprs| { + LexRequirement::new(sort_exprs.into_iter().map(Into::into)) + }) }) - .transpose()?; + .transpose()? + .flatten(); Ok(Arc::new(DataSinkExec::new( input, Arc::new(data_sink), @@ -1722,9 +1752,12 @@ impl protobuf::PhysicalPlanNode { &sink_schema, extension_codec, ) - .map(LexRequirement::from) + .map(|sort_exprs| { + LexRequirement::new(sort_exprs.into_iter().map(Into::into)) + }) }) - .transpose()?; + .transpose()? + .flatten(); Ok(Arc::new(DataSinkExec::new( input, Arc::new(data_sink), @@ -1761,6 +1794,18 @@ impl protobuf::PhysicalPlanNode { ))) } + fn try_into_cooperative_physical_plan( + &self, + field_stream: &protobuf::CooperativeExecNode, + registry: &dyn FunctionRegistry, + runtime: &RuntimeEnv, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let input = + into_physical_plan(&field_stream.input, registry, runtime, extension_codec)?; + Ok(Arc::new(CooperativeExec::new(input))) + } + fn try_from_explain_exec( exec: &ExplainExec, _extension_codec: &dyn PhysicalExtensionCodec, @@ -1916,6 +1961,7 @@ impl protobuf::PhysicalPlanNode { }) .collect::>()?; let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let null_equality: protobuf::NullEquality = exec.null_equality().into(); let filter = exec .filter() .as_ref() @@ -1956,7 +2002,7 @@ impl protobuf::PhysicalPlanNode { on, join_type: join_type.into(), partition_mode: partition_mode.into(), - null_equals_null: exec.null_equals_null(), + null_equality: null_equality.into(), filter, projection: exec.projection.as_ref().map_or_else(Vec::new, |v| { v.iter().map(|x| *x as u32).collect::>() @@ -1991,6 +2037,7 @@ impl protobuf::PhysicalPlanNode { }) .collect::>()?; let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let null_equality: protobuf::NullEquality = exec.null_equality().into(); let filter = exec .filter() .as_ref() @@ -2074,7 +2121,7 @@ impl protobuf::PhysicalPlanNode { on, join_type: join_type.into(), partition_mode: partition_mode.into(), - null_equals_null: exec.null_equals_null(), + null_equality: null_equality.into(), left_sort_exprs, right_sort_exprs, filter, @@ -2354,6 +2401,7 @@ impl protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Merge(Box::new( protobuf::CoalescePartitionsExecNode { input: Some(Box::new(input)), + fetch: exec.fetch().map(|f| f as u32), }, ))), }) @@ -2739,6 +2787,24 @@ impl protobuf::PhysicalPlanNode { ))), }) } + + fn try_from_cooperative_exec( + exec: &CooperativeExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Cooperative(Box::new( + protobuf::CooperativeExecNode { + input: Some(Box::new(input)), + }, + ))), + }) + } } pub trait AsExecutionPlan: Debug + Send + Sync + Clone { diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index d1b1f51ae107..d22a0b545161 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -21,8 +21,9 @@ use std::sync::Arc; use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::datasource::physical_plan::FileSink; use datafusion::physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; -use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; +use datafusion::physical_expr::ScalarFunctionExpr; use datafusion::physical_expr_common::physical_expr::snapshot_physical_expr; +use datafusion::physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion::physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, @@ -53,11 +54,8 @@ pub fn serialize_physical_aggr_expr( codec: &dyn PhysicalExtensionCodec, ) -> Result { let expressions = serialize_physical_exprs(&aggr_expr.expressions(), codec)?; - let ordering_req = match aggr_expr.order_bys() { - Some(order) => order.clone(), - None => LexOrdering::default(), - }; - let ordering_req = serialize_physical_sort_exprs(ordering_req, codec)?; + let order_bys = + serialize_physical_sort_exprs(aggr_expr.order_bys().iter().cloned(), codec)?; let name = aggr_expr.fun().name().to_string(); let mut buf = Vec::new(); @@ -67,7 +65,7 @@ pub fn serialize_physical_aggr_expr( protobuf::PhysicalAggregateExprNode { aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), expr: expressions, - ordering_req, + ordering_req: order_bys, distinct: aggr_expr.is_distinct(), ignore_nulls: aggr_expr.ignore_nulls(), fun_definition: (!buf.is_empty()).then_some(buf), @@ -506,7 +504,7 @@ pub fn serialize_file_scan_config( .iter() .cloned() .collect::>(); - fields.extend(conf.table_partition_cols.iter().cloned().map(Arc::new)); + fields.extend(conf.table_partition_cols.iter().cloned()); let schema = Arc::new(arrow::datatypes::Schema::new(fields.clone())); Ok(protobuf::FileScanExecConf { diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs index 92d961fc7556..4c7da2768e74 100644 --- a/datafusion/proto/tests/cases/mod.rs +++ b/datafusion/proto/tests/cases/mod.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::plan_err; use datafusion_expr::function::AccumulatorArgs; @@ -166,8 +166,11 @@ impl WindowUDFImpl for CustomUDWF { Ok(Box::new(CustomUDWFEvaluator {})) } - fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { - Ok(Field::new(field_args.name(), DataType::UInt64, false)) + fn field( + &self, + field_args: WindowUDFFieldArgs, + ) -> datafusion_common::Result { + Ok(Field::new(field_args.name(), DataType::UInt64, false).into()) } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 9fa1f74ae188..993cc6f87ca3 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -19,12 +19,15 @@ use arrow::array::{ ArrayRef, FixedSizeListArray, Int32Builder, MapArray, MapBuilder, StringBuilder, }; use arrow::datatypes::{ - DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, - IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, - DECIMAL256_MAX_PRECISION, + DataType, Field, FieldRef, Fields, Int32Type, IntervalDayTimeType, + IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, + UnionMode, DECIMAL256_MAX_PRECISION, }; use arrow::util::pretty::pretty_format_batches; -use datafusion::datasource::file_format::json::JsonFormatFactory; +use datafusion::datasource::file_format::json::{JsonFormat, JsonFormatFactory}; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; use datafusion::optimizer::eliminate_nested_union::EliminateNestedUnion; use datafusion::optimizer::Optimizer; use datafusion_common::parsers::CompressionTypeVariant; @@ -110,15 +113,21 @@ fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) { #[cfg(not(feature = "json"))] fn roundtrip_json_test(_proto: &protobuf::LogicalExprNode) {} -// Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test -// equality. fn roundtrip_expr_test(initial_struct: Expr, ctx: SessionContext) { let extension_codec = DefaultLogicalExtensionCodec {}; - let proto: protobuf::LogicalExprNode = - serialize_expr(&initial_struct, &extension_codec) - .unwrap_or_else(|e| panic!("Error serializing expression: {:?}", e)); - let round_trip: Expr = - from_proto::parse_expr(&proto, &ctx, &extension_codec).unwrap(); + roundtrip_expr_test_with_codec(initial_struct, ctx, &extension_codec); +} + +// Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test +// equality. +fn roundtrip_expr_test_with_codec( + initial_struct: Expr, + ctx: SessionContext, + codec: &dyn LogicalExtensionCodec, +) { + let proto: protobuf::LogicalExprNode = serialize_expr(&initial_struct, codec) + .unwrap_or_else(|e| panic!("Error serializing expression: {e:?}")); + let round_trip: Expr = from_proto::parse_expr(&proto, &ctx, codec).unwrap(); assert_eq!(format!("{:?}", &initial_struct), format!("{round_trip:?}")); @@ -970,8 +979,8 @@ async fn roundtrip_expr_api() -> Result<()> { stddev_pop(lit(2.2)), approx_distinct(lit(2)), approx_median(lit(2)), - approx_percentile_cont(lit(2), lit(0.5), None), - approx_percentile_cont(lit(2), lit(0.5), Some(lit(50))), + approx_percentile_cont(lit(2).sort(true, false), lit(0.5), None), + approx_percentile_cont(lit(2).sort(true, false), lit(0.5), Some(lit(50))), approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)), grouping(lit(1)), bit_and(lit(2)), @@ -1959,7 +1968,7 @@ fn roundtrip_case_with_null() { let test_expr = Expr::Case(Case::new( Some(Box::new(lit(1.0_f32))), vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], - Some(Box::new(Expr::Literal(ScalarValue::Null))), + Some(Box::new(Expr::Literal(ScalarValue::Null, None))), )); let ctx = SessionContext::new(); @@ -1968,7 +1977,7 @@ fn roundtrip_case_with_null() { #[test] fn roundtrip_null_literal() { - let test_expr = Expr::Literal(ScalarValue::Null); + let test_expr = Expr::Literal(ScalarValue::Null, None); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2182,8 +2191,7 @@ fn roundtrip_aggregate_udf() { roundtrip_expr_test(test_expr, ctx); } -#[test] -fn roundtrip_scalar_udf() { +fn dummy_udf() -> ScalarUDF { let scalar_fn = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { panic!("should be array") @@ -2191,13 +2199,18 @@ fn roundtrip_scalar_udf() { Ok(ColumnarValue::from(Arc::new(array.clone()) as ArrayRef)) }); - let udf = create_udf( + create_udf( "dummy", vec![DataType::Utf8], DataType::Utf8, Volatility::Immutable, scalar_fn, - ); + ) +} + +#[test] +fn roundtrip_scalar_udf() { + let udf = dummy_udf(); let test_expr = Expr::ScalarFunction(ScalarFunction::new_udf( Arc::new(udf.clone()), @@ -2207,7 +2220,57 @@ fn roundtrip_scalar_udf() { let ctx = SessionContext::new(); ctx.register_udf(udf); - roundtrip_expr_test(test_expr, ctx); + roundtrip_expr_test(test_expr.clone(), ctx); + + // Now test loading the UDF without registering it in the context, but rather creating it in the + // extension codec. + #[derive(Debug)] + struct DummyUDFExtensionCodec; + + impl LogicalExtensionCodec for DummyUDFExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[LogicalPlan], + _ctx: &SessionContext, + ) -> Result { + not_impl_err!("LogicalExtensionCodec is not provided") + } + + fn try_encode(&self, _node: &Extension, _buf: &mut Vec) -> Result<()> { + not_impl_err!("LogicalExtensionCodec is not provided") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _table_ref: &TableReference, + _schema: SchemaRef, + _ctx: &SessionContext, + ) -> Result> { + not_impl_err!("LogicalExtensionCodec is not provided") + } + + fn try_encode_table_provider( + &self, + _table_ref: &TableReference, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + not_impl_err!("LogicalExtensionCodec is not provided") + } + + fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { + if name == "dummy" { + Ok(Arc::new(dummy_udf())) + } else { + Err(DataFusionError::Internal(format!("UDF {name} not found"))) + } + } + } + + let ctx = SessionContext::new(); + roundtrip_expr_test_with_codec(test_expr, ctx, &DummyUDFExtensionCodec) } #[test] @@ -2296,7 +2359,7 @@ fn roundtrip_window() { let ctx = SessionContext::new(); // 1. without window_frame - let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr1 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2307,7 +2370,7 @@ fn roundtrip_window() { .unwrap(); // 2. with default window_frame - let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr2 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2324,7 +2387,7 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); - let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr3 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2341,7 +2404,7 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); - let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr4 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("col1")], )) @@ -2391,7 +2454,7 @@ fn roundtrip_window() { Arc::new(vec![DataType::Float64, DataType::UInt32]), ); - let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr5 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), vec![col("col1")], )) @@ -2453,14 +2516,18 @@ fn roundtrip_window() { make_partition_evaluator() } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - if let Some(return_type) = field_args.get_input_type(0) { - Ok(Field::new(field_args.name(), return_type, true)) + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + if let Some(return_field) = field_args.get_input_field(0) { + Ok(return_field + .as_ref() + .clone() + .with_name(field_args.name()) + .into()) } else { plan_err!( "dummy_udwf expects 1 argument, got {}: {:?}", - field_args.input_types().len(), - field_args.input_types() + field_args.input_fields().len(), + field_args.input_fields() ) } } @@ -2472,7 +2539,7 @@ fn roundtrip_window() { let dummy_window_udf = WindowUDF::from(SimpleWindowUDF::new()); - let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr6 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), vec![col("col1")], )) @@ -2482,7 +2549,7 @@ fn roundtrip_window() { .build() .unwrap(); - let text_expr7 = Expr::WindowFunction(expr::WindowFunction::new( + let text_expr7 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(avg_udaf()), vec![col("col1")], )) @@ -2559,3 +2626,33 @@ async fn roundtrip_union_query() -> Result<()> { ); Ok(()) } + +#[tokio::test] +async fn roundtrip_custom_listing_tables_schema() -> Result<()> { + let ctx = SessionContext::new(); + // Make sure during round-trip, constraint information is preserved + let file_format = JsonFormat::default(); + let table_partition_cols = vec![("part".to_owned(), DataType::Int64)]; + let data = "../core/tests/data/partitioned_table_json"; + let listing_table_url = ListingTableUrl::parse(data)?; + let listing_options = ListingOptions::new(Arc::new(file_format)) + .with_table_partition_cols(table_partition_cols); + + let config = ListingTableConfig::new(listing_table_url) + .with_listing_options(listing_options) + .infer_schema(&ctx.state()) + .await?; + + ctx.register_table("hive_style", Arc::new(ListingTable::try_new(config)?))?; + + let plan = ctx + .sql("SELECT part, value FROM hive_style LIMIT 1") + .await? + .logical_plan() + .clone(); + + let bytes = logical_plan_to_bytes(&plan)?; + let new_plan = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(plan, new_plan); + Ok(()) +} diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index be90497a6e21..43f9942a0a06 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -60,12 +60,13 @@ use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; use datafusion::physical_expr::expressions::Literal; use datafusion::physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; use datafusion::physical_expr::{ - LexOrdering, LexRequirement, PhysicalSortRequirement, ScalarFunctionExpr, + LexOrdering, PhysicalSortRequirement, ScalarFunctionExpr, }; use datafusion::physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; use datafusion::physical_plan::analyze::AnalyzeExec; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{ binary, cast, col, in_list, like, lit, BinaryExpr, Column, NotExpr, PhysicalSortExpr, @@ -73,6 +74,7 @@ use datafusion::physical_plan::expressions::{ use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, + SymmetricHashJoinExec, }; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; @@ -96,7 +98,7 @@ use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{ - internal_err, not_impl_err, DataFusionError, Result, UnnestOptions, + internal_err, not_impl_err, DataFusionError, NullEquality, Result, UnnestOptions, }; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, @@ -108,8 +110,7 @@ use datafusion_functions_aggregate::string_agg::string_agg_udaf; use datafusion_proto::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; -use datafusion_proto::protobuf; -use datafusion_proto::protobuf::PhysicalPlanNode; +use datafusion_proto::protobuf::{self, PhysicalPlanNode}; /// Perform a serde roundtrip and assert that the string representation of the before and after plans /// are identical. Note that this often isn't sufficient to guarantee that no information is @@ -267,7 +268,7 @@ fn roundtrip_hash_join() -> Result<()> { join_type, None, *partition_mode, - false, + NullEquality::NullEqualsNothing, )?))?; } } @@ -320,9 +321,9 @@ fn roundtrip_udwf() -> Result<()> { &[ col("a", &schema)? ], - &LexOrdering::new(vec![ - PhysicalSortExpr::new(col("b", &schema)?, SortOptions::new(true, true)), - ]), + &[ + PhysicalSortExpr::new(col("b", &schema)?, SortOptions::new(true, true)) + ], Arc::new(WindowFrame::new(None)), )); @@ -359,13 +360,13 @@ fn roundtrip_window() -> Result<()> { let udwf_expr = Arc::new(StandardWindowExpr::new( nth_value_window, &[col("b", &schema)?], - &LexOrdering::new(vec![PhysicalSortExpr { + &[PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }]), + }], Arc::new(window_frame), )); @@ -379,7 +380,7 @@ fn roundtrip_window() -> Result<()> { .build() .map(Arc::new)?, &[], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new(None)), )); @@ -399,7 +400,7 @@ fn roundtrip_window() -> Result<()> { let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new( sum_expr, &[], - &LexOrdering::default(), + &[], Arc::new(window_frame), )); @@ -504,7 +505,7 @@ fn rountrip_aggregate_with_approx_pencentile_cont() -> Result<()> { vec![col("b", &schema)?, lit(0.5)], ) .schema(Arc::clone(&schema)) - .alias("APPROX_PERCENTILE_CONT(b, 0.5)") + .alias("APPROX_PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY b)") .build() .map(Arc::new)?]; @@ -527,13 +528,13 @@ fn rountrip_aggregate_with_sort() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let sort_exprs = LexOrdering::new(vec![PhysicalSortExpr { + let sort_exprs = vec![PhysicalSortExpr { expr: col("b", &schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]); + }]; let aggregates = vec![ @@ -594,7 +595,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { Signature::exact(vec![DataType::Int64], Volatility::Immutable), return_type, accumulator, - vec![Field::new("value", DataType::Int64, true)], + vec![Field::new("value", DataType::Int64, true).into()], )); let ctx = SessionContext::new(); @@ -653,7 +654,7 @@ fn roundtrip_sort() -> Result<()> { let field_a = Field::new("a", DataType::Boolean, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let sort_exprs = LexOrdering::new(vec![ + let sort_exprs = [ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { @@ -668,7 +669,8 @@ fn roundtrip_sort() -> Result<()> { nulls_first: true, }, }, - ]); + ] + .into(); roundtrip_test(Arc::new(SortExec::new( sort_exprs, Arc::new(EmptyExec::new(schema)), @@ -680,7 +682,7 @@ fn roundtrip_sort_preserve_partitioning() -> Result<()> { let field_a = Field::new("a", DataType::Boolean, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let sort_exprs = LexOrdering::new(vec![ + let sort_exprs: LexOrdering = [ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { @@ -695,7 +697,8 @@ fn roundtrip_sort_preserve_partitioning() -> Result<()> { nulls_first: true, }, }, - ]); + ] + .into(); roundtrip_test(Arc::new(SortExec::new( sort_exprs.clone(), @@ -709,7 +712,7 @@ fn roundtrip_sort_preserve_partitioning() -> Result<()> { } #[test] -fn roundtrip_coalesce_with_fetch() -> Result<()> { +fn roundtrip_coalesce_batches_with_fetch() -> Result<()> { let field_a = Field::new("a", DataType::Boolean, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); @@ -725,6 +728,22 @@ fn roundtrip_coalesce_with_fetch() -> Result<()> { )) } +#[test] +fn roundtrip_coalesce_partitions_with_fetch() -> Result<()> { + let field_a = Field::new("a", DataType::Boolean, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + roundtrip_test(Arc::new(CoalescePartitionsExec::new(Arc::new( + EmptyExec::new(schema.clone()), + ))))?; + + roundtrip_test(Arc::new( + CoalescePartitionsExec::new(Arc::new(EmptyExec::new(schema))) + .with_fetch(Some(10)), + )) +} + #[test] fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { let file_schema = @@ -739,9 +758,7 @@ fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { let mut options = TableParquetOptions::new(); options.global.pushdown_filters = true; - let file_source = Arc::new( - ParquetSource::new(options).with_predicate(Arc::clone(&file_schema), predicate), - ); + let file_source = Arc::new(ParquetSource::new(options).with_predicate(predicate)); let scan_config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), @@ -800,10 +817,8 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { inner: Arc::new(Column::new("col", 1)), }); - let file_source = Arc::new( - ParquetSource::default() - .with_predicate(Arc::clone(&file_schema), custom_predicate_expr), - ); + let file_source = + Arc::new(ParquetSource::default().with_predicate(custom_predicate_expr)); let scan_config = FileScanConfigBuilder::new( ObjectStoreUrl::local_filesystem(), @@ -968,7 +983,7 @@ fn roundtrip_scalar_udf() -> Result<()> { "dummy", fun_def, vec![col("a", &schema)?], - DataType::Int64, + Field::new("f", DataType::Int64, true).into(), ); let project = @@ -1096,7 +1111,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { "regex_udf", Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), vec![col("text", &schema)?], - DataType::Int64, + Field::new("f", DataType::Int64, true).into(), )); let filter = Arc::new(FilterExec::try_new( @@ -1118,7 +1133,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { vec![Arc::new(PlainAggregateWindowExpr::new( aggr_expr.clone(), &[col("author", &schema)?], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new(None)), ))], filter, @@ -1163,13 +1178,13 @@ fn roundtrip_udwf_extension_codec() -> Result<()> { let udwf_expr = Arc::new(StandardWindowExpr::new( udwf, &[col("b", &schema)?], - &LexOrdering::new(vec![PhysicalSortExpr { + &[PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }]), + }], Arc::new(window_frame), )); @@ -1198,7 +1213,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { "regex_udf", Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), vec![col("text", &schema)?], - DataType::Int64, + Field::new("f", DataType::Int64, true).into(), )); let udaf = Arc::new(AggregateUDF::from(MyAggregateUDF::new( @@ -1226,7 +1241,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { vec![Arc::new(PlainAggregateWindowExpr::new( aggr_expr, &[col("author", &schema)?], - &LexOrdering::default(), + &[], Arc::new(WindowFrame::new(None)), ))], filter, @@ -1322,13 +1337,14 @@ fn roundtrip_json_sink() -> Result<()> { file_sink_config, JsonWriterOptions::new(CompressionTypeVariant::UNCOMPRESSED), )); - let sort_order = LexRequirement::new(vec![PhysicalSortRequirement::new( + let sort_order = [PhysicalSortRequirement::new( Arc::new(Column::new("plan_type", 0)), Some(SortOptions { descending: true, nulls_first: false, }), - )]); + )] + .into(); roundtrip_test(Arc::new(DataSinkExec::new( input, @@ -1359,13 +1375,14 @@ fn roundtrip_csv_sink() -> Result<()> { file_sink_config, CsvWriterOptions::new(WriterBuilder::default(), CompressionTypeVariant::ZSTD), )); - let sort_order = LexRequirement::new(vec![PhysicalSortRequirement::new( + let sort_order = [PhysicalSortRequirement::new( Arc::new(Column::new("plan_type", 0)), Some(SortOptions { descending: true, nulls_first: false, }), - )]); + )] + .into(); let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; @@ -1415,13 +1432,14 @@ fn roundtrip_parquet_sink() -> Result<()> { file_sink_config, TableParquetOptions::default(), )); - let sort_order = LexRequirement::new(vec![PhysicalSortRequirement::new( + let sort_order = [PhysicalSortRequirement::new( Arc::new(Column::new("plan_type", 0)), Some(SortOptions { descending: true, nulls_first: false, }), - )]); + )] + .into(); roundtrip_test(Arc::new(DataSinkExec::new( input, @@ -1458,31 +1476,29 @@ fn roundtrip_sym_hash_join() -> Result<()> { ] { for left_order in &[ None, - Some(LexOrdering::new(vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::new(Column::new("col", schema_left.index_of("col")?)), options: Default::default(), - }])), + }]), ] { - for right_order in &[ + for right_order in [ None, - Some(LexOrdering::new(vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::new(Column::new("col", schema_right.index_of("col")?)), options: Default::default(), - }])), + }]), ] { - roundtrip_test(Arc::new( - datafusion::physical_plan::joins::SymmetricHashJoinExec::try_new( - Arc::new(EmptyExec::new(schema_left.clone())), - Arc::new(EmptyExec::new(schema_right.clone())), - on.clone(), - None, - join_type, - false, - left_order.clone(), - right_order.clone(), - *partition_mode, - )?, - ))?; + roundtrip_test(Arc::new(SymmetricHashJoinExec::try_new( + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), + on.clone(), + None, + join_type, + NullEquality::NullEqualsNothing, + left_order.clone(), + right_order, + *partition_mode, + )?))?; } } } diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index d15e62909f7e..c9ef4377d43b 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -83,7 +83,7 @@ fn udf_roundtrip_with_registry() { #[test] #[should_panic( - expected = "No function registry provided to deserialize, so can not deserialize User Defined Function 'dummy'" + expected = "LogicalExtensionCodec is not provided for scalar function dummy" )] fn udf_roundtrip_without_registry() { let ctx = context_with_udf(); @@ -256,7 +256,7 @@ fn test_expression_serialization_roundtrip() { use datafusion_proto::logical_plan::from_proto::parse_expr; let ctx = SessionContext::new(); - let lit = Expr::Literal(ScalarValue::Utf8(None)); + let lit = Expr::Literal(ScalarValue::Utf8(None), None); for function in string::functions() { // default to 4 args (though some exprs like substr have error checking) let num_args = 4; diff --git a/datafusion/pruning/Cargo.toml b/datafusion/pruning/Cargo.toml new file mode 100644 index 000000000000..6acf178e4e2b --- /dev/null +++ b/datafusion/pruning/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "datafusion-pruning" +description = "DataFusion Pruning Logic" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } + +[lints] +workspace = true + +[dependencies] +arrow = { workspace = true } +arrow-schema = { workspace = true } +datafusion-common = { workspace = true, default-features = true } +datafusion-datasource = { workspace = true } +datafusion-expr-common = { workspace = true, default-features = true } +datafusion-physical-expr = { workspace = true } +datafusion-physical-expr-common = { workspace = true } +datafusion-physical-plan = { workspace = true } +itertools = { workspace = true } +log = { workspace = true } + +[dev-dependencies] +datafusion-expr = { workspace = true } +datafusion-functions-nested = { workspace = true } +insta = { workspace = true } diff --git a/datafusion/pruning/src/file_pruner.rs b/datafusion/pruning/src/file_pruner.rs new file mode 100644 index 000000000000..bce1a64edaa3 --- /dev/null +++ b/datafusion/pruning/src/file_pruner.rs @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! File-level pruning based on partition values and file-level statistics + +use std::sync::Arc; + +use arrow::datatypes::{FieldRef, Schema, SchemaRef}; +use datafusion_common::{ + pruning::{ + CompositePruningStatistics, PartitionPruningStatistics, PrunableStatistics, + PruningStatistics, + }, + Result, +}; +use datafusion_datasource::PartitionedFile; +use datafusion_physical_expr_common::physical_expr::{snapshot_generation, PhysicalExpr}; +use datafusion_physical_plan::metrics::Count; +use itertools::Itertools; +use log::debug; + +use crate::build_pruning_predicate; + +/// Prune based on partition values and file-level statistics. +pub struct FilePruner { + predicate_generation: Option, + predicate: Arc, + /// Schema used for pruning, which combines the file schema and partition fields. + /// Partition fields are always at the end, as they are during scans. + pruning_schema: Arc, + file: PartitionedFile, + partition_fields: Vec, + predicate_creation_errors: Count, +} + +impl FilePruner { + pub fn new( + predicate: Arc, + logical_file_schema: &SchemaRef, + partition_fields: Vec, + file: PartitionedFile, + predicate_creation_errors: Count, + ) -> Result { + // Build a pruning schema that combines the file fields and partition fields. + // Partition fileds are always at the end. + let pruning_schema = Arc::new( + Schema::new( + logical_file_schema + .fields() + .iter() + .cloned() + .chain(partition_fields.iter().cloned()) + .collect_vec(), + ) + .with_metadata(logical_file_schema.metadata().clone()), + ); + Ok(Self { + // Initialize the predicate generation to None so that the first time we call `should_prune` we actually check the predicate + // Subsequent calls will only do work if the predicate itself has changed. + // See `snapshot_generation` for more info. + predicate_generation: None, + predicate, + pruning_schema, + file, + partition_fields, + predicate_creation_errors, + }) + } + + pub fn should_prune(&mut self) -> Result { + let new_generation = snapshot_generation(&self.predicate); + if let Some(current_generation) = self.predicate_generation.as_mut() { + if *current_generation == new_generation { + return Ok(false); + } + *current_generation = new_generation; + } else { + self.predicate_generation = Some(new_generation); + } + let pruning_predicate = build_pruning_predicate( + Arc::clone(&self.predicate), + &self.pruning_schema, + &self.predicate_creation_errors, + ); + if let Some(pruning_predicate) = pruning_predicate { + // The partition column schema is the schema of the table - the schema of the file + let mut pruning = Box::new(PartitionPruningStatistics::try_new( + vec![self.file.partition_values.clone()], + self.partition_fields.clone(), + )?) as Box; + if let Some(stats) = &self.file.statistics { + let stats_pruning = Box::new(PrunableStatistics::new( + vec![Arc::clone(stats)], + Arc::clone(&self.pruning_schema), + )); + pruning = Box::new(CompositePruningStatistics::new(vec![ + pruning, + stats_pruning, + ])); + } + match pruning_predicate.prune(pruning.as_ref()) { + Ok(values) => { + assert!(values.len() == 1); + // We expect a single container -> if all containers are false skip this file + if values.into_iter().all(|v| !v) { + return Ok(true); + } + } + // Stats filter array could not be built, so we can't prune + Err(e) => { + debug!("Ignoring error building pruning predicate for file: {e}"); + self.predicate_creation_errors.add(1); + } + } + } + + Ok(false) + } +} diff --git a/datafusion/pruning/src/lib.rs b/datafusion/pruning/src/lib.rs new file mode 100644 index 000000000000..cec4fab2262f --- /dev/null +++ b/datafusion/pruning/src/lib.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod file_pruner; +mod pruning_predicate; + +pub use file_pruner::FilePruner; +pub use pruning_predicate::{ + build_pruning_predicate, PredicateRewriter, PruningPredicate, PruningStatistics, + RequiredColumns, UnhandledPredicateHook, +}; diff --git a/datafusion/physical-optimizer/src/pruning.rs b/datafusion/pruning/src/pruning_predicate.rs similarity index 91% rename from datafusion/physical-optimizer/src/pruning.rs rename to datafusion/pruning/src/pruning_predicate.rs index 1dd168f18167..1551a8f79a7a 100644 --- a/datafusion/physical-optimizer/src/pruning.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -28,7 +28,10 @@ use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::{RecordBatch, RecordBatchOptions}, }; -use log::trace; +// pub use for backwards compatibility +pub use datafusion_common::pruning::PruningStatistics; +use datafusion_physical_plan::metrics::Count; +use log::{debug, trace}; use datafusion_common::error::{DataFusionError, Result}; use datafusion_common::tree_node::TransformedResult; @@ -44,106 +47,6 @@ use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; use datafusion_physical_plan::{ColumnarValue, PhysicalExpr}; -/// A source of runtime statistical information to [`PruningPredicate`]s. -/// -/// # Supported Information -/// -/// 1. Minimum and maximum values for columns -/// -/// 2. Null counts and row counts for columns -/// -/// 3. Whether the values in a column are contained in a set of literals -/// -/// # Vectorized Interface -/// -/// Information for containers / files are returned as Arrow [`ArrayRef`], so -/// the evaluation happens once on a single `RecordBatch`, which amortizes the -/// overhead of evaluating the predicate. This is important when pruning 1000s -/// of containers which often happens in analytic systems that have 1000s of -/// potential files to consider. -/// -/// For example, for the following three files with a single column `a`: -/// ```text -/// file1: column a: min=5, max=10 -/// file2: column a: No stats -/// file2: column a: min=20, max=30 -/// ``` -/// -/// PruningStatistics would return: -/// -/// ```text -/// min_values("a") -> Some([5, Null, 20]) -/// max_values("a") -> Some([10, Null, 30]) -/// min_values("X") -> None -/// ``` -pub trait PruningStatistics { - /// Return the minimum values for the named column, if known. - /// - /// If the minimum value for a particular container is not known, the - /// returned array should have `null` in that row. If the minimum value is - /// not known for any row, return `None`. - /// - /// Note: the returned array must contain [`Self::num_containers`] rows - fn min_values(&self, column: &Column) -> Option; - - /// Return the maximum values for the named column, if known. - /// - /// See [`Self::min_values`] for when to return `None` and null values. - /// - /// Note: the returned array must contain [`Self::num_containers`] rows - fn max_values(&self, column: &Column) -> Option; - - /// Return the number of containers (e.g. Row Groups) being pruned with - /// these statistics. - /// - /// This value corresponds to the size of the [`ArrayRef`] returned by - /// [`Self::min_values`], [`Self::max_values`], [`Self::null_counts`], - /// and [`Self::row_counts`]. - fn num_containers(&self) -> usize; - - /// Return the number of null values for the named column as an - /// [`UInt64Array`] - /// - /// See [`Self::min_values`] for when to return `None` and null values. - /// - /// Note: the returned array must contain [`Self::num_containers`] rows - /// - /// [`UInt64Array`]: arrow::array::UInt64Array - fn null_counts(&self, column: &Column) -> Option; - - /// Return the number of rows for the named column in each container - /// as an [`UInt64Array`]. - /// - /// See [`Self::min_values`] for when to return `None` and null values. - /// - /// Note: the returned array must contain [`Self::num_containers`] rows - /// - /// [`UInt64Array`]: arrow::array::UInt64Array - fn row_counts(&self, column: &Column) -> Option; - - /// Returns [`BooleanArray`] where each row represents information known - /// about specific literal `values` in a column. - /// - /// For example, Parquet Bloom Filters implement this API to communicate - /// that `values` are known not to be present in a Row Group. - /// - /// The returned array has one row for each container, with the following - /// meanings: - /// * `true` if the values in `column` ONLY contain values from `values` - /// * `false` if the values in `column` are NOT ANY of `values` - /// * `null` if the neither of the above holds or is unknown. - /// - /// If these statistics can not determine column membership for any - /// container, return `None` (the default). - /// - /// Note: the returned array must contain [`Self::num_containers`] rows - fn contained( - &self, - column: &Column, - values: &HashSet, - ) -> Option; -} - /// Used to prove that arbitrary predicates (boolean expression) can not /// possibly evaluate to `true` given information about a column provided by /// [`PruningStatistics`]. @@ -474,6 +377,30 @@ pub struct PruningPredicate { literal_guarantees: Vec, } +/// Build a pruning predicate from an optional predicate expression. +/// If the predicate is None or the predicate cannot be converted to a pruning +/// predicate, return None. +/// If there is an error creating the pruning predicate it is recorded by incrementing +/// the `predicate_creation_errors` counter. +pub fn build_pruning_predicate( + predicate: Arc, + file_schema: &SchemaRef, + predicate_creation_errors: &Count, +) -> Option> { + match PruningPredicate::try_new(predicate, Arc::clone(file_schema)) { + Ok(pruning_predicate) => { + if !pruning_predicate.always_true() { + return Some(Arc::new(pruning_predicate)); + } + } + Err(e) => { + debug!("Could not create pruning predicate for: {e}"); + predicate_creation_errors.add(1); + } + } + None +} + /// Rewrites predicates that [`PredicateRewriter`] can not handle, e.g. certain /// complex expressions or predicates that reference columns that are not in the /// schema. @@ -567,7 +494,10 @@ impl PruningPredicate { /// simplified version `b`. See [`ExprSimplifier`] to simplify expressions. /// /// [`ExprSimplifier`]: https://docs.rs/datafusion/latest/datafusion/optimizer/simplify_expressions/struct.ExprSimplifier.html - pub fn prune(&self, statistics: &S) -> Result> { + pub fn prune( + &self, + statistics: &S, + ) -> Result> { let mut builder = BoolVecBuilder::new(statistics.num_containers()); // Try to prove the predicate can't be true for the containers based on @@ -751,6 +681,13 @@ fn is_always_true(expr: &Arc) -> bool { .unwrap_or_default() } +fn is_always_false(expr: &Arc) -> bool { + expr.as_any() + .downcast_ref::() + .map(|l| matches!(l.value(), ScalarValue::Boolean(Some(false)))) + .unwrap_or_default() +} + /// Describes which columns statistics are necessary to evaluate a /// [`PruningPredicate`]. /// @@ -942,7 +879,7 @@ impl From> for RequiredColumns { /// -------+-------- /// 5 | 1000 /// ``` -fn build_statistics_record_batch( +fn build_statistics_record_batch( statistics: &S, required_columns: &RequiredColumns, ) -> Result { @@ -984,11 +921,7 @@ fn build_statistics_record_batch( let mut options = RecordBatchOptions::default(); options.row_count = Some(statistics.num_containers()); - trace!( - "Creating statistics batch for {:#?} with {:#?}", - required_columns, - arrays - ); + trace!("Creating statistics batch for {required_columns:#?} with {arrays:#?}"); RecordBatch::try_new_with_options(schema, arrays, &options).map_err(|err| { plan_datafusion_err!("Can not create statistics record batch: {err}") @@ -1210,23 +1143,35 @@ fn is_compare_op(op: Operator) -> bool { ) } +fn is_string_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) +} + // The pruning logic is based on the comparing the min/max bounds. // Must make sure the two type has order. // For example, casts from string to numbers is not correct. // Because the "13" is less than "3" with UTF8 comparison order. fn verify_support_type_for_prune(from_type: &DataType, to_type: &DataType) -> Result<()> { - // TODO: support other data type for prunable cast or try cast - if matches!( - from_type, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Decimal128(_, _) - ) && matches!( - to_type, - DataType::Int8 | DataType::Int32 | DataType::Int64 | DataType::Decimal128(_, _) - ) { + // Dictionary casts are always supported as long as the value types are supported + let from_type = match from_type { + DataType::Dictionary(_, t) => { + return verify_support_type_for_prune(t.as_ref(), to_type) + } + _ => from_type, + }; + let to_type = match to_type { + DataType::Dictionary(_, t) => { + return verify_support_type_for_prune(from_type, t.as_ref()) + } + _ => to_type, + }; + // If both types are strings or both are not strings (number, timestamp, etc) + // then we can compare them. + // PruningPredicate does not support casting of strings to numbers and such. + if is_string_type(from_type) == is_string_type(to_type) { Ok(()) } else { plan_err!( @@ -1427,6 +1372,11 @@ fn build_predicate_expression( required_columns: &mut RequiredColumns, unhandled_hook: &Arc, ) -> Arc { + if is_always_false(expr) { + // Shouldn't return `unhandled_hook.handle(expr)` + // Because it will transfer false to true. + return Arc::clone(expr); + } // predicate expression can only be a binary expression let expr_any = expr.as_any(); if let Some(is_null) = expr_any.downcast_ref::() { @@ -1526,6 +1476,11 @@ fn build_predicate_expression( build_predicate_expression(&right, schema, required_columns, unhandled_hook); // simplify boolean expression if applicable let expr = match (&left_expr, op, &right_expr) { + (left, Operator::And, right) + if is_always_false(left) || is_always_false(right) => + { + Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(false)))) + } (left, Operator::And, _) if is_always_true(left) => right_expr, (_, Operator::And, right) if is_always_true(right) => left_expr, (left, Operator::Or, right) @@ -1533,6 +1488,9 @@ fn build_predicate_expression( { Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))) } + (left, Operator::Or, _) if is_always_false(left) => right_expr, + (_, Operator::Or, right) if is_always_false(right) => left_expr, + _ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, right_expr)), }; return expr; @@ -1544,7 +1502,10 @@ fn build_predicate_expression( Ok(builder) => builder, // allow partial failure in predicate expression generation // this can still produce a useful predicate when multiple conditions are joined using AND - Err(_) => return unhandled_hook.handle(expr), + Err(e) => { + debug!("Error building pruning expression: {e}"); + return unhandled_hook.handle(expr); + } }; build_statistics_expr(&mut expr_builder) @@ -1889,7 +1850,7 @@ mod tests { use super::*; use datafusion_common::test_util::batches_to_string; - use datafusion_expr::{col, lit}; + use datafusion_expr::{and, col, lit, or}; use insta::assert_snapshot; use arrow::array::Decimal128Array; @@ -2305,8 +2266,7 @@ mod tests { let was_new = fields.insert(field); if !was_new { panic!( - "Duplicate field in required schema: {:?}. Previous fields:\n{:#?}", - field, fields + "Duplicate field in required schema: {field:?}. Previous fields:\n{fields:#?}" ); } } @@ -2811,8 +2771,8 @@ mod tests { let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut required_columns); assert_eq!(predicate_expr.to_string(), expected_expr); - println!("required_columns: {:#?}", required_columns); // for debugging assertions below - // c1 < 1 should add c1_min + println!("required_columns: {required_columns:#?}"); // for debugging assertions below + // c1 < 1 should add c1_min let c1_min_field = Field::new("c1_min", DataType::Int32, false); assert_eq!( required_columns.columns[0], @@ -3006,7 +2966,7 @@ mod tests { } #[test] - fn row_group_predicate_cast() -> Result<()> { + fn row_group_predicate_cast_int_int() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64)"; @@ -3043,6 +3003,291 @@ mod tests { Ok(()) } + #[test] + fn row_group_predicate_cast_string_string() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Utf8) <= 1 AND 1 <= CAST(c1_max@1 AS Utf8)"; + + // test column on the left + let expr = cast(col("c1"), DataType::Utf8) + .eq(lit(ScalarValue::Utf8(Some("1".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Utf8(Some("1".to_string()))) + .eq(cast(col("c1"), DataType::Utf8)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_cast_string_int() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = cast(col("c1"), DataType::Int32).eq(lit(ScalarValue::Int32(Some(1)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Int32(Some(1))).eq(cast(col("c1"), DataType::Int32)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_cast_int_string() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = cast(col("c1"), DataType::Utf8) + .eq(lit(ScalarValue::Utf8(Some("1".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Utf8(Some("1".to_string()))) + .eq(cast(col("c1"), DataType::Utf8)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_date_date() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Date64) <= 1970-01-01 AND 1970-01-01 <= CAST(c1_max@1 AS Date64)"; + + // test column on the left + let expr = + cast(col("c1"), DataType::Date64).eq(lit(ScalarValue::Date64(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = + lit(ScalarValue::Date64(Some(123))).eq(cast(col("c1"), DataType::Date64)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_dict_string_date() -> Result<()> { + // Test with Dictionary for the literal + let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = cast( + col("c1"), + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + ) + .eq(lit(ScalarValue::Utf8(Some("2024-01-01".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))).eq(cast( + col("c1"), + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + )); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_date_dict_string() -> Result<()> { + // Test with Dictionary for the column + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + false, + )]); + let expected_expr = "true"; + + // test column on the left + let expr = + cast(col("c1"), DataType::Date32).eq(lit(ScalarValue::Date32(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = + lit(ScalarValue::Date32(Some(123))).eq(cast(col("c1"), DataType::Date32)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_dict_dict_same_value_type() -> Result<()> { + // Test with Dictionary types that have the same value type but different key types + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + false, + )]); + + // Direct comparison with no cast + let expr = col("c1").eq(lit(ScalarValue::Utf8(Some("test".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + let expected_expr = + "c1_null_count@2 != row_count@3 AND c1_min@0 <= test AND test <= c1_max@1"; + assert_eq!(predicate_expr.to_string(), expected_expr); + + // Test with column cast to a dictionary with different key type + let expr = cast( + col("c1"), + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + ) + .eq(lit(ScalarValue::Utf8(Some("test".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Dictionary(UInt16, Utf8)) <= test AND test <= CAST(c1_max@1 AS Dictionary(UInt16, Utf8))"; + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_dict_dict_different_value_type() -> Result<()> { + // Test with Dictionary types that have different value types + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Int32)), + false, + )]); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 123 AND 123 <= CAST(c1_max@1 AS Int64)"; + + // Test with literal of a different type + let expr = + cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_nested_dict() -> Result<()> { + // Test with nested Dictionary types + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary( + Box::new(DataType::UInt8), + Box::new(DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Utf8), + )), + ), + false, + )]); + let expected_expr = + "c1_null_count@2 != row_count@3 AND c1_min@0 <= test AND test <= c1_max@1"; + + // Test with a simple literal + let expr = col("c1").eq(lit(ScalarValue::Utf8(Some("test".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_dict_date_dict_date() -> Result<()> { + // Test with dictionary-wrapped date types for both sides + let schema = Schema::new(vec![Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Date32)), + false, + )]); + let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Dictionary(UInt16, Date64)) <= 1970-01-01 AND 1970-01-01 <= CAST(c1_max@1 AS Dictionary(UInt16, Date64))"; + + // Test with a cast to a different date type + let expr = cast( + col("c1"), + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Date64)), + ) + .eq(lit(ScalarValue::Date64(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_date_string() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = + cast(col("c1"), DataType::Date32).eq(lit(ScalarValue::Date32(Some(123)))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = + lit(ScalarValue::Date32(Some(123))).eq(cast(col("c1"), DataType::Date32)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + + #[test] + fn row_group_predicate_string_date() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]); + let expected_expr = "true"; + + // test column on the left + let expr = cast(col("c1"), DataType::Utf8) + .eq(lit(ScalarValue::Utf8(Some("2024-01-01".to_string())))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + // test column on the right + let expr = lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))) + .eq(cast(col("c1"), DataType::Utf8)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), expected_expr); + + Ok(()) + } + #[test] fn row_group_predicate_cast_list() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); @@ -3285,12 +3530,10 @@ mod tests { prune_with_expr( // false - // constant literals that do NOT refer to any columns are currently not evaluated at all, hence the result is - // "all true" lit(false), &schema, &statistics, - &[true, true, true, true, true], + &[false, false, false, false, false], ); } @@ -4855,7 +5098,7 @@ mod tests { statistics: &TestStatistics, expected: &[bool], ) { - println!("Pruning with expr: {}", expr); + println!("Pruning with expr: {expr}"); let expr = logical2physical(&expr, schema); let p = PruningPredicate::try_new(expr, Arc::::clone(schema)).unwrap(); let result = p.prune(statistics).unwrap(); @@ -4871,4 +5114,42 @@ mod tests { let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _; build_predicate_expression(&expr, schema, required_columns, &unhandled_hook) } + + #[test] + fn test_build_predicate_expression_with_false() { + let expr = lit(ScalarValue::Boolean(Some(false))); + let schema = Schema::empty(); + let res = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + let expected = logical2physical(&expr, &schema); + assert_eq!(&res, &expected); + } + + #[test] + fn test_build_predicate_expression_with_and_false() { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]); + let expr = and( + col("c1").eq(lit("a")), + lit(ScalarValue::Boolean(Some(false))), + ); + let res = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + let expected = logical2physical(&lit(ScalarValue::Boolean(Some(false))), &schema); + assert_eq!(&res, &expected); + } + + #[test] + fn test_build_predicate_expression_with_or_false() { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]); + let left_expr = col("c1").eq(lit("a")); + let right_expr = lit(ScalarValue::Boolean(Some(false))); + let res = test_build_predicate_expression( + &or(left_expr.clone(), right_expr.clone()), + &schema, + &mut RequiredColumns::new(), + ); + let expected = + "c1_null_count@2 != row_count@3 AND c1_min@0 <= a AND a <= c1_max@1"; + assert_eq!(res.to_string(), expected); + } } diff --git a/datafusion/session/README.md b/datafusion/session/README.md index 019f9f889247..f029c797366f 100644 --- a/datafusion/session/README.md +++ b/datafusion/session/README.md @@ -23,4 +23,9 @@ This crate provides **session-related abstractions** used in the DataFusion query engine. A _session_ represents the runtime context for query execution, including configuration, runtime environment, function registry, and planning. +Most projects should use the [`datafusion`] crate directly, which re-exports +this module. If you are already using the [`datafusion`] crate, there is no +reason to use this crate directly in your project as well. + [df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml new file mode 100644 index 000000000000..2c46cac6b7b0 --- /dev/null +++ b/datafusion/spark/Cargo.toml @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-spark" +description = "DataFusion expressions that emulate Apache Spark's behavior" +version = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +readme = { workspace = true } +license = { workspace = true } +edition = { workspace = true } + +[package.metadata.docs.rs] +all-features = true + +[lints] +workspace = true + +[lib] +name = "datafusion_spark" + +[dependencies] +arrow = { workspace = true } +datafusion-catalog = { workspace = true } +datafusion-common = { workspace = true } +datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-functions = { workspace = true, features = ["crypto_expressions"] } +datafusion-macros = { workspace = true } +log = { workspace = true } diff --git a/datafusion/spark/LICENSE.txt b/datafusion/spark/LICENSE.txt new file mode 120000 index 000000000000..1ef648f64b34 --- /dev/null +++ b/datafusion/spark/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/spark/NOTICE.txt b/datafusion/spark/NOTICE.txt new file mode 120000 index 000000000000..fb051c92b10b --- /dev/null +++ b/datafusion/spark/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/spark/README.md b/datafusion/spark/README.md new file mode 100644 index 000000000000..c92ada0ab477 --- /dev/null +++ b/datafusion/spark/README.md @@ -0,0 +1,40 @@ + + +# datafusion-spark: Spark-compatible Expressions + +This crate provides Apache Spark-compatible expressions for use with DataFusion. + +## Testing Guide + +When testing functions by directly invoking them (e.g., `test_scalar_function!()`), input coercion (from the `signature` +or `coerce_types`) is not applied. + +Therefore, direct invocation tests should only be used to verify that the function is correctly implemented. + +Please be sure to add additional tests beyond direct invocation. +For more detailed testing guidelines, refer to +the [Spark SQLLogicTest README](../sqllogictest/test_files/spark/README.md). + +## Implementation References + +When implementing Spark-compatible functions, you can check if there are existing implementations in +the [Sail](https://github.com/lakehq/sail) or [Comet](https://github.com/apache/datafusion-comet) projects first. +If you do port functionality from these sources, make sure to port over the corresponding tests too, to ensure +correctness and compatibility. diff --git a/datafusion/spark/src/function/aggregate/mod.rs b/datafusion/spark/src/function/aggregate/mod.rs new file mode 100644 index 000000000000..0856e2872d4f --- /dev/null +++ b/datafusion/spark/src/function/aggregate/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::AggregateUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/array/mod.rs b/datafusion/spark/src/function/array/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/array/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/bitwise/mod.rs b/datafusion/spark/src/function/bitwise/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/bitwise/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/collection/mod.rs b/datafusion/spark/src/function/collection/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/collection/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/conditional/mod.rs b/datafusion/spark/src/function/conditional/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/conditional/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/conversion/mod.rs b/datafusion/spark/src/function/conversion/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/conversion/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/csv/mod.rs b/datafusion/spark/src/function/csv/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/csv/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/datetime/mod.rs b/datafusion/spark/src/function/datetime/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/datetime/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/error_utils.rs b/datafusion/spark/src/function/error_utils.rs new file mode 100644 index 000000000000..b972d64ed3e9 --- /dev/null +++ b/datafusion/spark/src/function/error_utils.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// TODO: https://github.com/apache/spark/tree/master/common/utils/src/main/resources/error + +use arrow::datatypes::DataType; +use datafusion_common::{exec_datafusion_err, internal_datafusion_err, DataFusionError}; + +pub fn invalid_arg_count_exec_err( + function_name: &str, + required_range: (i32, i32), + provided: usize, +) -> DataFusionError { + let (min_required, max_required) = required_range; + let required = if min_required == max_required { + format!( + "{min_required} argument{}", + if min_required == 1 { "" } else { "s" } + ) + } else { + format!("{min_required} to {max_required} arguments") + }; + exec_datafusion_err!( + "Spark `{function_name}` function requires {required}, got {provided}" + ) +} + +pub fn unsupported_data_type_exec_err( + function_name: &str, + required: &str, + provided: &DataType, +) -> DataFusionError { + exec_datafusion_err!("Unsupported Data Type: Spark `{function_name}` function expects {required}, got {provided}") +} + +pub fn unsupported_data_types_exec_err( + function_name: &str, + required: &str, + provided: &[DataType], +) -> DataFusionError { + exec_datafusion_err!( + "Unsupported Data Type: Spark `{function_name}` function expects {required}, got {}", + provided + .iter() + .map(|dt| format!("{dt}")) + .collect::>() + .join(", ") + ) +} + +pub fn generic_exec_err(function_name: &str, message: &str) -> DataFusionError { + exec_datafusion_err!("Spark `{function_name}` function: {message}") +} + +pub fn generic_internal_err(function_name: &str, message: &str) -> DataFusionError { + internal_datafusion_err!("Spark `{function_name}` function: {message}") +} diff --git a/datafusion/spark/src/function/generator/mod.rs b/datafusion/spark/src/function/generator/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/generator/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/hash/mod.rs b/datafusion/spark/src/function/hash/mod.rs new file mode 100644 index 000000000000..f31918e6a46b --- /dev/null +++ b/datafusion/spark/src/function/hash/mod.rs @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod sha2; + +use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; +use std::sync::Arc; + +make_udf_function!(sha2::SparkSha2, sha2); + +pub mod expr_fn { + use datafusion_functions::export_functions; + export_functions!((sha2, "sha2(expr, bitLength) - Returns a checksum of SHA-2 family as a hex string of expr. SHA-224, SHA-256, SHA-384, and SHA-512 are supported. Bit length of 0 is equivalent to 256.", arg1 arg2)); +} + +pub fn functions() -> Vec> { + vec![sha2()] +} diff --git a/datafusion/spark/src/function/hash/sha2.rs b/datafusion/spark/src/function/hash/sha2.rs new file mode 100644 index 000000000000..a8bb8c21a2a4 --- /dev/null +++ b/datafusion/spark/src/function/hash/sha2.rs @@ -0,0 +1,220 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate datafusion_functions; + +use crate::function::error_utils::{ + invalid_arg_count_exec_err, unsupported_data_type_exec_err, +}; +use crate::function::math::hex::spark_sha2_hex; +use arrow::array::{ArrayRef, AsArray, StringArray}; +use arrow::datatypes::{DataType, Int32Type}; +use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue}; +use datafusion_expr::Signature; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; +pub use datafusion_functions::crypto::basic::{sha224, sha256, sha384, sha512}; +use std::any::Any; +use std::sync::Arc; + +/// +#[derive(Debug)] +pub struct SparkSha2 { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkSha2 { + fn default() -> Self { + Self::new() + } +} + +impl SparkSha2 { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkSha2 { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "sha2" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types[1].is_null() { + return Ok(DataType::Null); + } + Ok(match arg_types[0] { + DataType::Utf8View + | DataType::LargeUtf8 + | DataType::Utf8 + | DataType::Binary + | DataType::BinaryView + | DataType::LargeBinary => DataType::Utf8, + DataType::Null => DataType::Null, + _ => { + return exec_err!( + "{} function can only accept strings or binary arrays.", + self.name() + ) + } + }) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| { + internal_datafusion_err!("Expected 2 arguments for function sha2") + })?; + + sha2(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 2 { + return Err(invalid_arg_count_exec_err( + self.name(), + (2, 2), + arg_types.len(), + )); + } + let expr_type = match &arg_types[0] { + DataType::Utf8View + | DataType::LargeUtf8 + | DataType::Utf8 + | DataType::Binary + | DataType::BinaryView + | DataType::LargeBinary + | DataType::Null => Ok(arg_types[0].clone()), + _ => Err(unsupported_data_type_exec_err( + self.name(), + "String, Binary", + &arg_types[0], + )), + }?; + let bit_length_type = if arg_types[1].is_numeric() { + Ok(DataType::Int32) + } else if arg_types[1].is_null() { + Ok(DataType::Null) + } else { + Err(unsupported_data_type_exec_err( + self.name(), + "Numeric Type", + &arg_types[1], + )) + }?; + + Ok(vec![expr_type, bit_length_type]) + } +} + +pub fn sha2(args: [ColumnarValue; 2]) -> Result { + match args { + [ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg)))] => { + compute_sha2( + bit_length_arg, + &[ColumnarValue::from(ScalarValue::Utf8(expr_arg))], + ) + } + [ColumnarValue::Array(expr_arg), ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg)))] => { + compute_sha2(bit_length_arg, &[ColumnarValue::from(expr_arg)]) + } + [ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), ColumnarValue::Array(bit_length_arg)] => + { + let arr: StringArray = bit_length_arg + .as_primitive::() + .iter() + .map(|bit_length| { + match sha2([ + ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg.clone())), + ColumnarValue::Scalar(ScalarValue::Int32(bit_length)), + ]) + .unwrap() + { + ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str, + ColumnarValue::Array(arr) => arr + .as_string::() + .iter() + .map(|str| str.unwrap().to_string()) + .next(), // first element + _ => unreachable!(), + } + }) + .collect(); + Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef)) + } + [ColumnarValue::Array(expr_arg), ColumnarValue::Array(bit_length_arg)] => { + let expr_iter = expr_arg.as_string::().iter(); + let bit_length_iter = bit_length_arg.as_primitive::().iter(); + let arr: StringArray = expr_iter + .zip(bit_length_iter) + .map(|(expr, bit_length)| { + match sha2([ + ColumnarValue::Scalar(ScalarValue::Utf8(Some( + expr.unwrap().to_string(), + ))), + ColumnarValue::Scalar(ScalarValue::Int32(bit_length)), + ]) + .unwrap() + { + ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str, + ColumnarValue::Array(arr) => arr + .as_string::() + .iter() + .map(|str| str.unwrap().to_string()) + .next(), // first element + _ => unreachable!(), + } + }) + .collect(); + Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef)) + } + _ => exec_err!("Unsupported argument types for sha2 function"), + } +} + +fn compute_sha2( + bit_length_arg: i32, + expr_arg: &[ColumnarValue], +) -> Result { + match bit_length_arg { + 0 | 256 => sha256(expr_arg), + 224 => sha224(expr_arg), + 384 => sha384(expr_arg), + 512 => sha512(expr_arg), + _ => { + // Return null for unsupported bit lengths instead of error, because spark sha2 does not + // error out for this. + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + } + .map(|hashed| spark_sha2_hex(&[hashed]).unwrap()) +} diff --git a/datafusion/spark/src/function/json/mod.rs b/datafusion/spark/src/function/json/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/json/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/lambda/mod.rs b/datafusion/spark/src/function/lambda/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/lambda/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/map/mod.rs b/datafusion/spark/src/function/map/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/map/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/math/expm1.rs b/datafusion/spark/src/function/math/expm1.rs new file mode 100644 index 000000000000..3a3a0c3835d3 --- /dev/null +++ b/datafusion/spark/src/function/math/expm1.rs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::function::error_utils::{ + invalid_arg_count_exec_err, unsupported_data_type_exec_err, +}; +use arrow::array::{ArrayRef, AsArray}; +use arrow::datatypes::{DataType, Float64Type}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +/// +#[derive(Debug)] +pub struct SparkExpm1 { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkExpm1 { + fn default() -> Self { + Self::new() + } +} + +impl SparkExpm1 { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkExpm1 { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "expm1" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args.args.len() != 1 { + return Err(invalid_arg_count_exec_err("expm1", (1, 1), args.args.len())); + } + match &args.args[0] { + ColumnarValue::Scalar(ScalarValue::Float64(value)) => Ok( + ColumnarValue::Scalar(ScalarValue::Float64(value.map(|x| x.exp_m1()))), + ), + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float64 => Ok(ColumnarValue::Array(Arc::new( + array + .as_primitive::() + .unary::<_, Float64Type>(|x| x.exp_m1()), + ) + as ArrayRef)), + other => Err(unsupported_data_type_exec_err( + "expm1", + format!("{}", DataType::Float64).as_str(), + other, + )), + }, + other => Err(unsupported_data_type_exec_err( + "expm1", + format!("{}", DataType::Float64).as_str(), + &other.data_type(), + )), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return Err(invalid_arg_count_exec_err("expm1", (1, 1), arg_types.len())); + } + if arg_types[0].is_numeric() { + Ok(vec![DataType::Float64]) + } else { + Err(unsupported_data_type_exec_err( + "expm1", + "Numeric Type", + &arg_types[0], + )) + } + } +} + +#[cfg(test)] +mod tests { + use crate::function::math::expm1::SparkExpm1; + use crate::function::utils::test::test_scalar_function; + use arrow::array::{Array, Float64Array}; + use arrow::datatypes::DataType::Float64; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + macro_rules! test_expm1_float64_invoke { + ($INPUT:expr, $EXPECTED:expr) => { + test_scalar_function!( + SparkExpm1::new(), + vec![ColumnarValue::Scalar(ScalarValue::Float64($INPUT))], + $EXPECTED, + f64, + Float64, + Float64Array + ); + }; + } + + #[test] + fn test_expm1_invoke() -> Result<()> { + test_expm1_float64_invoke!(Some(0f64), Ok(Some(0.0f64))); + Ok(()) + } +} diff --git a/datafusion/spark/src/function/math/factorial.rs b/datafusion/spark/src/function/math/factorial.rs new file mode 100644 index 000000000000..10f7f0696469 --- /dev/null +++ b/datafusion/spark/src/function/math/factorial.rs @@ -0,0 +1,196 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, Int64Array}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{Int32, Int64}; +use datafusion_common::cast::as_int32_array; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::Signature; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; + +/// +#[derive(Debug)] +pub struct SparkFactorial { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkFactorial { + fn default() -> Self { + Self::new() + } +} + +impl SparkFactorial { + pub fn new() -> Self { + Self { + signature: Signature::exact(vec![Int32], Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkFactorial { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "factorial" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_factorial(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +const FACTORIALS: [i64; 21] = [ + 1, + 1, + 2, + 6, + 24, + 120, + 720, + 5040, + 40320, + 362880, + 3628800, + 39916800, + 479001600, + 6227020800, + 87178291200, + 1307674368000, + 20922789888000, + 355687428096000, + 6402373705728000, + 121645100408832000, + 2432902008176640000, +]; + +pub fn spark_factorial(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return Err(DataFusionError::Internal( + "`factorial` expects exactly one argument".to_string(), + )); + } + + match &args[0] { + ColumnarValue::Scalar(ScalarValue::Int32(value)) => { + let result = compute_factorial(*value); + Ok(ColumnarValue::Scalar(ScalarValue::Int64(result))) + } + ColumnarValue::Scalar(other) => { + exec_err!("`factorial` got an unexpected scalar type: {:?}", other) + } + ColumnarValue::Array(array) => match array.data_type() { + Int32 => { + let array = as_int32_array(array)?; + + let result: Int64Array = array.iter().map(compute_factorial).collect(); + + Ok(ColumnarValue::Array(Arc::new(result))) + } + other => { + exec_err!("`factorial` got an unexpected argument type: {:?}", other) + } + }, + } +} + +#[inline] +fn compute_factorial(num: Option) -> Option { + num.filter(|&v| (0..=20).contains(&v)) + .map(|v| FACTORIALS[v as usize]) +} + +#[cfg(test)] +mod test { + use crate::function::math::factorial::spark_factorial; + use arrow::array::{Int32Array, Int64Array}; + use datafusion_common::cast::as_int64_array; + use datafusion_common::ScalarValue; + use datafusion_expr::ColumnarValue; + use std::sync::Arc; + + #[test] + fn test_spark_factorial_array() { + let input = Int32Array::from(vec![ + Some(-1), + Some(0), + Some(1), + Some(2), + Some(4), + Some(20), + Some(21), + None, + ]); + + let args = ColumnarValue::Array(Arc::new(input)); + let result = spark_factorial(&[args]).unwrap(); + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let actual = as_int64_array(&result).unwrap(); + let expected = Int64Array::from(vec![ + None, + Some(1), + Some(1), + Some(2), + Some(24), + Some(2432902008176640000), + None, + None, + ]); + + assert_eq!(actual, &expected); + } + + #[test] + fn test_spark_factorial_scalar() { + let input = ScalarValue::Int32(Some(5)); + + let args = ColumnarValue::Scalar(input); + let result = spark_factorial(&[args]).unwrap(); + let result = match result { + ColumnarValue::Scalar(ScalarValue::Int64(val)) => val, + _ => panic!("Expected scalar"), + }; + let actual = result.unwrap(); + let expected = 120_i64; + + assert_eq!(actual, expected); + } +} diff --git a/datafusion/spark/src/function/math/hex.rs b/datafusion/spark/src/function/math/hex.rs new file mode 100644 index 000000000000..614d1d4e9ac1 --- /dev/null +++ b/datafusion/spark/src/function/math/hex.rs @@ -0,0 +1,419 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use crate::function::error_utils::{ + invalid_arg_count_exec_err, unsupported_data_type_exec_err, +}; +use arrow::array::{Array, StringArray}; +use arrow::datatypes::DataType; +use arrow::{ + array::{as_dictionary_array, as_largestring_array, as_string_array}, + datatypes::Int32Type, +}; +use datafusion_common::{ + cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array}, + exec_err, DataFusionError, +}; +use datafusion_expr::Signature; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; +use std::fmt::Write; + +/// +#[derive(Debug)] +pub struct SparkHex { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkHex { + fn default() -> Self { + Self::new() + } +} + +impl SparkHex { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkHex { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "hex" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type( + &self, + _arg_types: &[DataType], + ) -> datafusion_common::Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + spark_hex(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types( + &self, + arg_types: &[DataType], + ) -> datafusion_common::Result> { + if arg_types.len() != 1 { + return Err(invalid_arg_count_exec_err("hex", (1, 1), arg_types.len())); + } + match &arg_types[0] { + DataType::Int64 + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary => Ok(vec![arg_types[0].clone()]), + DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { + DataType::Int64 + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary => Ok(vec![arg_types[0].clone()]), + other => { + if other.is_numeric() { + Ok(vec![DataType::Dictionary( + key_type.clone(), + Box::new(DataType::Int64), + )]) + } else { + Err(unsupported_data_type_exec_err( + "hex", + "Numeric, String, or Binary", + &arg_types[0], + )) + } + } + }, + other => { + if other.is_numeric() { + Ok(vec![DataType::Int64]) + } else { + Err(unsupported_data_type_exec_err( + "hex", + "Numeric, String, or Binary", + &arg_types[0], + )) + } + } + } + } +} + +fn hex_int64(num: i64) -> String { + format!("{num:X}") +} + +#[inline(always)] +fn hex_encode>(data: T, lower_case: bool) -> String { + let mut s = String::with_capacity(data.as_ref().len() * 2); + if lower_case { + for b in data.as_ref() { + // Writing to a string never errors, so we can unwrap here. + write!(&mut s, "{b:02x}").unwrap(); + } + } else { + for b in data.as_ref() { + // Writing to a string never errors, so we can unwrap here. + write!(&mut s, "{b:02X}").unwrap(); + } + } + s +} + +#[inline(always)] +fn hex_bytes>( + bytes: T, + lowercase: bool, +) -> Result { + let hex_string = hex_encode(bytes, lowercase); + Ok(hex_string) +} + +/// Spark-compatible `hex` function +pub fn spark_hex(args: &[ColumnarValue]) -> Result { + compute_hex(args, false) +} + +/// Spark-compatible `sha2` function +pub fn spark_sha2_hex(args: &[ColumnarValue]) -> Result { + compute_hex(args, true) +} + +pub fn compute_hex( + args: &[ColumnarValue], + lowercase: bool, +) -> Result { + if args.len() != 1 { + return Err(DataFusionError::Internal( + "hex expects exactly one argument".to_string(), + )); + } + + let input = match &args[0] { + ColumnarValue::Scalar(value) => ColumnarValue::Array(value.to_array()?), + ColumnarValue::Array(_) => args[0].clone(), + }; + + match &input { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Int64 => { + let array = as_int64_array(array)?; + + let hexed_array: StringArray = + array.iter().map(|v| v.map(hex_int64)).collect(); + + Ok(ColumnarValue::Array(Arc::new(hexed_array))) + } + DataType::Utf8 => { + let array = as_string_array(array); + + let hexed: StringArray = array + .iter() + .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::LargeUtf8 => { + let array = as_largestring_array(array); + + let hexed: StringArray = array + .iter() + .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::Binary => { + let array = as_binary_array(array)?; + + let hexed: StringArray = array + .iter() + .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::FixedSizeBinary(_) => { + let array = as_fixed_size_binary_array(array)?; + + let hexed: StringArray = array + .iter() + .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::Dictionary(_, value_type) => { + let dict = as_dictionary_array::(&array); + + let values = match **value_type { + DataType::Int64 => as_int64_array(dict.values())? + .iter() + .map(|v| v.map(hex_int64)) + .collect::>(), + DataType::Utf8 => as_string_array(dict.values()) + .iter() + .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) + .collect::>()?, + DataType::Binary => as_binary_array(dict.values())? + .iter() + .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) + .collect::>()?, + _ => exec_err!( + "hex got an unexpected argument type: {:?}", + array.data_type() + )?, + }; + + let new_values: Vec> = dict + .keys() + .iter() + .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None)) + .collect(); + + let string_array_values = StringArray::from(new_values); + + Ok(ColumnarValue::Array(Arc::new(string_array_values))) + } + _ => exec_err!( + "hex got an unexpected argument type: {:?}", + array.data_type() + ), + }, + _ => exec_err!("native hex does not support scalar values at this time"), + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{Int64Array, StringArray}; + use arrow::{ + array::{ + as_string_array, BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, + StringBuilder, StringDictionaryBuilder, + }, + datatypes::{Int32Type, Int64Type}, + }; + use datafusion_expr::ColumnarValue; + + #[test] + fn test_dictionary_hex_utf8() { + let mut input_builder = StringDictionaryBuilder::::new(); + input_builder.append_value("hi"); + input_builder.append_value("bye"); + input_builder.append_null(); + input_builder.append_value("rust"); + let input = input_builder.finish(); + + let mut string_builder = StringBuilder::new(); + string_builder.append_value("6869"); + string_builder.append_value("627965"); + string_builder.append_null(); + string_builder.append_value("72757374"); + let expected = string_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_dictionary_hex_int64() { + let mut input_builder = PrimitiveDictionaryBuilder::::new(); + input_builder.append_value(1); + input_builder.append_value(2); + input_builder.append_null(); + input_builder.append_value(3); + let input = input_builder.finish(); + + let mut string_builder = StringBuilder::new(); + string_builder.append_value("1"); + string_builder.append_value("2"); + string_builder.append_null(); + string_builder.append_value("3"); + let expected = string_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_dictionary_hex_binary() { + let mut input_builder = BinaryDictionaryBuilder::::new(); + input_builder.append_value("1"); + input_builder.append_value("j"); + input_builder.append_null(); + input_builder.append_value("3"); + let input = input_builder.finish(); + + let mut expected_builder = StringBuilder::new(); + expected_builder.append_value("31"); + expected_builder.append_value("6A"); + expected_builder.append_null(); + expected_builder.append_value("33"); + let expected = expected_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_hex_int64() { + let num = 1234; + let hexed = super::hex_int64(num); + assert_eq!(hexed, "4D2".to_string()); + + let num = -1; + let hexed = super::hex_int64(num); + assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string()); + } + + #[test] + fn test_spark_hex_int64() { + let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]); + let columnar_value = ColumnarValue::Array(Arc::new(int_array)); + + let result = super::spark_hex(&[columnar_value]).unwrap(); + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let string_array = as_string_array(&result); + let expected_array = StringArray::from(vec![ + Some("1".to_string()), + Some("2".to_string()), + None, + Some("3".to_string()), + ]); + + assert_eq!(string_array, &expected_array); + } +} diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs new file mode 100644 index 000000000000..1f2a3b1d67f6 --- /dev/null +++ b/datafusion/spark/src/function/math/mod.rs @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod expm1; +pub mod factorial; +pub mod hex; + +use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; +use std::sync::Arc; + +make_udf_function!(expm1::SparkExpm1, expm1); +make_udf_function!(factorial::SparkFactorial, factorial); +make_udf_function!(hex::SparkHex, hex); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!((expm1, "Returns exp(expr) - 1 as a Float64.", arg1)); + export_functions!(( + factorial, + "Returns the factorial of expr. expr is [0..20]. Otherwise, null.", + arg1 + )); + export_functions!((hex, "Computes hex value of the given column.", arg1)); +} + +pub fn functions() -> Vec> { + vec![expm1(), factorial(), hex()] +} diff --git a/datafusion/spark/src/function/misc/mod.rs b/datafusion/spark/src/function/misc/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/misc/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/mod.rs b/datafusion/spark/src/function/mod.rs new file mode 100644 index 000000000000..dfdd94a040a9 --- /dev/null +++ b/datafusion/spark/src/function/mod.rs @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod aggregate; +pub mod array; +pub mod bitwise; +pub mod collection; +pub mod conditional; +pub mod conversion; +pub mod csv; +pub mod datetime; +pub mod error_utils; +pub mod generator; +pub mod hash; +pub mod json; +pub mod lambda; +pub mod map; +pub mod math; +pub mod misc; +pub mod predicate; +pub mod string; +pub mod r#struct; +pub mod table; +pub mod url; +pub mod utils; +pub mod window; +pub mod xml; diff --git a/datafusion/spark/src/function/predicate/mod.rs b/datafusion/spark/src/function/predicate/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/predicate/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/string/ascii.rs b/datafusion/spark/src/function/string/ascii.rs new file mode 100644 index 000000000000..c05aa214ccc0 --- /dev/null +++ b/datafusion/spark/src/function/string/ascii.rs @@ -0,0 +1,174 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array}; +use arrow::datatypes::DataType; +use arrow::error::ArrowError; +use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion_functions::utils::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + +/// +#[derive(Debug)] +pub struct SparkAscii { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkAscii { + fn default() -> Self { + Self::new() + } +} + +impl SparkAscii { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} + +impl ScalarUDFImpl for SparkAscii { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ascii" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(ascii, vec![])(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return plan_err!( + "The {} function requires 1 argument, but got {}.", + self.name(), + arg_types.len() + ); + } + Ok(vec![DataType::Utf8]) + } +} + +fn calculate_ascii<'a, V>(array: V) -> Result +where + V: ArrayAccessor, +{ + let iter = ArrayIter::new(array); + let result = iter + .map(|string| { + string.map(|s| { + let mut chars = s.chars(); + chars.next().map_or(0, |v| v as i32) + }) + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Returns the numeric code of the first character of the argument. +pub fn ascii(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Utf8 => { + let string_array = args[0].as_string::(); + Ok(calculate_ascii(string_array)?) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + Ok(calculate_ascii(string_array)?) + } + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + Ok(calculate_ascii(string_array)?) + } + _ => internal_err!("Unsupported data type"), + } +} + +#[cfg(test)] +mod tests { + use crate::function::string::ascii::SparkAscii; + use crate::function::utils::test::test_scalar_function; + use arrow::array::{Array, Int32Array}; + use arrow::datatypes::DataType::Int32; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + macro_rules! test_ascii_string_invoke { + ($INPUT:expr, $EXPECTED:expr) => { + test_scalar_function!( + SparkAscii::new(), + vec![ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + + test_scalar_function!( + SparkAscii::new(), + vec![ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + + test_scalar_function!( + SparkAscii::new(), + vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + }; + } + + #[test] + fn test_ascii_invoke() -> Result<()> { + test_ascii_string_invoke!(Some(String::from("x")), Ok(Some(120))); + test_ascii_string_invoke!(Some(String::from("a")), Ok(Some(97))); + test_ascii_string_invoke!(Some(String::from("")), Ok(Some(0))); + test_ascii_string_invoke!(Some(String::from("\n")), Ok(Some(10))); + test_ascii_string_invoke!(Some(String::from("\t")), Ok(Some(9))); + test_ascii_string_invoke!(None, Ok(None)); + + Ok(()) + } +} diff --git a/datafusion/spark/src/function/string/char.rs b/datafusion/spark/src/function/string/char.rs new file mode 100644 index 000000000000..dd6cdc83b30d --- /dev/null +++ b/datafusion/spark/src/function/string/char.rs @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow::{ + array::{ArrayRef, StringArray}, + datatypes::{ + DataType, + DataType::{Int64, Utf8}, + }, +}; + +use datafusion_common::{cast::as_int64_array, exec_err, Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +/// Spark-compatible `char` expression +/// +#[derive(Debug)] +pub struct SparkChar { + signature: Signature, +} + +impl Default for SparkChar { + fn default() -> Self { + Self::new() + } +} + +impl SparkChar { + pub fn new() -> Self { + Self { + signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkChar { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "char" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_chr(&args.args) + } +} + +/// Returns the ASCII character having the binary equivalent to the input expression. +/// E.g., chr(65) = 'A'. +/// Compatible with Apache Spark's Chr function +fn spark_chr(args: &[ColumnarValue]) -> Result { + let array = args[0].clone(); + match array { + ColumnarValue::Array(array) => { + let array = chr(&[array])?; + Ok(ColumnarValue::Array(array)) + } + ColumnarValue::Scalar(ScalarValue::Int64(Some(value))) => { + if value < 0 { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "".to_string(), + )))) + } else { + match core::char::from_u32((value % 256) as u32) { + Some(ch) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + ch.to_string(), + )))), + None => { + exec_err!("requested character was incompatible for encoding.") + } + } + } + } + _ => exec_err!("The argument must be an Int64 array or scalar."), + } +} + +fn chr(args: &[ArrayRef]) -> Result { + let integer_array = as_int64_array(&args[0])?; + + // first map is the iterator, second is for the `Option<_>` + let result = integer_array + .iter() + .map(|integer: Option| { + integer + .map(|integer| { + if integer < 0 { + return Ok("".to_string()); // Return empty string for negative integers + } + match core::char::from_u32((integer % 256) as u32) { + Some(ch) => Ok(ch.to_string()), + None => { + exec_err!("requested character not compatible for encoding.") + } + } + }) + .transpose() + }) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/spark/src/function/string/mod.rs b/datafusion/spark/src/function/string/mod.rs new file mode 100644 index 000000000000..9d5fabe832e9 --- /dev/null +++ b/datafusion/spark/src/function/string/mod.rs @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod ascii; +pub mod char; + +use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; +use std::sync::Arc; + +make_udf_function!(ascii::SparkAscii, ascii); +make_udf_function!(char::SparkChar, char); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!(( + ascii, + "Returns the ASCII code point of the first character of string.", + arg1 + )); + export_functions!(( + char, + "Returns the ASCII character having the binary equivalent to col. If col is larger than 256 the result is equivalent to char(col % 256).", + arg1 + )); +} + +pub fn functions() -> Vec> { + vec![ascii(), char()] +} diff --git a/datafusion/spark/src/function/struct/mod.rs b/datafusion/spark/src/function/struct/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/struct/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/table/mod.rs b/datafusion/spark/src/function/table/mod.rs new file mode 100644 index 000000000000..aba7b7ceb78e --- /dev/null +++ b/datafusion/spark/src/function/table/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_catalog::TableFunction; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/url/mod.rs b/datafusion/spark/src/function/url/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/url/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/utils.rs b/datafusion/spark/src/function/utils.rs new file mode 100644 index 000000000000..85af4bb927ca --- /dev/null +++ b/datafusion/spark/src/function/utils.rs @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +pub mod test { + /// $FUNC ScalarUDFImpl to test + /// $ARGS arguments (vec) to pass to function + /// $EXPECTED a Result + /// $EXPECTED_TYPE is the expected value type + /// $EXPECTED_DATA_TYPE is the expected result type + /// $ARRAY_TYPE is the column type after function applied + macro_rules! test_scalar_function { + ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => { + let expected: datafusion_common::Result> = $EXPECTED; + let func = $FUNC; + + let arg_fields: Vec = $ARGS + .iter() + .enumerate() + .map(|(idx, arg)| { + + let nullable = match arg { + datafusion_expr::ColumnarValue::Scalar(scalar) => scalar.is_null(), + datafusion_expr::ColumnarValue::Array(a) => a.null_count() > 0, + }; + + std::sync::Arc::new(arrow::datatypes::Field::new(format!("arg_{idx}"), arg.data_type(), nullable)) + }) + .collect::>(); + + let cardinality = $ARGS + .iter() + .fold(Option::::None, |acc, arg| match arg { + datafusion_expr::ColumnarValue::Scalar(_) => acc, + datafusion_expr::ColumnarValue::Array(a) => Some(a.len()), + }) + .unwrap_or(1); + + let scalar_arguments = $ARGS.iter().map(|arg| match arg { + datafusion_expr::ColumnarValue::Scalar(scalar) => Some(scalar.clone()), + datafusion_expr::ColumnarValue::Array(_) => None, + }).collect::>(); + let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::>(); + + + let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &scalar_arguments_refs + }); + + match expected { + Ok(expected) => { + let return_field = return_field.unwrap(); + assert_eq!(return_field.data_type(), &$EXPECTED_DATA_TYPE); + + let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{ + args: $ARGS, + number_rows: cardinality, + return_field, + arg_fields: arg_fields.clone(), + }); + assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); + + let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array"); + let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type"); + assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE); + + // value is correct + match expected { + Some(v) => assert_eq!(result.value(0), v), + None => assert!(result.is_null(0)), + }; + } + Err(expected_error) => { + if return_field.is_err() { + match return_field { + Ok(_) => assert!(false, "expected error"), + Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); } + } + } + else { + let return_field = return_field.unwrap(); + + // invoke is expected error - cannot use .expect_err() due to Debug not being implemented + match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{ + args: $ARGS, + number_rows: cardinality, + return_field, + arg_fields, + }) { + Ok(_) => assert!(false, "expected error"), + Err(error) => { + assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); + } + } + } + } + }; + }; + } + + pub(crate) use test_scalar_function; +} diff --git a/datafusion/spark/src/function/window/mod.rs b/datafusion/spark/src/function/window/mod.rs new file mode 100644 index 000000000000..97ab4a9e3542 --- /dev/null +++ b/datafusion/spark/src/function/window/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::WindowUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/function/xml/mod.rs b/datafusion/spark/src/function/xml/mod.rs new file mode 100644 index 000000000000..a87df9a2c87a --- /dev/null +++ b/datafusion/spark/src/function/xml/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + +pub mod expr_fn {} + +pub fn functions() -> Vec> { + vec![] +} diff --git a/datafusion/spark/src/lib.rs b/datafusion/spark/src/lib.rs new file mode 100644 index 000000000000..4ce9be1263ef --- /dev/null +++ b/datafusion/spark/src/lib.rs @@ -0,0 +1,199 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![doc( + html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", + html_favicon_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg" +)] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] + +//! Spark Expression packages for [DataFusion]. +//! +//! This crate contains a collection of various Spark function packages for DataFusion, +//! implemented using the extension API. +//! +//! [DataFusion]: https://crates.io/crates/datafusion +//! +//! +//! # Available Function Packages +//! See the list of [modules](#modules) in this crate for available packages. +//! +//! # Example: using all function packages +//! +//! You can register all the functions in all packages using the [`register_all`] +//! function as shown below. +//! +//! ``` +//! # use datafusion_execution::FunctionRegistry; +//! # use datafusion_expr::{ScalarUDF, AggregateUDF, WindowUDF}; +//! # use datafusion_expr::planner::ExprPlanner; +//! # use datafusion_common::Result; +//! # use std::collections::HashSet; +//! # use std::sync::Arc; +//! # // Note: We can't use a real SessionContext here because the +//! # // `datafusion_spark` crate has no dependence on the DataFusion crate +//! # // thus use a dummy SessionContext that has enough of the implementation +//! # struct SessionContext {} +//! # impl FunctionRegistry for SessionContext { +//! # fn register_udf(&mut self, _udf: Arc) -> Result>> { Ok (None) } +//! # fn udfs(&self) -> HashSet { unimplemented!() } +//! # fn udf(&self, _name: &str) -> Result> { unimplemented!() } +//! # fn udaf(&self, name: &str) -> Result> {unimplemented!() } +//! # fn udwf(&self, name: &str) -> Result> { unimplemented!() } +//! # fn expr_planners(&self) -> Vec> { unimplemented!() } +//! # } +//! # impl SessionContext { +//! # fn new() -> Self { SessionContext {} } +//! # async fn sql(&mut self, _query: &str) -> Result<()> { Ok(()) } +//! # } +//! # +//! # async fn stub() -> Result<()> { +//! // Create a new session context +//! let mut ctx = SessionContext::new(); +//! // register all spark functions with the context +//! datafusion_spark::register_all(&mut ctx)?; +//! // run a query. Note the `sha2` function is now available which +//! // has Spark semantics +//! let df = ctx.sql("SELECT sha2('The input String', 256)").await?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Example: calling a specific function in Rust +//! +//! Each package also exports an `expr_fn` submodule that create [`Expr`]s for +//! invoking functions via rust using a fluent style. For example, to invoke the +//! `sha2` function, you can use the following code: +//! +//! ```rust +//! # use datafusion_expr::{col, lit}; +//! use datafusion_spark::expr_fn::sha2; +//! // Create the expression `sha2(my_data, 256)` +//! let expr = sha2(col("my_data"), lit(256)); +//!``` +//! +//![`Expr`]: datafusion_expr::Expr + +pub mod function; + +use datafusion_catalog::TableFunction; +use datafusion_common::Result; +use datafusion_execution::FunctionRegistry; +use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; +use log::debug; +use std::sync::Arc; + +/// Fluent-style API for creating `Expr`s +#[allow(unused)] +pub mod expr_fn { + pub use super::function::aggregate::expr_fn::*; + pub use super::function::array::expr_fn::*; + pub use super::function::bitwise::expr_fn::*; + pub use super::function::collection::expr_fn::*; + pub use super::function::conditional::expr_fn::*; + pub use super::function::conversion::expr_fn::*; + pub use super::function::csv::expr_fn::*; + pub use super::function::datetime::expr_fn::*; + pub use super::function::generator::expr_fn::*; + pub use super::function::hash::expr_fn::*; + pub use super::function::json::expr_fn::*; + pub use super::function::lambda::expr_fn::*; + pub use super::function::map::expr_fn::*; + pub use super::function::math::expr_fn::*; + pub use super::function::misc::expr_fn::*; + pub use super::function::predicate::expr_fn::*; + pub use super::function::r#struct::expr_fn::*; + pub use super::function::string::expr_fn::*; + pub use super::function::table::expr_fn::*; + pub use super::function::url::expr_fn::*; + pub use super::function::window::expr_fn::*; + pub use super::function::xml::expr_fn::*; +} + +/// Returns all default scalar functions +pub fn all_default_scalar_functions() -> Vec> { + function::array::functions() + .into_iter() + .chain(function::bitwise::functions()) + .chain(function::collection::functions()) + .chain(function::conditional::functions()) + .chain(function::conversion::functions()) + .chain(function::csv::functions()) + .chain(function::datetime::functions()) + .chain(function::generator::functions()) + .chain(function::hash::functions()) + .chain(function::json::functions()) + .chain(function::lambda::functions()) + .chain(function::map::functions()) + .chain(function::math::functions()) + .chain(function::misc::functions()) + .chain(function::predicate::functions()) + .chain(function::string::functions()) + .chain(function::r#struct::functions()) + .chain(function::url::functions()) + .chain(function::xml::functions()) + .collect::>() +} + +/// Returns all default aggregate functions +pub fn all_default_aggregate_functions() -> Vec> { + function::aggregate::functions() +} + +/// Returns all default window functions +pub fn all_default_window_functions() -> Vec> { + function::window::functions() +} + +/// Returns all default table functions +pub fn all_default_table_functions() -> Vec> { + function::table::functions() +} + +/// Registers all enabled packages with a [`FunctionRegistry`] +pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { + let scalar_functions: Vec> = all_default_scalar_functions(); + scalar_functions.into_iter().try_for_each(|udf| { + let existing_udf = registry.register_udf(udf)?; + if let Some(existing_udf) = existing_udf { + debug!("Overwrite existing UDF: {}", existing_udf.name()); + } + Ok(()) as Result<()> + })?; + + let aggregate_functions: Vec> = all_default_aggregate_functions(); + aggregate_functions.into_iter().try_for_each(|udf| { + let existing_udaf = registry.register_udaf(udf)?; + if let Some(existing_udaf) = existing_udaf { + debug!("Overwrite existing UDAF: {}", existing_udaf.name()); + } + Ok(()) as Result<()> + })?; + + let window_functions: Vec> = all_default_window_functions(); + window_functions.into_iter().try_for_each(|udf| { + let existing_udwf = registry.register_udwf(udf)?; + if let Some(existing_udwf) = existing_udwf { + debug!("Overwrite existing UDWF: {}", existing_udwf.name()); + } + Ok(()) as Result<()> + })?; + + Ok(()) +} diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index b778db46769d..eca40c553280 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -43,6 +43,10 @@ unicode_expressions = [] unparser = [] recursive_protection = ["dep:recursive"] +# Note the sql planner should not depend directly on the datafusion-function packages +# so that it can be used in a standalone manner with other function implementations. +# +# They are used for testing purposes only, so they are in the dev-dependencies section. [dependencies] arrow = { workspace = true } bigdecimal = { workspace = true } @@ -56,6 +60,7 @@ sqlparser = { workspace = true } [dev-dependencies] ctor = { workspace = true } +# please do not move these dependencies to the main dependencies section datafusion-functions = { workspace = true, default-features = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true } diff --git a/datafusion/sql/README.md b/datafusion/sql/README.md index 98f3c4faa2ec..d5ef3114c14e 100644 --- a/datafusion/sql/README.md +++ b/datafusion/sql/README.md @@ -25,6 +25,13 @@ project that requires a SQL query planner and does not make any assumptions abou will be translated to a physical plan. For example, there is no concept of row-based versus columnar execution in the logical plan. +Note that the [`datafusion`] crate re-exports this module. If you are already +using the [`datafusion`] crate in your project, there is no reason to use this +crate directly in your project as well. + +[df]: https://crates.io/crates/datafusion +[`datafusion`]: https://crates.io/crates/datafusion + ## Example Usage See the [examples](examples) directory for fully working examples. diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 436f4388d8a3..d0cb4263dbd9 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -74,7 +74,7 @@ fn find_closest_match(candidates: Vec, target: &str) -> Option { }) } -/// Arguments to for a function call extracted from the SQL AST +/// Arguments for a function call extracted from the SQL AST #[derive(Debug)] struct FunctionArgs { /// Function name @@ -91,6 +91,8 @@ struct FunctionArgs { null_treatment: Option, /// DISTINCT distinct: bool, + /// WITHIN GROUP clause, if any + within_group: Vec, } impl FunctionArgs { @@ -115,6 +117,7 @@ impl FunctionArgs { filter, null_treatment, distinct: false, + within_group, }); }; @@ -144,6 +147,9 @@ impl FunctionArgs { } FunctionArgumentClause::OrderBy(oby) => { if order_by.is_some() { + if !within_group.is_empty() { + return plan_err!("ORDER BY clause is only permitted in WITHIN GROUP clause when a WITHIN GROUP is used"); + } return not_impl_err!("Calling {name}: Duplicated ORDER BY clause in function arguments"); } order_by = Some(oby); @@ -176,8 +182,10 @@ impl FunctionArgs { } } - if !within_group.is_empty() { - return not_impl_err!("WITHIN GROUP is not supported yet: {within_group:?}"); + if within_group.len() > 1 { + return not_impl_err!( + "Only a single ordering expression is permitted in a WITHIN GROUP clause" + ); } let order_by = order_by.unwrap_or_default(); @@ -190,6 +198,7 @@ impl FunctionArgs { filter, null_treatment, distinct, + within_group, }) } } @@ -210,8 +219,18 @@ impl SqlToRel<'_, S> { filter, null_treatment, distinct, + within_group, } = function_args; + if over.is_some() && !within_group.is_empty() { + return plan_err!("OVER and WITHIN GROUP clause cannot be used together. \ + OVER is for window functions, whereas WITHIN GROUP is for ordered set aggregate functions"); + } + + if !order_by.is_empty() && !within_group.is_empty() { + return plan_err!("ORDER BY and WITHIN GROUP clauses cannot be used together in the same aggregate function"); + } + // If function is a window function (it has an OVER clause), // it shouldn't have ordering requirement as function argument // required ordering should be defined in OVER clause. @@ -346,7 +365,7 @@ impl SqlToRel<'_, S> { null_treatment, } = window_expr; - return Expr::WindowFunction(expr::WindowFunction::new(func_def, args)) + return Expr::from(expr::WindowFunction::new(func_def, args)) .partition_by(partition_by) .order_by(order_by) .window_frame(window_frame) @@ -356,15 +375,54 @@ impl SqlToRel<'_, S> { } else { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { - let order_by = self.order_by_to_sort_expr( - order_by, - schema, - planner_context, - true, - None, - )?; - let order_by = (!order_by.is_empty()).then_some(order_by); - let args = self.function_args_to_expr(args, schema, planner_context)?; + if fm.is_ordered_set_aggregate() && within_group.is_empty() { + return plan_err!("WITHIN GROUP clause is required when calling ordered set aggregate function({})", fm.name()); + } + + if null_treatment.is_some() && !fm.supports_null_handling_clause() { + return plan_err!( + "[IGNORE | RESPECT] NULLS are not permitted for {}", + fm.name() + ); + } + + let mut args = + self.function_args_to_expr(args, schema, planner_context)?; + + let order_by = if fm.is_ordered_set_aggregate() { + let within_group = self.order_by_to_sort_expr( + within_group, + schema, + planner_context, + false, + None, + )?; + + // add target column expression in within group clause to function arguments + if !within_group.is_empty() { + args = within_group + .iter() + .map(|sort| sort.expr.clone()) + .chain(args) + .collect::>(); + } + (!within_group.is_empty()).then_some(within_group) + } else { + let order_by = if !order_by.is_empty() { + order_by + } else { + within_group + }; + let order_by = self.order_by_to_sort_expr( + order_by, + schema, + planner_context, + true, + None, + )?; + (!order_by.is_empty()).then_some(order_by) + }; + let filter: Option> = filter .map(|e| self.sql_expr_to_logical_expr(*e, schema, planner_context)) .transpose()? @@ -408,17 +466,12 @@ impl SqlToRel<'_, S> { if let Some(suggested_func_name) = suggest_valid_function(&name, is_function_window, self.context_provider) { - plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?") - .map_err(|e| { - let span = Span::try_from_sqlparser_span(sql_parser_span); - let mut diagnostic = - Diagnostic::new_error(format!("Invalid function '{name}'"), span); - diagnostic.add_note( - format!("Possible function '{}'", suggested_func_name), - None, - ); - e.with_diagnostic(diagnostic) - }) + let span = Span::try_from_sqlparser_span(sql_parser_span); + let mut diagnostic = + Diagnostic::new_error(format!("Invalid function '{name}'"), span); + diagnostic + .add_note(format!("Possible function '{suggested_func_name}'"), None); + plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?"; diagnostic=diagnostic) } else { internal_err!("No functions registered with this context.") } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index d29ccdc6a7e9..e92869873731 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -215,7 +215,7 @@ impl SqlToRel<'_, S> { } SQLExpr::Extract { field, expr, .. } => { let mut extract_args = vec![ - Expr::Literal(ScalarValue::from(format!("{field}"))), + Expr::Literal(ScalarValue::from(format!("{field}")), None), self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ]; @@ -644,7 +644,9 @@ impl SqlToRel<'_, S> { values: Vec, ) -> Result { match values.first() { - Some(SQLExpr::Identifier(_)) | Some(SQLExpr::Value(_)) => { + Some(SQLExpr::Identifier(_)) + | Some(SQLExpr::Value(_)) + | Some(SQLExpr::CompoundIdentifier(_)) => { self.parse_struct(schema, planner_context, values, vec![]) } None => not_impl_err!("Empty tuple not supported yet"), diff --git a/datafusion/sql/src/expr/order_by.rs b/datafusion/sql/src/expr/order_by.rs index cce3f3004809..d357c3753e13 100644 --- a/datafusion/sql/src/expr/order_by.rs +++ b/datafusion/sql/src/expr/order_by.rs @@ -41,13 +41,13 @@ impl SqlToRel<'_, S> { /// If false, interpret numeric literals as constant values. pub(crate) fn order_by_to_sort_expr( &self, - exprs: Vec, + order_by_exprs: Vec, input_schema: &DFSchema, planner_context: &mut PlannerContext, literal_to_column: bool, additional_schema: Option<&DFSchema>, ) -> Result> { - if exprs.is_empty() { + if order_by_exprs.is_empty() { return Ok(vec![]); } @@ -61,13 +61,23 @@ impl SqlToRel<'_, S> { None => input_schema, }; - let mut expr_vec = vec![]; - for e in exprs { + let mut sort_expr_vec = Vec::with_capacity(order_by_exprs.len()); + + let make_sort_expr = + |expr: Expr, asc: Option, nulls_first: Option| { + let asc = asc.unwrap_or(true); + // When asc is true, by default nulls last to be consistent with postgres + // postgres rule: https://www.postgresql.org/docs/current/queries-order.html + let nulls_first = nulls_first.unwrap_or(!asc); + Sort::new(expr, asc, nulls_first) + }; + + for order_by_expr in order_by_exprs { let OrderByExpr { expr, options: OrderByOptions { asc, nulls_first }, with_fill, - } = e; + } = order_by_expr; if let Some(with_fill) = with_fill { return not_impl_err!("ORDER BY WITH FILL is not supported: {with_fill}"); @@ -102,15 +112,9 @@ impl SqlToRel<'_, S> { self.sql_expr_to_logical_expr(e, order_by_schema, planner_context)? } }; - let asc = asc.unwrap_or(true); - expr_vec.push(Sort::new( - expr, - asc, - // When asc is true, by default nulls last to be consistent with postgres - // postgres rule: https://www.postgresql.org/docs/current/queries-order.html - nulls_first.unwrap_or(!asc), - )) + sort_expr_vec.push(make_sort_expr(expr, asc, nulls_first)); } - Ok(expr_vec) + + Ok(sort_expr_vec) } } diff --git a/datafusion/sql/src/expr/subquery.rs b/datafusion/sql/src/expr/subquery.rs index 225c5d74c2ab..602d39233d58 100644 --- a/datafusion/sql/src/expr/subquery.rs +++ b/datafusion/sql/src/expr/subquery.rs @@ -138,15 +138,9 @@ impl SqlToRel<'_, S> { if sub_plan.schema().fields().len() > 1 { let sub_schema = sub_plan.schema(); let field_names = sub_schema.field_names(); - - plan_err!("{}: {}", error_message, field_names.join(", ")).map_err(|err| { - let diagnostic = self.build_multi_column_diagnostic( - spans, - error_message, - help_message, - ); - err.with_diagnostic(diagnostic) - }) + let diagnostic = + self.build_multi_column_diagnostic(spans, error_message, help_message); + plan_err!("{}: {}", error_message, field_names.join(", "); diagnostic=diagnostic) } else { Ok(()) } diff --git a/datafusion/sql/src/expr/substring.rs b/datafusion/sql/src/expr/substring.rs index 59c78bc713cc..8f6e77e035c1 100644 --- a/datafusion/sql/src/expr/substring.rs +++ b/datafusion/sql/src/expr/substring.rs @@ -51,7 +51,7 @@ impl SqlToRel<'_, S> { (None, Some(for_expr)) => { let arg = self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; - let from_logic = Expr::Literal(ScalarValue::Int64(Some(1))); + let from_logic = Expr::Literal(ScalarValue::Int64(Some(1)), None); let for_logic = self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?; vec![arg, from_logic, for_logic] diff --git a/datafusion/sql/src/expr/unary_op.rs b/datafusion/sql/src/expr/unary_op.rs index 626b79d6c3b6..e0c94543f601 100644 --- a/datafusion/sql/src/expr/unary_op.rs +++ b/datafusion/sql/src/expr/unary_op.rs @@ -45,16 +45,18 @@ impl SqlToRel<'_, S> { { Ok(operand) } else { - plan_err!("Unary operator '+' only supports numeric, interval and timestamp types").map_err(|e| { - let span = operand.spans().and_then(|s| s.first()); - let mut diagnostic = Diagnostic::new_error( - format!("+ cannot be used with {data_type}"), - span - ); - diagnostic.add_note("+ can only be used with numbers, intervals, and timestamps", None); - diagnostic.add_help(format!("perhaps you need to cast {operand}"), None); - e.with_diagnostic(diagnostic) - }) + let span = operand.spans().and_then(|s| s.first()); + let mut diagnostic = Diagnostic::new_error( + format!("+ cannot be used with {data_type}"), + span, + ); + diagnostic.add_note( + "+ can only be used with numbers, intervals, and timestamps", + None, + ); + diagnostic + .add_help(format!("perhaps you need to cast {operand}"), None); + plan_err!("Unary operator '+' only supports numeric, interval and timestamp types"; diagnostic=diagnostic) } } UnaryOperator::Minus => { diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index be4a45a25750..7075a1afd9dd 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -50,7 +50,7 @@ impl SqlToRel<'_, S> { match value { Value::Number(n, _) => self.parse_sql_number(&n, false), Value::SingleQuotedString(s) | Value::DoubleQuotedString(s) => Ok(lit(s)), - Value::Null => Ok(Expr::Literal(ScalarValue::Null)), + Value::Null => Ok(Expr::Literal(ScalarValue::Null, None)), Value::Boolean(n) => Ok(lit(n)), Value::Placeholder(param) => { Self::create_placeholder_expr(param, param_data_types) @@ -131,10 +131,7 @@ impl SqlToRel<'_, S> { // Check if the placeholder is in the parameter list let param_type = param_data_types.get(idx); // Data type of the parameter - debug!( - "type of param {} param_data_types[idx]: {:?}", - param, param_type - ); + debug!("type of param {param} param_data_types[idx]: {param_type:?}"); Ok(Expr::Placeholder(Placeholder::new( param, @@ -383,11 +380,10 @@ fn parse_decimal(unsigned_number: &str, negative: bool) -> Result { int_val ) })?; - Ok(Expr::Literal(ScalarValue::Decimal128( - Some(val), - precision as u8, - scale as i8, - ))) + Ok(Expr::Literal( + ScalarValue::Decimal128(Some(val), precision as u8, scale as i8), + None, + )) } else if precision <= DECIMAL256_MAX_PRECISION as u64 { let val = bigint_to_i256(&int_val).ok_or_else(|| { // Failures are unexpected here as we have already checked the precision @@ -396,11 +392,10 @@ fn parse_decimal(unsigned_number: &str, negative: bool) -> Result { int_val ) })?; - Ok(Expr::Literal(ScalarValue::Decimal256( - Some(val), - precision as u8, - scale as i8, - ))) + Ok(Expr::Literal( + ScalarValue::Decimal256(Some(val), precision as u8, scale as i8), + None, + )) } else { not_impl_err!( "Decimal precision {} exceeds the maximum supported precision: {}", @@ -486,10 +481,13 @@ mod tests { ]; for (input, expect) in cases { let output = parse_decimal(input, true).unwrap(); - assert_eq!(output, Expr::Literal(expect.arithmetic_negate().unwrap())); + assert_eq!( + output, + Expr::Literal(expect.arithmetic_negate().unwrap(), None) + ); let output = parse_decimal(input, false).unwrap(); - assert_eq!(output, Expr::Literal(expect)); + assert_eq!(output, Expr::Literal(expect, None)); } // scale < i8::MIN diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 822b651eae86..9731eebad167 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -20,9 +20,9 @@ //! This parser implements DataFusion specific statements such as //! `CREATE EXTERNAL TABLE` -use std::collections::VecDeque; -use std::fmt; - +use datafusion_common::config::SqlParserOptions; +use datafusion_common::DataFusionError; +use datafusion_common::{sql_err, Diagnostic, Span}; use sqlparser::ast::{ExprWithAlias, OrderByOptions}; use sqlparser::tokenizer::TokenWithSpan; use sqlparser::{ @@ -34,15 +34,22 @@ use sqlparser::{ parser::{Parser, ParserError}, tokenizer::{Token, Tokenizer, Word}, }; +use std::collections::VecDeque; +use std::fmt; // Use `Parser::expected` instead, if possible macro_rules! parser_err { - ($MSG:expr) => { - Err(ParserError::ParserError($MSG.to_string())) - }; + ($MSG:expr $(; diagnostic = $DIAG:expr)?) => {{ + + let err = DataFusionError::from(ParserError::ParserError($MSG.to_string())); + $( + let err = err.with_diagnostic($DIAG); + )? + Err(err) + }}; } -fn parse_file_type(s: &str) -> Result { +fn parse_file_type(s: &str) -> Result { Ok(s.to_uppercase()) } @@ -140,7 +147,7 @@ impl fmt::Display for CopyToStatement { write!(f, "COPY {source} TO {target}")?; if let Some(file_type) = stored_as { - write!(f, " STORED AS {}", file_type)?; + write!(f, " STORED AS {file_type}")?; } if !partitioned_by.is_empty() { write!(f, " PARTITIONED BY ({})", partitioned_by.join(", "))?; @@ -266,11 +273,9 @@ impl fmt::Display for Statement { } } -fn ensure_not_set(field: &Option, name: &str) -> Result<(), ParserError> { +fn ensure_not_set(field: &Option, name: &str) -> Result<(), DataFusionError> { if field.is_some() { - return Err(ParserError::ParserError(format!( - "{name} specified more than once", - ))); + parser_err!(format!("{name} specified more than once",))? } Ok(()) } @@ -285,6 +290,7 @@ fn ensure_not_set(field: &Option, name: &str) -> Result<(), ParserError> { /// [`Statement`] for a list of this special syntax pub struct DFParser<'a> { pub parser: Parser<'a>, + options: SqlParserOptions, } /// Same as `sqlparser` @@ -356,21 +362,28 @@ impl<'a> DFParserBuilder<'a> { self } - pub fn build(self) -> Result, ParserError> { + pub fn build(self) -> Result, DataFusionError> { let mut tokenizer = Tokenizer::new(self.dialect, self.sql); - let tokens = tokenizer.tokenize_with_location()?; + // Convert TokenizerError -> ParserError + let tokens = tokenizer + .tokenize_with_location() + .map_err(ParserError::from)?; Ok(DFParser { parser: Parser::new(self.dialect) .with_tokens_with_locations(tokens) .with_recursion_limit(self.recursion_limit), + options: SqlParserOptions { + recursion_limit: self.recursion_limit, + ..Default::default() + }, }) } } impl<'a> DFParser<'a> { #[deprecated(since = "46.0.0", note = "DFParserBuilder")] - pub fn new(sql: &'a str) -> Result { + pub fn new(sql: &'a str) -> Result { DFParserBuilder::new(sql).build() } @@ -378,13 +391,13 @@ impl<'a> DFParser<'a> { pub fn new_with_dialect( sql: &'a str, dialect: &'a dyn Dialect, - ) -> Result { + ) -> Result { DFParserBuilder::new(sql).with_dialect(dialect).build() } /// Parse a sql string into one or [`Statement`]s using the /// [`GenericDialect`]. - pub fn parse_sql(sql: &'a str) -> Result, ParserError> { + pub fn parse_sql(sql: &'a str) -> Result, DataFusionError> { let mut parser = DFParserBuilder::new(sql).build()?; parser.parse_statements() @@ -395,7 +408,7 @@ impl<'a> DFParser<'a> { pub fn parse_sql_with_dialect( sql: &str, dialect: &dyn Dialect, - ) -> Result, ParserError> { + ) -> Result, DataFusionError> { let mut parser = DFParserBuilder::new(sql).with_dialect(dialect).build()?; parser.parse_statements() } @@ -403,14 +416,14 @@ impl<'a> DFParser<'a> { pub fn parse_sql_into_expr_with_dialect( sql: &str, dialect: &dyn Dialect, - ) -> Result { + ) -> Result { let mut parser = DFParserBuilder::new(sql).with_dialect(dialect).build()?; parser.parse_expr() } /// Parse a sql string into one or [`Statement`]s - pub fn parse_statements(&mut self) -> Result, ParserError> { + pub fn parse_statements(&mut self) -> Result, DataFusionError> { let mut stmts = VecDeque::new(); let mut expecting_statement_delimiter = false; loop { @@ -438,12 +451,22 @@ impl<'a> DFParser<'a> { &self, expected: &str, found: TokenWithSpan, - ) -> Result { - parser_err!(format!("Expected {expected}, found: {found}")) + ) -> Result { + let sql_parser_span = found.span; + let span = Span::try_from_sqlparser_span(sql_parser_span); + let diagnostic = Diagnostic::new_error( + format!("Expected: {expected}, found: {found}{}", found.span.start), + span, + ); + parser_err!( + format!("Expected: {expected}, found: {found}{}", found.span.start); + diagnostic= + diagnostic + ) } /// Parse a new expression - pub fn parse_statement(&mut self) -> Result { + pub fn parse_statement(&mut self) -> Result { match self.parser.peek_token().token { Token::Word(w) => { match w.keyword { @@ -455,9 +478,7 @@ impl<'a> DFParser<'a> { if let Token::Word(w) = self.parser.peek_nth_token(1).token { // use native parser for COPY INTO if w.keyword == Keyword::INTO { - return Ok(Statement::Statement(Box::from( - self.parser.parse_statement()?, - ))); + return self.parse_and_handle_statement(); } } self.parser.next_token(); // COPY @@ -469,36 +490,49 @@ impl<'a> DFParser<'a> { } _ => { // use sqlparser-rs parser - Ok(Statement::Statement(Box::from( - self.parser.parse_statement()?, - ))) + self.parse_and_handle_statement() } } } _ => { // use the native parser - Ok(Statement::Statement(Box::from( - self.parser.parse_statement()?, - ))) + self.parse_and_handle_statement() } } } - pub fn parse_expr(&mut self) -> Result { + pub fn parse_expr(&mut self) -> Result { if let Token::Word(w) = self.parser.peek_token().token { match w.keyword { Keyword::CREATE | Keyword::COPY | Keyword::EXPLAIN => { - return parser_err!("Unsupported command in expression"); + return parser_err!("Unsupported command in expression")?; } _ => {} } } - self.parser.parse_expr_with_alias() + Ok(self.parser.parse_expr_with_alias()?) + } + + /// Helper method to parse a statement and handle errors consistently, especially for recursion limits + fn parse_and_handle_statement(&mut self) -> Result { + self.parser + .parse_statement() + .map(|stmt| Statement::Statement(Box::from(stmt))) + .map_err(|e| match e { + ParserError::RecursionLimitExceeded => DataFusionError::SQL( + ParserError::RecursionLimitExceeded, + Some(format!( + " (current limit: {})", + self.options.recursion_limit + )), + ), + other => DataFusionError::SQL(other, None), + }) } /// Parse a SQL `COPY TO` statement - pub fn parse_copy(&mut self) -> Result { + pub fn parse_copy(&mut self) -> Result { // parse as a query let source = if self.parser.consume_token(&Token::LParen) { let query = self.parser.parse_query()?; @@ -541,7 +575,7 @@ impl<'a> DFParser<'a> { Keyword::WITH => { self.parser.expect_keyword(Keyword::HEADER)?; self.parser.expect_keyword(Keyword::ROW)?; - return parser_err!("WITH HEADER ROW clause is no longer in use. Please use the OPTIONS clause with 'format.has_header' set appropriately, e.g., OPTIONS ('format.has_header' 'true')"); + return parser_err!("WITH HEADER ROW clause is no longer in use. Please use the OPTIONS clause with 'format.has_header' set appropriately, e.g., OPTIONS ('format.has_header' 'true')")?; } Keyword::PARTITIONED => { self.parser.expect_keyword(Keyword::BY)?; @@ -561,17 +595,13 @@ impl<'a> DFParser<'a> { if token == Token::EOF || token == Token::SemiColon { break; } else { - return Err(ParserError::ParserError(format!( - "Unexpected token {token}" - ))); + return self.expected("end of statement or ;", token)?; } } } let Some(target) = builder.target else { - return Err(ParserError::ParserError( - "Missing TO clause in COPY statement".into(), - )); + return parser_err!("Missing TO clause in COPY statement")?; }; Ok(Statement::CopyTo(CopyToStatement { @@ -589,7 +619,7 @@ impl<'a> DFParser<'a> { /// because it allows keywords as well as other non words /// /// [`parse_literal_string`]: sqlparser::parser::Parser::parse_literal_string - pub fn parse_option_key(&mut self) -> Result { + pub fn parse_option_key(&mut self) -> Result { let next_token = self.parser.next_token(); match next_token.token { Token::Word(Word { value, .. }) => { @@ -602,7 +632,7 @@ impl<'a> DFParser<'a> { // Unquoted namespaced keys have to conform to the syntax // "[\.]*". If we have a key that breaks this // pattern, error out: - return self.parser.expected("key name", next_token); + return self.expected("key name", next_token); } } Ok(parts.join(".")) @@ -610,7 +640,7 @@ impl<'a> DFParser<'a> { Token::SingleQuotedString(s) => Ok(s), Token::DoubleQuotedString(s) => Ok(s), Token::EscapedStringLiteral(s) => Ok(s), - _ => self.parser.expected("key name", next_token), + _ => self.expected("key name", next_token), } } @@ -620,7 +650,7 @@ impl<'a> DFParser<'a> { /// word or keyword in this location. /// /// [`parse_value`]: sqlparser::parser::Parser::parse_value - pub fn parse_option_value(&mut self) -> Result { + pub fn parse_option_value(&mut self) -> Result { let next_token = self.parser.next_token(); match next_token.token { // e.g. things like "snappy" or "gzip" that may be keywords @@ -629,12 +659,12 @@ impl<'a> DFParser<'a> { Token::DoubleQuotedString(s) => Ok(Value::DoubleQuotedString(s)), Token::EscapedStringLiteral(s) => Ok(Value::EscapedStringLiteral(s)), Token::Number(n, l) => Ok(Value::Number(n, l)), - _ => self.parser.expected("string or numeric value", next_token), + _ => self.expected("string or numeric value", next_token), } } /// Parse a SQL `EXPLAIN` - pub fn parse_explain(&mut self) -> Result { + pub fn parse_explain(&mut self) -> Result { let analyze = self.parser.parse_keyword(Keyword::ANALYZE); let verbose = self.parser.parse_keyword(Keyword::VERBOSE); let format = self.parse_explain_format()?; @@ -649,7 +679,7 @@ impl<'a> DFParser<'a> { })) } - pub fn parse_explain_format(&mut self) -> Result, ParserError> { + pub fn parse_explain_format(&mut self) -> Result, DataFusionError> { if !self.parser.parse_keyword(Keyword::FORMAT) { return Ok(None); } @@ -659,15 +689,13 @@ impl<'a> DFParser<'a> { Token::Word(w) => Ok(w.value), Token::SingleQuotedString(w) => Ok(w), Token::DoubleQuotedString(w) => Ok(w), - _ => self - .parser - .expected("an explain format such as TREE", next_token), + _ => self.expected("an explain format such as TREE", next_token), }?; Ok(Some(format)) } /// Parse a SQL `CREATE` statement handling `CREATE EXTERNAL TABLE` - pub fn parse_create(&mut self) -> Result { + pub fn parse_create(&mut self) -> Result { if self.parser.parse_keyword(Keyword::EXTERNAL) { self.parse_create_external_table(false) } else if self.parser.parse_keyword(Keyword::UNBOUNDED) { @@ -678,7 +706,7 @@ impl<'a> DFParser<'a> { } } - fn parse_partitions(&mut self) -> Result, ParserError> { + fn parse_partitions(&mut self) -> Result, DataFusionError> { let mut partitions: Vec = vec![]; if !self.parser.consume_token(&Token::LParen) || self.parser.consume_token(&Token::RParen) @@ -708,7 +736,7 @@ impl<'a> DFParser<'a> { } /// Parse the ordering clause of a `CREATE EXTERNAL TABLE` SQL statement - pub fn parse_order_by_exprs(&mut self) -> Result, ParserError> { + pub fn parse_order_by_exprs(&mut self) -> Result, DataFusionError> { let mut values = vec![]; self.parser.expect_token(&Token::LParen)?; loop { @@ -721,7 +749,7 @@ impl<'a> DFParser<'a> { } /// Parse an ORDER BY sub-expression optionally followed by ASC or DESC. - pub fn parse_order_by_expr(&mut self) -> Result { + pub fn parse_order_by_expr(&mut self) -> Result { let expr = self.parser.parse_expr()?; let asc = if self.parser.parse_keyword(Keyword::ASC) { @@ -753,7 +781,7 @@ impl<'a> DFParser<'a> { // This is a copy of the equivalent implementation in sqlparser. fn parse_columns( &mut self, - ) -> Result<(Vec, Vec), ParserError> { + ) -> Result<(Vec, Vec), DataFusionError> { let mut columns = vec![]; let mut constraints = vec![]; if !self.parser.consume_token(&Token::LParen) @@ -789,7 +817,7 @@ impl<'a> DFParser<'a> { Ok((columns, constraints)) } - fn parse_column_def(&mut self) -> Result { + fn parse_column_def(&mut self) -> Result { let name = self.parser.parse_identifier()?; let data_type = self.parser.parse_data_type()?; let mut options = vec![]; @@ -820,7 +848,7 @@ impl<'a> DFParser<'a> { fn parse_create_external_table( &mut self, unbounded: bool, - ) -> Result { + ) -> Result { let temporary = self .parser .parse_one_of_keywords(&[Keyword::TEMP, Keyword::TEMPORARY]) @@ -868,15 +896,15 @@ impl<'a> DFParser<'a> { } else { self.parser.expect_keyword(Keyword::HEADER)?; self.parser.expect_keyword(Keyword::ROW)?; - return parser_err!("WITH HEADER ROW clause is no longer in use. Please use the OPTIONS clause with 'format.has_header' set appropriately, e.g., OPTIONS (format.has_header true)"); + return parser_err!("WITH HEADER ROW clause is no longer in use. Please use the OPTIONS clause with 'format.has_header' set appropriately, e.g., OPTIONS (format.has_header true)")?; } } Keyword::DELIMITER => { - return parser_err!("DELIMITER clause is no longer in use. Please use the OPTIONS clause with 'format.delimiter' set appropriately, e.g., OPTIONS (format.delimiter ',')"); + return parser_err!("DELIMITER clause is no longer in use. Please use the OPTIONS clause with 'format.delimiter' set appropriately, e.g., OPTIONS (format.delimiter ',')")?; } Keyword::COMPRESSION => { self.parser.expect_keyword(Keyword::TYPE)?; - return parser_err!("COMPRESSION TYPE clause is no longer in use. Please use the OPTIONS clause with 'format.compression' set appropriately, e.g., OPTIONS (format.compression gzip)"); + return parser_err!("COMPRESSION TYPE clause is no longer in use. Please use the OPTIONS clause with 'format.compression' set appropriately, e.g., OPTIONS (format.compression gzip)")?; } Keyword::PARTITIONED => { self.parser.expect_keyword(Keyword::BY)?; @@ -899,7 +927,7 @@ impl<'a> DFParser<'a> { columns.extend(cols); if !cons.is_empty() { - return Err(ParserError::ParserError( + return sql_err!(ParserError::ParserError( "Constraints on Partition Columns are not supported" .to_string(), )); @@ -919,21 +947,19 @@ impl<'a> DFParser<'a> { if token == Token::EOF || token == Token::SemiColon { break; } else { - return Err(ParserError::ParserError(format!( - "Unexpected token {token}" - ))); + return self.expected("end of statement or ;", token)?; } } } // Validations: location and file_type are required if builder.file_type.is_none() { - return Err(ParserError::ParserError( + return sql_err!(ParserError::ParserError( "Missing STORED AS clause in CREATE EXTERNAL TABLE statement".into(), )); } if builder.location.is_none() { - return Err(ParserError::ParserError( + return sql_err!(ParserError::ParserError( "Missing LOCATION clause in CREATE EXTERNAL TABLE statement".into(), )); } @@ -955,7 +981,7 @@ impl<'a> DFParser<'a> { } /// Parses the set of valid formats - fn parse_file_format(&mut self) -> Result { + fn parse_file_format(&mut self) -> Result { let token = self.parser.next_token(); match &token.token { Token::Word(w) => parse_file_type(&w.value), @@ -967,7 +993,7 @@ impl<'a> DFParser<'a> { /// /// This method supports keywords as key names as well as multiple /// value types such as Numbers as well as Strings. - fn parse_value_options(&mut self) -> Result, ParserError> { + fn parse_value_options(&mut self) -> Result, DataFusionError> { let mut options = vec![]; self.parser.expect_token(&Token::LParen)?; @@ -999,7 +1025,7 @@ mod tests { use sqlparser::dialect::SnowflakeDialect; use sqlparser::tokenizer::Span; - fn expect_parse_ok(sql: &str, expected: Statement) -> Result<(), ParserError> { + fn expect_parse_ok(sql: &str, expected: Statement) -> Result<(), DataFusionError> { let statements = DFParser::parse_sql(sql)?; assert_eq!( statements.len(), @@ -1041,7 +1067,7 @@ mod tests { } #[test] - fn create_external_table() -> Result<(), ParserError> { + fn create_external_table() -> Result<(), DataFusionError> { // positive case let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv'"; let display = None; @@ -1262,13 +1288,13 @@ mod tests { "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (p1 int, c1) LOCATION 'foo.csv'"; expect_parse_error( sql, - "sql parser error: Expected: a data type name, found: )", + "SQL error: ParserError(\"Expected: a data type name, found: ) at Line: 1, Column: 73\")", ); // negative case: mixed column defs and column names in `PARTITIONED BY` clause let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (c1, p1 int) LOCATION 'foo.csv'"; - expect_parse_error(sql, "sql parser error: Expected ',' or ')' after partition definition, found: int"); + expect_parse_error(sql, "SQL error: ParserError(\"Expected: ',' or ')' after partition definition, found: int at Line: 1, Column: 70\")"); // positive case: additional options (one entry) can be specified let sql = @@ -1514,7 +1540,7 @@ mod tests { } #[test] - fn copy_to_table_to_table() -> Result<(), ParserError> { + fn copy_to_table_to_table() -> Result<(), DataFusionError> { // positive case let sql = "COPY foo TO bar STORED AS CSV"; let expected = Statement::CopyTo(CopyToStatement { @@ -1530,7 +1556,7 @@ mod tests { } #[test] - fn skip_copy_into_snowflake() -> Result<(), ParserError> { + fn skip_copy_into_snowflake() -> Result<(), DataFusionError> { let sql = "COPY INTO foo FROM @~/staged FILE_FORMAT = (FORMAT_NAME = 'mycsv');"; let dialect = Box::new(SnowflakeDialect); let statements = DFParser::parse_sql_with_dialect(sql, dialect.as_ref())?; @@ -1547,7 +1573,7 @@ mod tests { } #[test] - fn explain_copy_to_table_to_table() -> Result<(), ParserError> { + fn explain_copy_to_table_to_table() -> Result<(), DataFusionError> { let cases = vec![ ("EXPLAIN COPY foo TO bar STORED AS PARQUET", false, false), ( @@ -1588,7 +1614,7 @@ mod tests { } #[test] - fn copy_to_query_to_table() -> Result<(), ParserError> { + fn copy_to_query_to_table() -> Result<(), DataFusionError> { let statement = verified_stmt("SELECT 1"); // unwrap the various layers @@ -1621,7 +1647,7 @@ mod tests { } #[test] - fn copy_to_options() -> Result<(), ParserError> { + fn copy_to_options() -> Result<(), DataFusionError> { let sql = "COPY foo TO bar STORED AS CSV OPTIONS ('row_group_size' '55')"; let expected = Statement::CopyTo(CopyToStatement { source: object_name("foo"), @@ -1638,7 +1664,7 @@ mod tests { } #[test] - fn copy_to_partitioned_by() -> Result<(), ParserError> { + fn copy_to_partitioned_by() -> Result<(), DataFusionError> { let sql = "COPY foo TO bar STORED AS CSV PARTITIONED BY (a) OPTIONS ('row_group_size' '55')"; let expected = Statement::CopyTo(CopyToStatement { source: object_name("foo"), @@ -1655,7 +1681,7 @@ mod tests { } #[test] - fn copy_to_multi_options() -> Result<(), ParserError> { + fn copy_to_multi_options() -> Result<(), DataFusionError> { // order of options is preserved let sql = "COPY foo TO bar STORED AS parquet OPTIONS ('format.row_group_size' 55, 'format.compression' snappy, 'execution.keep_partition_by_columns' true)"; @@ -1754,7 +1780,7 @@ mod tests { assert_contains!( err.to_string(), - "sql parser error: recursion limit exceeded" + "SQL error: RecursionLimitExceeded (current limit: 1)" ); } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 3325c98aa74b..b50ad1fafda0 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -52,8 +52,8 @@ pub struct ParserOptions { pub enable_options_value_normalization: bool, /// Whether to collect spans pub collect_spans: bool, - /// Whether `VARCHAR` is mapped to `Utf8View` during SQL planning. - pub map_varchar_to_utf8view: bool, + /// Whether string types (VARCHAR, CHAR, Text, and String) are mapped to `Utf8View` during SQL planning. + pub map_string_types_to_utf8view: bool, } impl ParserOptions { @@ -72,7 +72,7 @@ impl ParserOptions { parse_float_as_decimal: false, enable_ident_normalization: true, support_varchar_with_length: true, - map_varchar_to_utf8view: false, + map_string_types_to_utf8view: true, enable_options_value_normalization: false, collect_spans: false, } @@ -112,9 +112,9 @@ impl ParserOptions { self } - /// Sets the `map_varchar_to_utf8view` option. - pub fn with_map_varchar_to_utf8view(mut self, value: bool) -> Self { - self.map_varchar_to_utf8view = value; + /// Sets the `map_string_types_to_utf8view` option. + pub fn with_map_string_types_to_utf8view(mut self, value: bool) -> Self { + self.map_string_types_to_utf8view = value; self } @@ -143,7 +143,7 @@ impl From<&SqlParserOptions> for ParserOptions { parse_float_as_decimal: options.parse_float_as_decimal, enable_ident_normalization: options.enable_ident_normalization, support_varchar_with_length: options.support_varchar_with_length, - map_varchar_to_utf8view: options.map_varchar_to_utf8view, + map_string_types_to_utf8view: options.map_string_types_to_utf8view, enable_options_value_normalization: options .enable_options_value_normalization, collect_spans: options.collect_spans, @@ -577,7 +577,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { please set `support_varchar_with_length` to be true" ), _ => { - if self.options.map_varchar_to_utf8view { + if self.options.map_string_types_to_utf8view { Ok(DataType::Utf8View) } else { Ok(DataType::Utf8) @@ -601,7 +601,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) } SQLDataType::Char(_) | SQLDataType::Text | SQLDataType::String(_) => { - Ok(DataType::Utf8) + if self.options.map_string_types_to_utf8view { + Ok(DataType::Utf8View) + } else { + Ok(DataType::Utf8) + } } SQLDataType::Timestamp(precision, tz_info) if precision.is_none() || [0, 3, 6, 9].contains(&precision.unwrap()) => @@ -612,7 +616,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Timestamp With Time Zone // INPUT : [SQLDataType] TimestampTz + [Config] Time Zone // OUTPUT: [ArrowDataType] Timestamp - self.context_provider.options().execution.time_zone.clone() + Some(self.context_provider.options().execution.time_zone.clone()) } else { // Timestamp Without Time zone None @@ -816,7 +820,7 @@ impl std::fmt::Display for IdentTaker { if !first { write!(f, ".")?; } - write!(f, "{}", ident)?; + write!(f, "{ident}")?; first = false; } diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index ea641320c01b..2ea2299c1fcf 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -22,14 +22,15 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::stack::StackGuard; use datafusion_common::{not_impl_err, Constraints, DFSchema, Result}; use datafusion_expr::expr::Sort; -use datafusion_expr::select_expr::SelectExpr; + use datafusion_expr::{ - CreateMemoryTable, DdlStatement, Distinct, LogicalPlan, LogicalPlanBuilder, + CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, }; use sqlparser::ast::{ - Expr as SQLExpr, Offset as SQLOffset, OrderBy, OrderByExpr, OrderByKind, Query, - SelectInto, SetExpr, + Expr as SQLExpr, Ident, Offset as SQLOffset, OrderBy, OrderByExpr, OrderByKind, + Query, SelectInto, SetExpr, }; +use sqlparser::tokenizer::Span; impl SqlToRel<'_, S> { /// Generate a logical plan from an SQL query/subquery @@ -137,7 +138,7 @@ impl SqlToRel<'_, S> { Some(into) => Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { name: self.object_name_to_table_reference(into.name)?, - constraints: Constraints::empty(), + constraints: Constraints::default(), input: Arc::new(plan), if_not_exists: false, or_replace: false, @@ -158,7 +159,7 @@ fn to_order_by_exprs(order_by: Option) -> Result> { /// Returns the order by expressions from the query with the select expressions. pub(crate) fn to_order_by_exprs_with_select( order_by: Option, - _select_exprs: Option<&Vec>, // TODO: ORDER BY ALL + select_exprs: Option<&Vec>, ) -> Result> { let Some(OrderBy { kind, interpolate }) = order_by else { // If no order by, return an empty array. @@ -168,7 +169,30 @@ pub(crate) fn to_order_by_exprs_with_select( return not_impl_err!("ORDER BY INTERPOLATE is not supported"); } match kind { - OrderByKind::All(_) => not_impl_err!("ORDER BY ALL is not supported"), + OrderByKind::All(order_by_options) => { + let Some(exprs) = select_exprs else { + return Ok(vec![]); + }; + let order_by_exprs = exprs + .iter() + .map(|select_expr| match select_expr { + Expr::Column(column) => Ok(OrderByExpr { + expr: SQLExpr::Identifier(Ident { + value: column.name.clone(), + quote_style: None, + span: Span::empty(), + }), + options: order_by_options.clone(), + with_fill: None, + }), + // TODO: Support other types of expressions + _ => not_impl_err!( + "ORDER BY ALL is not supported for non-column expressions" + ), + }) + .collect::>>()?; + Ok(order_by_exprs) + } OrderByKind::Expressions(order_by_exprs) => Ok(order_by_exprs), } } diff --git a/datafusion/sql/src/relation/join.rs b/datafusion/sql/src/relation/join.rs index 8a3c20e3971b..10491963e3ce 100644 --- a/datafusion/sql/src/relation/join.rs +++ b/datafusion/sql/src/relation/join.rs @@ -142,7 +142,7 @@ impl SqlToRel<'_, S> { "Expected identifier in USING clause" ) }) - .map(|ident| self.ident_normalizer.normalize(ident.clone())) + .map(|ident| Column::from_name(self.ident_normalizer.normalize(ident.clone()))) } }) .collect::>>()?; diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index dee855f8c000..aa37d74fd4d8 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -66,7 +66,7 @@ impl SqlToRel<'_, S> { .get_table_function_source(&tbl_func_name, args)?; let plan = LogicalPlanBuilder::scan( TableReference::Bare { - table: "tmp_table".into(), + table: format!("{tbl_func_name}()").into(), }, provider, None, @@ -92,7 +92,7 @@ impl SqlToRel<'_, S> { .build(), (None, Err(e)) => { let e = e.with_diagnostic(Diagnostic::new_error( - format!("table '{}' not found", table_ref), + format!("table '{table_ref}' not found"), Span::try_from_sqlparser_span(relation_span), )); Err(e) @@ -154,6 +154,35 @@ impl SqlToRel<'_, S> { "UNNEST table factor with offset is not supported yet" ); } + TableFactor::Function { + name, args, alias, .. + } => { + let tbl_func_ref = self.object_name_to_table_reference(name)?; + let schema = planner_context + .outer_query_schema() + .cloned() + .unwrap_or_else(DFSchema::empty); + let func_args = args + .into_iter() + .map(|arg| match arg { + FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) + | FunctionArg::Named { + arg: FunctionArgExpr::Expr(expr), + .. + } => { + self.sql_expr_to_logical_expr(expr, &schema, planner_context) + } + _ => plan_err!("Unsupported function argument: {arg:?}"), + }) + .collect::>>()?; + let provider = self + .context_provider + .get_table_function_source(tbl_func_ref.table(), func_args)?; + let plan = + LogicalPlanBuilder::scan(tbl_func_ref.table(), provider, None)? + .build()?; + (plan, alias) + } // @todo Support TableFactory::TableFunction? _ => { return not_impl_err!( diff --git a/datafusion/sql/src/resolve.rs b/datafusion/sql/src/resolve.rs index 96012a92c09a..9e909f66fa97 100644 --- a/datafusion/sql/src/resolve.rs +++ b/datafusion/sql/src/resolve.rs @@ -78,7 +78,7 @@ impl Visitor for RelationVisitor { if !with.recursive { // This is a bit hackish as the CTE will be visited again as part of visiting `q`, // but thankfully `insert_relation` is idempotent. - cte.visit(self); + let _ = cte.visit(self); } self.ctes_in_scope .push(ObjectName::from(vec![cte.alias.name.clone()])); @@ -143,7 +143,7 @@ fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) { visitor.insert_relation(table_name); } CopyToSource::Query(query) => { - query.visit(visitor); + let _ = query.visit(visitor); } }, DFStatement::Explain(explain) => visit_statement(&explain.statement, visitor), diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 33994b60b735..b50fbf68129c 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -94,13 +94,13 @@ impl SqlToRel<'_, S> { planner_context, )?; - let order_by = - to_order_by_exprs_with_select(query_order_by, Some(&select_exprs))?; - // Having and group by clause may reference aliases defined in select projection let projected_plan = self.project(base_plan.clone(), select_exprs)?; let select_exprs = projected_plan.expressions(); + let order_by = + to_order_by_exprs_with_select(query_order_by, Some(&select_exprs))?; + // Place the fields of the base plan at the front so that when there are references // with the same name, the fields of the base plan will be searched first. // See https://github.com/apache/datafusion/issues/9162 @@ -307,6 +307,15 @@ impl SqlToRel<'_, S> { let mut intermediate_plan = input; let mut intermediate_select_exprs = select_exprs; + // Fast path: If there is are no unnests in the select_exprs, wrap the plan in a projection + if !intermediate_select_exprs + .iter() + .any(has_unnest_expr_recursively) + { + return LogicalPlanBuilder::from(intermediate_plan) + .project(intermediate_select_exprs)? + .build(); + } // Each expr in select_exprs can contains multiple unnest stage // The transformation happen bottom up, one at a time for each iteration @@ -374,6 +383,12 @@ impl SqlToRel<'_, S> { fn try_process_aggregate_unnest(&self, input: LogicalPlan) -> Result { match input { + // Fast path if there are no unnest in group by + LogicalPlan::Aggregate(ref agg) + if !&agg.group_expr.iter().any(has_unnest_expr_recursively) => + { + Ok(input) + } LogicalPlan::Aggregate(agg) => { let agg_expr = agg.aggr_expr.clone(); let (new_input, new_group_by_exprs) = @@ -885,7 +900,7 @@ impl SqlToRel<'_, S> { | SelectItem::UnnamedExpr(expr) = proj { let mut err = None; - visit_expressions_mut(expr, |expr| { + let _ = visit_expressions_mut(expr, |expr| { if let SQLExpr::Function(f) = expr { if let Some(WindowType::NamedWindow(ident)) = &f.over { let normalized_ident = @@ -939,3 +954,17 @@ fn check_conflicting_windows(window_defs: &[NamedWindowDefinition]) -> Result<() } Ok(()) } + +/// Returns true if the expression recursively contains an `Expr::Unnest` expression +fn has_unnest_expr_recursively(expr: &Expr) -> bool { + let mut has_unnest = false; + let _ = expr.apply(|e| { + if let Expr::Unnest(_) = e { + has_unnest = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }); + has_unnest +} diff --git a/datafusion/sql/src/set_expr.rs b/datafusion/sql/src/set_expr.rs index 272d6f874b4d..5b65e1c045bd 100644 --- a/datafusion/sql/src/set_expr.rs +++ b/datafusion/sql/src/set_expr.rs @@ -95,26 +95,22 @@ impl SqlToRel<'_, S> { if left_plan.schema().fields().len() == right_plan.schema().fields().len() { return Ok(()); } - - plan_err!("{} queries have different number of columns", op).map_err(|err| { - err.with_diagnostic( - Diagnostic::new_error( - format!("{} queries have different number of columns", op), - set_expr_span, - ) - .with_note( - format!("this side has {} fields", left_plan.schema().fields().len()), - left_span, - ) - .with_note( - format!( - "this side has {} fields", - right_plan.schema().fields().len() - ), - right_span, - ), - ) - }) + let diagnostic = Diagnostic::new_error( + format!("{op} queries have different number of columns"), + set_expr_span, + ) + .with_note( + format!("this side has {} fields", left_plan.schema().fields().len()), + left_span, + ) + .with_note( + format!( + "this side has {} fields", + right_plan.schema().fields().len() + ), + right_span, + ); + plan_err!("{} queries have different number of columns", op; diagnostic =diagnostic) } pub(super) fn set_operation_to_plan( diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 1f1c235fee6f..f83cffe47a17 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -215,7 +215,7 @@ impl SqlToRel<'_, S> { ) -> Result { match statement { Statement::ExplainTable { - describe_alias: DescribeAlias::Describe, // only parse 'DESCRIBE table_name' and not 'EXPLAIN table_name' + describe_alias: DescribeAlias::Describe | DescribeAlias::Desc, // only parse 'DESCRIBE table_name' or 'DESC table_name' and not 'EXPLAIN table_name' table_name, .. } => self.describe_table_to_plan(table_name), @@ -696,7 +696,7 @@ impl SqlToRel<'_, S> { statement, } => { // Convert parser data types to DataFusion data types - let data_types: Vec = data_types + let mut data_types: Vec = data_types .into_iter() .map(|t| self.convert_data_type(&t)) .collect::>()?; @@ -710,6 +710,19 @@ impl SqlToRel<'_, S> { *statement, &mut planner_context, )?; + + if data_types.is_empty() { + let map_types = plan.get_parameter_types()?; + let param_types: Vec<_> = (1..=map_types.len()) + .filter_map(|i| { + let key = format!("${i}"); + map_types.get(&key).and_then(|opt| opt.clone()) + }) + .collect(); + data_types.extend(param_types.iter().cloned()); + planner_context.with_prepare_param_data_types(param_types); + } + Ok(LogicalPlan::Statement(PlanStatement::Prepare(Prepare { name: ident_to_string(&name), data_types, @@ -1609,7 +1622,7 @@ impl SqlToRel<'_, S> { // If config does not belong to any namespace, assume it is // a format option and apply the format prefix for backwards // compatibility. - let renamed_key = format!("format.{}", key); + let renamed_key = format!("format.{key}"); options_map.insert(renamed_key.to_lowercase(), value_string); } else { options_map.insert(key.to_lowercase(), value_string); @@ -1794,7 +1807,10 @@ impl SqlToRel<'_, S> { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(table_name.clone())?; let table_source = self.context_provider.get_table_source(table_ref.clone())?; - let schema = table_source.schema().to_dfschema_ref()?; + let schema = DFSchema::try_from_qualified_schema( + table_ref.clone(), + &table_source.schema(), + )?; let scan = LogicalPlanBuilder::scan(table_ref.clone(), Arc::clone(&table_source), None)? .build()?; @@ -2049,7 +2065,7 @@ impl SqlToRel<'_, S> { .cloned() .unwrap_or_else(|| { // If there is no default for the column, then the default is NULL - Expr::Literal(ScalarValue::Null) + Expr::Literal(ScalarValue::Null, None) }) .cast_to(target_field.data_type(), &DFSchema::empty())?, }; diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index 6fcc203637cc..d9ade822aa00 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -32,6 +32,8 @@ pub struct QueryBuilder { fetch: Option, locks: Vec, for_clause: Option, + // If true, we need to unparse LogicalPlan::Union as a SQL `UNION` rather than a `UNION ALL`. + distinct_union: bool, } #[allow(dead_code)] @@ -75,6 +77,13 @@ impl QueryBuilder { self.for_clause = value; self } + pub fn distinct_union(&mut self) -> &mut Self { + self.distinct_union = true; + self + } + pub fn is_distinct_union(&self) -> bool { + self.distinct_union + } pub fn build(&self) -> Result { let order_by = self .order_by_kind @@ -112,6 +121,7 @@ impl QueryBuilder { fetch: Default::default(), locks: Default::default(), for_clause: Default::default(), + distinct_union: false, } } } @@ -155,6 +165,11 @@ impl SelectBuilder { self.projection = value; self } + pub fn pop_projections(&mut self) -> Vec { + let ret = self.projection.clone(); + self.projection.clear(); + ret + } pub fn already_projected(&self) -> bool { !self.projection.is_empty() } @@ -198,7 +213,7 @@ impl SelectBuilder { value: &ast::Expr, ) -> &mut Self { if let Some(selection) = &mut self.selection { - visit_expressions_mut(selection, |expr| { + let _ = visit_expressions_mut(selection, |expr| { if expr == existing_expr { *expr = value.clone(); } @@ -383,6 +398,7 @@ pub struct RelationBuilder { #[allow(dead_code)] #[derive(Clone)] +#[allow(clippy::large_enum_variant)] enum TableFactorBuilder { Table(TableRelationBuilder), Derived(DerivedRelationBuilder), @@ -690,9 +706,9 @@ impl fmt::Display for BuilderError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Self::UninitializedField(ref field) => { - write!(f, "`{}` must be initialized", field) + write!(f, "`{field}` must be initialized") } - Self::ValidationError(ref error) => write!(f, "{}", error), + Self::ValidationError(ref error) => write!(f, "{error}"), } } } diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index a7bde967f2fa..3c8de7e74032 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -197,6 +197,13 @@ pub trait Dialect: Send + Sync { fn unnest_as_table_factor(&self) -> bool { false } + + /// Allows the dialect to override column alias unparsing if the dialect has specific rules. + /// Returns None if the default unparsing should be used, or Some(String) if there is + /// a custom implementation for the alias. + fn col_alias_overrides(&self, _alias: &str) -> Result> { + Ok(None) + } } /// `IntervalStyle` to use for unparsing @@ -500,6 +507,49 @@ impl Dialect for SqliteDialect { } } +#[derive(Default)] +pub struct BigQueryDialect {} + +impl Dialect for BigQueryDialect { + fn identifier_quote_style(&self, _: &str) -> Option { + Some('`') + } + + fn col_alias_overrides(&self, alias: &str) -> Result> { + // Check if alias contains any special characters not supported by BigQuery col names + // https://cloud.google.com/bigquery/docs/schemas#flexible-column-names + let special_chars: [char; 20] = [ + '!', '"', '$', '(', ')', '*', ',', '.', '/', ';', '?', '@', '[', '\\', ']', + '^', '`', '{', '}', '~', + ]; + + if alias.chars().any(|c| special_chars.contains(&c)) { + let mut encoded_name = String::new(); + for c in alias.chars() { + if special_chars.contains(&c) { + encoded_name.push_str(&format!("_{}", c as u32)); + } else { + encoded_name.push(c); + } + } + Ok(Some(encoded_name)) + } else { + Ok(Some(alias.to_string())) + } + } + + fn unnest_as_table_factor(&self) -> bool { + true + } +} + +impl BigQueryDialect { + #[must_use] + pub fn new() -> Self { + Self {} + } +} + pub struct CustomDialect { identifier_quote_style: Option, supports_nulls_first_in_sort: bool, diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 0b6e9b7a3ddd..1e39d7186d13 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -187,19 +187,20 @@ impl Unparser<'_> { Expr::Cast(Cast { expr, data_type }) => { Ok(self.cast_to_sql(expr, data_type)?) } - Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), + Expr::Literal(value, _) => Ok(self.scalar_to_sql(value)?), Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), - Expr::WindowFunction(WindowFunction { - fun, - params: - WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - .. - }, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + .. + }, + } = window_fun.as_ref(); let func_name = fun.name(); let args = self.function_args_to_sql(args)?; @@ -273,13 +274,6 @@ impl Unparser<'_> { pattern, escape_char, case_insensitive: _, - }) - | Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive: _, }) => Ok(ast::Expr::Like { negated: *negated, expr: Box::new(self.expr_to_sql_inner(expr)?), @@ -287,12 +281,39 @@ impl Unparser<'_> { escape_char: escape_char.map(|c| c.to_string()), any: false, }), + Expr::Like(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }) => { + if *case_insensitive { + Ok(ast::Expr::ILike { + negated: *negated, + expr: Box::new(self.expr_to_sql_inner(expr)?), + pattern: Box::new(self.expr_to_sql_inner(pattern)?), + escape_char: escape_char.map(|c| c.to_string()), + any: false, + }) + } else { + Ok(ast::Expr::Like { + negated: *negated, + expr: Box::new(self.expr_to_sql_inner(expr)?), + pattern: Box::new(self.expr_to_sql_inner(pattern)?), + escape_char: escape_char.map(|c| c.to_string()), + any: false, + }) + } + } + Expr::AggregateFunction(agg) => { let func_name = agg.func.name(); let AggregateFunctionParams { distinct, args, filter, + order_by, .. } = &agg.params; @@ -301,6 +322,16 @@ impl Unparser<'_> { Some(filter) => Some(Box::new(self.expr_to_sql_inner(filter)?)), None => None, }; + let within_group = if agg.func.is_ordered_set_aggregate() { + order_by + .as_ref() + .unwrap_or(&Vec::new()) + .iter() + .map(|sort_expr| self.sort_to_sql(sort_expr)) + .collect::>>()? + } else { + Vec::new() + }; Ok(ast::Expr::Function(Function { name: ObjectName::from(vec![Ident { value: func_name.to_string(), @@ -316,7 +347,7 @@ impl Unparser<'_> { filter, null_treatment: None, over: None, - within_group: vec![], + within_group, parameters: ast::FunctionArguments::None, uses_odbc_syntax: false, })) @@ -571,7 +602,7 @@ impl Unparser<'_> { .chunks_exact(2) .map(|chunk| { let key = match &chunk[0] { - Expr::Literal(ScalarValue::Utf8(Some(s))) => self.new_ident_quoted_if_needs(s.to_string()), + Expr::Literal(ScalarValue::Utf8(Some(s)), _) => self.new_ident_quoted_if_needs(s.to_string()), _ => return internal_err!("named_struct expects even arguments to be strings, but received: {:?}", &chunk[0]) }; @@ -600,7 +631,7 @@ impl Unparser<'_> { }; let field = match &args[1] { - Expr::Literal(lit) => self.new_ident_quoted_if_needs(lit.to_string()), + Expr::Literal(lit, _) => self.new_ident_quoted_if_needs(lit.to_string()), _ => { return internal_err!( "get_field expects second argument to be a string, but received: {:?}", @@ -679,13 +710,21 @@ impl Unparser<'_> { } pub fn col_to_sql(&self, col: &Column) -> Result { + // Replace the column name if the dialect has an override + let col_name = + if let Some(rewritten_name) = self.dialect.col_alias_overrides(&col.name)? { + rewritten_name + } else { + col.name.to_string() + }; + if let Some(table_ref) = &col.relation { let mut id = if self.dialect.full_qualified_col() { table_ref.to_vec() } else { vec![table_ref.table().to_string()] }; - id.push(col.name.to_string()); + id.push(col_name); return Ok(ast::Expr::CompoundIdentifier( id.iter() .map(|i| self.new_ident_quoted_if_needs(i.to_string())) @@ -693,7 +732,7 @@ impl Unparser<'_> { )); } Ok(ast::Expr::Identifier( - self.new_ident_quoted_if_needs(col.name.to_string()), + self.new_ident_quoted_if_needs(col_name), )) } @@ -1103,16 +1142,16 @@ impl Unparser<'_> { ScalarValue::Float16(None) => Ok(ast::Expr::value(ast::Value::Null)), ScalarValue::Float32(Some(f)) => { let f_val = match f.fract() { - 0.0 => format!("{:.1}", f), - _ => format!("{}", f), + 0.0 => format!("{f:.1}"), + _ => format!("{f}"), }; Ok(ast::Expr::value(ast::Value::Number(f_val, false))) } ScalarValue::Float32(None) => Ok(ast::Expr::value(ast::Value::Null)), ScalarValue::Float64(Some(f)) => { let f_val = match f.fract() { - 0.0 => format!("{:.1}", f), - _ => format!("{}", f), + 0.0 => format!("{f:.1}"), + _ => format!("{f}"), }; Ok(ast::Expr::value(ast::Value::Number(f_val, false))) } @@ -1855,10 +1894,20 @@ mod tests { expr: Box::new(col("a")), pattern: Box::new(lit("foo")), escape_char: Some('o'), - case_insensitive: true, + case_insensitive: false, }), r#"a NOT LIKE 'foo' ESCAPE 'o'"#, ), + ( + Expr::Like(Like { + negated: true, + expr: Box::new(col("a")), + pattern: Box::new(lit("foo")), + escape_char: Some('o'), + case_insensitive: true, + }), + r#"a NOT ILIKE 'foo' ESCAPE 'o'"#, + ), ( Expr::SimilarTo(Like { negated: false, @@ -1870,87 +1919,87 @@ mod tests { r#"a LIKE 'foo' ESCAPE 'o'"#, ), ( - Expr::Literal(ScalarValue::Date64(Some(0))), + Expr::Literal(ScalarValue::Date64(Some(0)), None), r#"CAST('1970-01-01 00:00:00' AS DATETIME)"#, ), ( - Expr::Literal(ScalarValue::Date64(Some(10000))), + Expr::Literal(ScalarValue::Date64(Some(10000)), None), r#"CAST('1970-01-01 00:00:10' AS DATETIME)"#, ), ( - Expr::Literal(ScalarValue::Date64(Some(-10000))), + Expr::Literal(ScalarValue::Date64(Some(-10000)), None), r#"CAST('1969-12-31 23:59:50' AS DATETIME)"#, ), ( - Expr::Literal(ScalarValue::Date32(Some(0))), + Expr::Literal(ScalarValue::Date32(Some(0)), None), r#"CAST('1970-01-01' AS DATE)"#, ), ( - Expr::Literal(ScalarValue::Date32(Some(10))), + Expr::Literal(ScalarValue::Date32(Some(10)), None), r#"CAST('1970-01-11' AS DATE)"#, ), ( - Expr::Literal(ScalarValue::Date32(Some(-1))), + Expr::Literal(ScalarValue::Date32(Some(-1)), None), r#"CAST('1969-12-31' AS DATE)"#, ), ( - Expr::Literal(ScalarValue::TimestampSecond(Some(10001), None)), + Expr::Literal(ScalarValue::TimestampSecond(Some(10001), None), None), r#"CAST('1970-01-01 02:46:41' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampSecond( - Some(10001), - Some("+08:00".into()), - )), + Expr::Literal( + ScalarValue::TimestampSecond(Some(10001), Some("+08:00".into())), + None, + ), r#"CAST('1970-01-01 10:46:41 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMillisecond(Some(10001), None)), + Expr::Literal(ScalarValue::TimestampMillisecond(Some(10001), None), None), r#"CAST('1970-01-01 00:00:10.001' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMillisecond( - Some(10001), - Some("+08:00".into()), - )), + Expr::Literal( + ScalarValue::TimestampMillisecond(Some(10001), Some("+08:00".into())), + None, + ), r#"CAST('1970-01-01 08:00:10.001 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMicrosecond(Some(10001), None)), + Expr::Literal(ScalarValue::TimestampMicrosecond(Some(10001), None), None), r#"CAST('1970-01-01 00:00:00.010001' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampMicrosecond( - Some(10001), - Some("+08:00".into()), - )), + Expr::Literal( + ScalarValue::TimestampMicrosecond(Some(10001), Some("+08:00".into())), + None, + ), r#"CAST('1970-01-01 08:00:00.010001 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampNanosecond(Some(10001), None)), + Expr::Literal(ScalarValue::TimestampNanosecond(Some(10001), None), None), r#"CAST('1970-01-01 00:00:00.000010001' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::TimestampNanosecond( - Some(10001), - Some("+08:00".into()), - )), + Expr::Literal( + ScalarValue::TimestampNanosecond(Some(10001), Some("+08:00".into())), + None, + ), r#"CAST('1970-01-01 08:00:00.000010001 +08:00' AS TIMESTAMP)"#, ), ( - Expr::Literal(ScalarValue::Time32Second(Some(10001))), + Expr::Literal(ScalarValue::Time32Second(Some(10001)), None), r#"CAST('02:46:41' AS TIME)"#, ), ( - Expr::Literal(ScalarValue::Time32Millisecond(Some(10001))), + Expr::Literal(ScalarValue::Time32Millisecond(Some(10001)), None), r#"CAST('00:00:10.001' AS TIME)"#, ), ( - Expr::Literal(ScalarValue::Time64Microsecond(Some(10001))), + Expr::Literal(ScalarValue::Time64Microsecond(Some(10001)), None), r#"CAST('00:00:00.010001' AS TIME)"#, ), ( - Expr::Literal(ScalarValue::Time64Nanosecond(Some(10001))), + Expr::Literal(ScalarValue::Time64Nanosecond(Some(10001)), None), r#"CAST('00:00:00.000010001' AS TIME)"#, ), (sum(col("a")), r#"sum(a)"#), @@ -1979,7 +2028,7 @@ mod tests { "count(*) FILTER (WHERE true)", ), ( - Expr::WindowFunction(WindowFunction { + Expr::from(WindowFunction { fun: WindowFunctionDefinition::WindowUDF(row_number_udwf()), params: WindowFunctionParams { args: vec![col("col")], @@ -1993,7 +2042,7 @@ mod tests { ), ( #[expect(deprecated)] - Expr::WindowFunction(WindowFunction { + Expr::from(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), params: WindowFunctionParams { args: vec![Expr::Wildcard { @@ -2095,19 +2144,17 @@ mod tests { (col("need quoted").eq(lit(1)), r#"("need quoted" = 1)"#), // See test_interval_scalar_to_expr for interval literals ( - (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal128( - Some(100123), - 28, - 3, - ))), + (col("a") + col("b")).gt(Expr::Literal( + ScalarValue::Decimal128(Some(100123), 28, 3), + None, + )), r#"((a + b) > 100.123)"#, ), ( - (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal256( - Some(100123.into()), - 28, - 3, - ))), + (col("a") + col("b")).gt(Expr::Literal( + ScalarValue::Decimal256(Some(100123.into()), 28, 3), + None, + )), r#"((a + b) > 100.123)"#, ), ( @@ -2143,28 +2190,39 @@ mod tests { "MAP {'a': 1, 'b': 2}", ), ( - Expr::Literal(ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::Utf8(Some("foo".into()))), - )), + Expr::Literal( + ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::Utf8(Some("foo".into()))), + ), + None, + ), "'foo'", ), ( - Expr::Literal(ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![Some(vec![ + Expr::Literal( + ScalarValue::List(Arc::new(ListArray::from_iter_primitive::< + Int32Type, + _, + _, + >(vec![Some(vec![ Some(1), Some(2), Some(3), - ])]), - ))), + ])]))), + None, + ), "[1, 2, 3]", ), ( - Expr::Literal(ScalarValue::LargeList(Arc::new( - LargeListArray::from_iter_primitive::(vec![Some( - vec![Some(1), Some(2), Some(3)], - )]), - ))), + Expr::Literal( + ScalarValue::LargeList(Arc::new( + LargeListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + ]), + )), + None, + ), "[1, 2, 3]", ), ( @@ -2188,7 +2246,7 @@ mod tests { for (expr, expected) in tests { let ast = expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); assert_eq!(actual, expected); } @@ -2206,7 +2264,7 @@ mod tests { let expr = col("a").gt(lit(4)); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = r#"('a' > 4)"#; assert_eq!(actual, expected); @@ -2222,7 +2280,7 @@ mod tests { let expr = col("a").gt(lit(4)); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = r#"(a > 4)"#; assert_eq!(actual, expected); @@ -2246,7 +2304,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2271,7 +2329,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2293,7 +2351,7 @@ mod tests { let unparser = Unparser::new(&dialect); let ast = unparser.sort_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); assert_eq!(actual, expected); } @@ -2469,11 +2527,17 @@ mod tests { #[test] fn test_float_scalar_to_expr() { let tests = [ - (Expr::Literal(ScalarValue::Float64(Some(3f64))), "3.0"), - (Expr::Literal(ScalarValue::Float64(Some(3.1f64))), "3.1"), - (Expr::Literal(ScalarValue::Float32(Some(-2f32))), "-2.0"), + (Expr::Literal(ScalarValue::Float64(Some(3f64)), None), "3.0"), + ( + Expr::Literal(ScalarValue::Float64(Some(3.1f64)), None), + "3.1", + ), ( - Expr::Literal(ScalarValue::Float32(Some(-2.989f32))), + Expr::Literal(ScalarValue::Float32(Some(-2f32)), None), + "-2.0", + ), + ( + Expr::Literal(ScalarValue::Float32(Some(-2.989f32)), None), "-2.989", ), ]; @@ -2493,18 +2557,20 @@ mod tests { let tests = [ ( Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( - "blah".to_string(), - )))), + expr: Box::new(Expr::Literal( + ScalarValue::Utf8(Some("blah".to_string())), + None, + )), data_type: DataType::Binary, }), "'blah'", ), ( Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( - "blah".to_string(), - )))), + expr: Box::new(Expr::Literal( + ScalarValue::Utf8(Some("blah".to_string())), + None, + )), data_type: DataType::BinaryView, }), "'blah'", @@ -2543,7 +2609,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2596,10 +2662,13 @@ mod tests { let expr = ScalarUDF::new_from_impl( datafusion_functions::datetime::date_part::DatePartFunc::new(), ) - .call(vec![Expr::Literal(ScalarValue::new_utf8(unit)), col("x")]); + .call(vec![ + Expr::Literal(ScalarValue::new_utf8(unit), None), + col("x"), + ]); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); assert_eq!(actual, expected); } @@ -2626,7 +2695,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2654,7 +2723,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2693,7 +2762,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2716,13 +2785,13 @@ mod tests { (&mysql_dialect, "DATETIME"), ] { let unparser = Unparser::new(dialect); - let expr = Expr::Literal(ScalarValue::TimestampMillisecond( - Some(1738285549123), + let expr = Expr::Literal( + ScalarValue::TimestampMillisecond(Some(1738285549123), None), None, - )); + ); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST('2025-01-31 01:05:49.123' AS {identifier})"#); assert_eq!(actual, expected); @@ -2749,7 +2818,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"CAST(a AS {identifier})"#); assert_eq!(actual, expected); @@ -2775,7 +2844,7 @@ mod tests { }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = expected.to_string(); assert_eq!(actual, expected); @@ -2787,9 +2856,10 @@ mod tests { fn test_cast_value_to_dict_expr() { let tests = [( Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( - "variation".to_string(), - )))), + expr: Box::new(Expr::Literal( + ScalarValue::Utf8(Some("variation".to_string())), + None, + )), data_type: DataType::Dictionary(Box::new(Int8), Box::new(DataType::Utf8)), }), "'variation'", @@ -2827,12 +2897,12 @@ mod tests { expr: Box::new(col("a")), data_type: DataType::Float64, }), - Expr::Literal(ScalarValue::Int64(Some(2))), + Expr::Literal(ScalarValue::Int64(Some(2)), None), ], }); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = format!(r#"round(CAST("a" AS {identifier}), 2)"#); assert_eq!(actual, expected); @@ -2862,7 +2932,7 @@ mod tests { let func = WindowFunctionDefinition::WindowUDF(rank_udwf()); let mut window_func = WindowFunction::new(func, vec![]); window_func.params.order_by = vec![Sort::new(col("a"), true, true)]; - let expr = Expr::WindowFunction(window_func); + let expr = Expr::from(window_func); let ast = unparser.expr_to_sql(&expr)?; let actual = ast.to_string(); @@ -2967,7 +3037,7 @@ mod tests { datafusion_functions::datetime::date_trunc::DateTruncFunc::new(), )), args: vec![ - Expr::Literal(ScalarValue::Utf8(Some(precision.to_string()))), + Expr::Literal(ScalarValue::Utf8(Some(precision.to_string())), None), col("date_col"), ], }); @@ -3012,7 +3082,7 @@ mod tests { let expr = cast(col("a"), DataType::Utf8View); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = r#"CAST(a AS CHAR)"#.to_string(); assert_eq!(actual, expected); @@ -3020,7 +3090,7 @@ mod tests { let expr = col("a").eq(lit(ScalarValue::Utf8View(Some("hello".to_string())))); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = r#"(a = 'hello')"#.to_string(); assert_eq!(actual, expected); @@ -3028,7 +3098,7 @@ mod tests { let expr = col("a").is_not_null(); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = r#"a IS NOT NULL"#.to_string(); assert_eq!(actual, expected); @@ -3036,7 +3106,7 @@ mod tests { let expr = col("a").is_null(); let ast = unparser.expr_to_sql(&expr)?; - let actual = format!("{}", ast); + let actual = format!("{ast}"); let expected = r#"a IS NULL"#.to_string(); assert_eq!(actual, expected); diff --git a/datafusion/sql/src/unparser/extension_unparser.rs b/datafusion/sql/src/unparser/extension_unparser.rs index f7deabe7c902..b778130ca5a2 100644 --- a/datafusion/sql/src/unparser/extension_unparser.rs +++ b/datafusion/sql/src/unparser/extension_unparser.rs @@ -64,6 +64,7 @@ pub enum UnparseWithinStatementResult { } /// The result of unparsing a custom logical node to a statement. +#[allow(clippy::large_enum_variant)] pub enum UnparseToStatementResult { /// If the custom logical node was successfully unparsed to a statement. Modified(Statement), diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index c41effa47885..4fb1e42d6028 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -50,7 +50,7 @@ use datafusion_expr::{ UserDefinedLogicalNode, }; use sqlparser::ast::{self, Ident, OrderByKind, SetExpr, TableAliasColumnDef}; -use std::sync::Arc; +use std::{sync::Arc, vec}; /// Convert a DataFusion [`LogicalPlan`] to [`ast::Statement`] /// @@ -309,12 +309,13 @@ impl Unparser<'_> { plan: &LogicalPlan, relation: &mut RelationBuilder, lateral: bool, + columns: Vec, ) -> Result<()> { - if self.dialect.requires_derived_table_alias() { + if self.dialect.requires_derived_table_alias() || !columns.is_empty() { self.derive( plan, relation, - Some(self.new_table_alias(alias.to_string(), vec![])), + Some(self.new_table_alias(alias.to_string(), columns)), lateral, ) } else { @@ -392,6 +393,18 @@ impl Unparser<'_> { } } + // If it's a unnest projection, we should provide the table column alias + // to provide a column name for the unnest relation. + let columns = if unnest_input_type.is_some() { + p.expr + .iter() + .map(|e| { + self.new_ident_quoted_if_needs(e.schema_name().to_string()) + }) + .collect() + } else { + vec![] + }; // Projection can be top-level plan for derived table if select.already_projected() { return self.derive_with_dialect_alias( @@ -401,6 +414,7 @@ impl Unparser<'_> { unnest_input_type .filter(|t| matches!(t, UnnestInputType::OuterReference)) .is_some(), + columns, ); } self.reconstruct_select_statement(plan, p, select)?; @@ -434,6 +448,7 @@ impl Unparser<'_> { plan, relation, false, + vec![], ); } if let Some(fetch) = &limit.fetch { @@ -472,6 +487,7 @@ impl Unparser<'_> { plan, relation, false, + vec![], ); } let Some(query_ref) = query else { @@ -493,7 +509,7 @@ impl Unparser<'_> { .expr .iter() .map(|sort_expr| { - unproject_sort_expr(sort_expr, agg, sort.input.as_ref()) + unproject_sort_expr(sort_expr.clone(), agg, sort.input.as_ref()) }) .collect::>>()?; @@ -543,8 +559,26 @@ impl Unparser<'_> { plan, relation, false, + vec![], ); } + + // If this distinct is the parent of a Union and we're in a query context, + // then we need to unparse as a `UNION` rather than a `UNION ALL`. + if let Distinct::All(input) = distinct { + if matches!(input.as_ref(), LogicalPlan::Union(_)) { + if let Some(query_mut) = query.as_mut() { + query_mut.distinct_union(); + return self.select_to_sql_recursively( + input.as_ref(), + query, + select, + relation, + ); + } + } + } + let (select_distinct, input) = match distinct { Distinct::All(input) => (ast::Distinct::Distinct, input.as_ref()), Distinct::On(on) => { @@ -582,6 +616,10 @@ impl Unparser<'_> { } _ => (&join.left, &join.right), }; + // If there's an outer projection plan, it will already set up the projection. + // In that case, we don't need to worry about setting up the projection here. + // The outer projection plan will handle projecting the correct columns. + let already_projected = select.already_projected(); let left_plan = match try_transform_to_simple_table_scan_with_filters(left_plan)? { @@ -599,6 +637,13 @@ impl Unparser<'_> { relation, )?; + let left_projection: Option> = if !already_projected + { + Some(select.pop_projections()) + } else { + None + }; + let right_plan = match try_transform_to_simple_table_scan_with_filters(right_plan)? { Some((plan, filters)) => { @@ -657,12 +702,20 @@ impl Unparser<'_> { &mut right_relation, )?; + let right_projection: Option> = if !already_projected + { + Some(select.pop_projections()) + } else { + None + }; + match join.join_type { JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark | JoinType::RightSemi - | JoinType::RightAnti => { + | JoinType::RightAnti + | JoinType::RightMark => { let mut query_builder = QueryBuilder::default(); let mut from = TableWithJoinsBuilder::default(); let mut exists_select: SelectBuilder = SelectBuilder::default(); @@ -686,7 +739,8 @@ impl Unparser<'_> { let negated = match join.join_type { JoinType::LeftSemi | JoinType::RightSemi - | JoinType::LeftMark => false, + | JoinType::LeftMark + | JoinType::RightMark => false, JoinType::LeftAnti | JoinType::RightAnti => true, _ => unreachable!(), }; @@ -694,13 +748,28 @@ impl Unparser<'_> { subquery: Box::new(query_builder.build()?), negated, }; - if join.join_type == JoinType::LeftMark { - let (table_ref, _) = right_plan.schema().qualified_field(0); - let column = self - .col_to_sql(&Column::new(table_ref.cloned(), "mark"))?; - select.replace_mark(&column, &exists_expr); - } else { - select.selection(Some(exists_expr)); + + match join.join_type { + JoinType::LeftMark | JoinType::RightMark => { + let source_schema = + if join.join_type == JoinType::LeftMark { + right_plan.schema() + } else { + left_plan.schema() + }; + let (table_ref, _) = source_schema.qualified_field(0); + let column = self.col_to_sql(&Column::new( + table_ref.cloned(), + "mark", + ))?; + select.replace_mark(&column, &exists_expr); + } + _ => { + select.selection(Some(exists_expr)); + } + } + if let Some(projection) = left_projection { + select.projection(projection); } } JoinType::Inner @@ -719,6 +788,21 @@ impl Unparser<'_> { let mut from = select.pop_from().unwrap(); from.push_join(ast_join); select.push_from(from); + if !already_projected { + let Some(left_projection) = left_projection else { + return internal_err!("Left projection is missing"); + }; + + let Some(right_projection) = right_projection else { + return internal_err!("Right projection is missing"); + }; + + let projection = left_projection + .into_iter() + .chain(right_projection.into_iter()) + .collect(); + select.projection(projection); + } } }; @@ -780,6 +864,7 @@ impl Unparser<'_> { plan, relation, false, + vec![], ); } @@ -793,6 +878,15 @@ impl Unparser<'_> { return internal_err!("UNION operator requires at least 2 inputs"); } + let set_quantifier = + if query.as_ref().is_some_and(|q| q.is_distinct_union()) { + // Setting the SetQuantifier to None will unparse as a `UNION` + // rather than a `UNION ALL`. + ast::SetQuantifier::None + } else { + ast::SetQuantifier::All + }; + // Build the union expression tree bottom-up by reversing the order // note that we are also swapping left and right inputs because of the rev let union_expr = input_exprs @@ -800,7 +894,7 @@ impl Unparser<'_> { .rev() .reduce(|a, b| SetExpr::SetOperation { op: ast::SetOperator::Union, - set_quantifier: ast::SetQuantifier::All, + set_quantifier, left: Box::new(b), right: Box::new(a), }) @@ -888,6 +982,7 @@ impl Unparser<'_> { subquery.subquery.as_ref(), relation, true, + vec![], ) } } @@ -910,8 +1005,7 @@ impl Unparser<'_> { if let Expr::Alias(Alias { expr, .. }) = expr { if let Expr::Column(Column { name, .. }) = expr.as_ref() { if let Some(prefix) = name.strip_prefix(UNNEST_PLACEHOLDER) { - if prefix.starts_with(&format!("({}(", OUTER_REFERENCE_COLUMN_PREFIX)) - { + if prefix.starts_with(&format!("({OUTER_REFERENCE_COLUMN_PREFIX}(")) { return Some(UnnestInputType::OuterReference); } return Some(UnnestInputType::Scalar); @@ -998,6 +1092,7 @@ impl Unparser<'_> { if project_vec.is_empty() { builder = builder.project(vec![Expr::Literal( ScalarValue::Int64(Some(1)), + None, )])?; } else { let project_columns = project_vec @@ -1118,9 +1213,18 @@ impl Unparser<'_> { Expr::Alias(Alias { expr, name, .. }) => { let inner = self.expr_to_sql(expr)?; + // Determine the alias name to use + let col_name = if let Some(rewritten_name) = + self.dialect.col_alias_overrides(name)? + { + rewritten_name.to_string() + } else { + name.to_string() + }; + Ok(ast::SelectItem::ExprWithAlias { expr: inner, - alias: self.new_ident_quoted_if_needs(name.to_string()), + alias: self.new_ident_quoted_if_needs(col_name), }) } _ => { @@ -1163,7 +1267,9 @@ impl Unparser<'_> { JoinType::LeftSemi => ast::JoinOperator::LeftSemi(constraint), JoinType::RightAnti => ast::JoinOperator::RightAnti(constraint), JoinType::RightSemi => ast::JoinOperator::RightSemi(constraint), - JoinType::LeftMark => unimplemented!("Unparsing of Left Mark join type"), + JoinType::LeftMark | JoinType::RightMark => { + unimplemented!("Unparsing of Mark join type") + } }) } diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index de0b23fba225..89fa392c183f 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -270,51 +270,58 @@ fn find_window_expr<'a>( .find(|expr| expr.schema_name().to_string() == column_name) } -/// Transforms a Column expression into the actual expression from aggregation or projection if found. +/// Transforms all Column expressions in a sort expression into the actual expression from aggregation or projection if found. /// This is required because if an ORDER BY expression is present in an Aggregate or Select, it is replaced /// with a Column expression (e.g., "sum(catalog_returns.cr_net_loss)"). We need to transform it back to /// the actual expression, such as sum("catalog_returns"."cr_net_loss"). pub(crate) fn unproject_sort_expr( - sort_expr: &SortExpr, + mut sort_expr: SortExpr, agg: Option<&Aggregate>, input: &LogicalPlan, ) -> Result { - let mut sort_expr = sort_expr.clone(); - - // Remove alias if present, because ORDER BY cannot use aliases - if let Expr::Alias(alias) = &sort_expr.expr { - sort_expr.expr = *alias.expr.clone(); - } - - let Expr::Column(ref col_ref) = sort_expr.expr else { - return Ok(sort_expr); - }; + sort_expr.expr = sort_expr + .expr + .transform(|sub_expr| { + match sub_expr { + // Remove alias if present, because ORDER BY cannot use aliases + Expr::Alias(alias) => Ok(Transformed::yes(*alias.expr)), + Expr::Column(col) => { + if col.relation.is_some() { + return Ok(Transformed::no(Expr::Column(col))); + } - if col_ref.relation.is_some() { - return Ok(sort_expr); - }; + // In case of aggregation there could be columns containing aggregation functions we need to unproject + if let Some(agg) = agg { + if agg.schema.is_column_from_schema(&col) { + return Ok(Transformed::yes(unproject_agg_exprs( + Expr::Column(col), + agg, + None, + )?)); + } + } - // In case of aggregation there could be columns containing aggregation functions we need to unproject - if let Some(agg) = agg { - if agg.schema.is_column_from_schema(col_ref) { - let new_expr = unproject_agg_exprs(sort_expr.expr, agg, None)?; - sort_expr.expr = new_expr; - return Ok(sort_expr); - } - } + // If SELECT and ORDER BY contain the same expression with a scalar function, the ORDER BY expression will + // be replaced by a Column expression (e.g., "substr(customer.c_last_name, Int64(0), Int64(5))"), and we need + // to transform it back to the actual expression. + if let LogicalPlan::Projection(Projection { expr, schema, .. }) = + input + { + if let Ok(idx) = schema.index_of_column(&col) { + if let Some(Expr::ScalarFunction(scalar_fn)) = expr.get(idx) { + return Ok(Transformed::yes(Expr::ScalarFunction( + scalar_fn.clone(), + ))); + } + } + } - // If SELECT and ORDER BY contain the same expression with a scalar function, the ORDER BY expression will - // be replaced by a Column expression (e.g., "substr(customer.c_last_name, Int64(0), Int64(5))"), and we need - // to transform it back to the actual expression. - if let LogicalPlan::Projection(Projection { expr, schema, .. }) = input { - if let Ok(idx) = schema.index_of_column(col_ref) { - if let Some(Expr::ScalarFunction(scalar_fn)) = expr.get(idx) { - sort_expr.expr = Expr::ScalarFunction(scalar_fn.clone()); + Ok(Transformed::no(Expr::Column(col))) + } + _ => Ok(Transformed::no(sub_expr)), } - } - return Ok(sort_expr); - } - + }) + .map(|e| e.data)?; Ok(sort_expr) } @@ -385,7 +392,7 @@ pub(crate) fn try_transform_to_simple_table_scan_with_filters( let mut builder = LogicalPlanBuilder::scan( table_scan.table_name.clone(), Arc::clone(&table_scan.source), - None, + table_scan.projection.clone(), )?; if let Some(alias) = table_alias.take() { @@ -415,7 +422,7 @@ pub(crate) fn date_part_to_sql( match (style, date_part_args.len()) { (DateFieldExtractStyle::Extract, 2) => { let date_expr = unparser.expr_to_sql(&date_part_args[1])?; - if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] { let field = match field.to_lowercase().as_str() { "year" => ast::DateTimeField::Year, "month" => ast::DateTimeField::Month, @@ -436,7 +443,7 @@ pub(crate) fn date_part_to_sql( (DateFieldExtractStyle::Strftime, 2) => { let column = unparser.expr_to_sql(&date_part_args[1])?; - if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] { let field = match field.to_lowercase().as_str() { "year" => "%Y", "month" => "%m", @@ -524,7 +531,7 @@ pub(crate) fn sqlite_from_unixtime_to_sql( "datetime", &[ from_unixtime_args[0].clone(), - Expr::Literal(ScalarValue::Utf8(Some("unixepoch".to_string()))), + Expr::Literal(ScalarValue::Utf8(Some("unixepoch".to_string())), None), ], )?)) } @@ -547,7 +554,7 @@ pub(crate) fn sqlite_date_trunc_to_sql( ); } - if let Expr::Literal(ScalarValue::Utf8(Some(unit))) = &date_trunc_args[0] { + if let Expr::Literal(ScalarValue::Utf8(Some(unit)), _) = &date_trunc_args[0] { let format = match unit.to_lowercase().as_str() { "year" => "%Y", "month" => "%Y-%m", @@ -561,7 +568,7 @@ pub(crate) fn sqlite_date_trunc_to_sql( return Ok(Some(unparser.scalar_function_to_sql( "strftime", &[ - Expr::Literal(ScalarValue::Utf8(Some(format.to_string()))), + Expr::Literal(ScalarValue::Utf8(Some(format.to_string())), None), date_trunc_args[1].clone(), ], )?)); diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index bc2a94cd44ff..52832e1324be 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -158,20 +158,19 @@ fn check_column_satisfies_expr( purpose: CheckColumnsSatisfyExprsPurpose, ) -> Result<()> { if !columns.contains(expr) { + let diagnostic = Diagnostic::new_error( + purpose.diagnostic_message(expr), + expr.spans().and_then(|spans| spans.first()), + ) + .with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregare function like ANY_VALUE({expr})"), None); + return plan_err!( "{}: While expanding wildcard, column \"{}\" must appear in the GROUP BY clause or must be part of an aggregate function, currently only \"{}\" appears in the SELECT clause satisfies this requirement", purpose.message_prefix(), expr, - expr_vec_fmt!(columns) - ) - .map_err(|err| { - let diagnostic = Diagnostic::new_error( - purpose.diagnostic_message(expr), - expr.spans().and_then(|spans| spans.first()), - ) - .with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregare function like ANY_VALUE({expr})"), None); - err.with_diagnostic(diagnostic) - }); + expr_vec_fmt!(columns); + diagnostic=diagnostic + ); } Ok(()) } @@ -199,7 +198,7 @@ pub(crate) fn resolve_positions_to_exprs( match expr { // sql_expr_to_logical_expr maps number to i64 // https://github.com/apache/datafusion/blob/8d175c759e17190980f270b5894348dc4cff9bbf/datafusion/src/sql/planner.rs#L882-L887 - Expr::Literal(ScalarValue::Int64(Some(position))) + Expr::Literal(ScalarValue::Int64(Some(position)), _) if position > 0_i64 && position <= select_exprs.len() as i64 => { let index = (position - 1) as usize; @@ -209,7 +208,7 @@ pub(crate) fn resolve_positions_to_exprs( _ => select_expr.clone(), }) } - Expr::Literal(ScalarValue::Int64(Some(position))) => plan_err!( + Expr::Literal(ScalarValue::Int64(Some(position)), _) => plan_err!( "Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}", position, select_exprs.len() ), @@ -242,15 +241,21 @@ pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr let all_partition_keys = window_exprs .iter() .map(|expr| match expr { - Expr::WindowFunction(WindowFunction { - params: WindowFunctionParams { partition_by, .. }, - .. - }) => Ok(partition_by), - Expr::Alias(Alias { expr, .. }) => match expr.as_ref() { - Expr::WindowFunction(WindowFunction { + Expr::WindowFunction(window_fun) => { + let WindowFunction { params: WindowFunctionParams { partition_by, .. }, .. - }) => Ok(partition_by), + } = window_fun.as_ref(); + Ok(partition_by) + } + Expr::Alias(Alias { expr, .. }) => match expr.as_ref() { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + params: WindowFunctionParams { partition_by, .. }, + .. + } = window_fun.as_ref(); + Ok(partition_by) + } expr => exec_err!("Impossibly got non-window expr {expr:?}"), }, expr => exec_err!("Impossibly got non-window expr {expr:?}"), @@ -399,9 +404,9 @@ impl RecursiveUnnestRewriter<'_> { // Full context, we are trying to plan the execution as InnerProjection->Unnest->OuterProjection // inside unnest execution, each column inside the inner projection // will be transformed into new columns. Thus we need to keep track of these placeholding column names - let placeholder_name = format!("{UNNEST_PLACEHOLDER}({})", inner_expr_name); + let placeholder_name = format!("{UNNEST_PLACEHOLDER}({inner_expr_name})"); let post_unnest_name = - format!("{UNNEST_PLACEHOLDER}({},depth={})", inner_expr_name, level); + format!("{UNNEST_PLACEHOLDER}({inner_expr_name},depth={level})"); // This is due to the fact that unnest transformation should keep the original // column name as is, to comply with group by and order by let placeholder_column = Column::from_name(placeholder_name.clone()); @@ -681,7 +686,7 @@ mod tests { "{}=>[{}]", i.0, vec.iter() - .map(|i| format!("{}", i)) + .map(|i| format!("{i}")) .collect::>() .join(", ") ), diff --git a/datafusion/sql/tests/cases/diagnostic.rs b/datafusion/sql/tests/cases/diagnostic.rs index d08fe787948a..b3fc5dea9eff 100644 --- a/datafusion/sql/tests/cases/diagnostic.rs +++ b/datafusion/sql/tests/cases/diagnostic.rs @@ -20,16 +20,17 @@ use insta::assert_snapshot; use std::{collections::HashMap, sync::Arc}; use datafusion_common::{Diagnostic, Location, Result, Span}; -use datafusion_sql::planner::{ParserOptions, SqlToRel}; +use datafusion_sql::{ + parser::{DFParser, DFParserBuilder}, + planner::{ParserOptions, SqlToRel}, +}; use regex::Regex; -use sqlparser::{dialect::GenericDialect, parser::Parser}; use crate::{MockContextProvider, MockSessionState}; fn do_query(sql: &'static str) -> Diagnostic { - let dialect = GenericDialect {}; - let statement = Parser::new(&dialect) - .try_with_sql(sql) + let statement = DFParserBuilder::new(sql) + .build() .expect("unable to create parser") .parse_statement() .expect("unable to parse query"); @@ -41,7 +42,7 @@ fn do_query(sql: &'static str) -> Diagnostic { .with_scalar_function(Arc::new(string::concat().as_ref().clone())); let context = MockContextProvider { state }; let sql_to_rel = SqlToRel::new_with_options(&context, options); - match sql_to_rel.sql_statement_to_plan(statement) { + match sql_to_rel.statement_to_plan(statement) { Ok(_) => panic!("expected error"), Err(err) => match err.diagnostic() { Some(diag) => diag.clone(), @@ -366,3 +367,25 @@ fn test_unary_op_plus_with_non_column() -> Result<()> { assert_eq!(diag.span, None); Ok(()) } + +#[test] +fn test_syntax_error() -> Result<()> { + // create a table with a column of type varchar + let query = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (c1, p1 /*int*/int/*int*/) LOCATION 'foo.csv'"; + let spans = get_spans(query); + match DFParser::parse_sql(query) { + Ok(_) => panic!("expected error"), + Err(err) => match err.diagnostic() { + Some(diag) => { + let diag = diag.clone(); + assert_snapshot!(diag.message, @"Expected: ',' or ')' after partition definition, found: int at Line: 1, Column: 77"); + println!("{spans:?}"); + assert_eq!(diag.span, Some(spans["int"])); + Ok(()) + } + None => { + panic!("expected diagnostic") + } + }, + } +} diff --git a/datafusion/sql/tests/cases/mod.rs b/datafusion/sql/tests/cases/mod.rs index b3eedcdc41e3..426d188f633c 100644 --- a/datafusion/sql/tests/cases/mod.rs +++ b/datafusion/sql/tests/cases/mod.rs @@ -17,4 +17,5 @@ mod collection; mod diagnostic; +mod params; mod plan_to_sql; diff --git a/datafusion/sql/tests/cases/params.rs b/datafusion/sql/tests/cases/params.rs new file mode 100644 index 000000000000..15e7d923a91a --- /dev/null +++ b/datafusion/sql/tests/cases/params.rs @@ -0,0 +1,886 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan; +use arrow::datatypes::DataType; +use datafusion_common::{assert_contains, ParamValues, ScalarValue}; +use datafusion_expr::{LogicalPlan, Prepare, Statement}; +use insta::assert_snapshot; +use std::collections::HashMap; + +pub struct ParameterTest<'a> { + pub sql: &'a str, + pub expected_types: Vec<(&'a str, Option)>, + pub param_values: Vec, +} + +impl ParameterTest<'_> { + pub fn run(&self) -> String { + let plan = logical_plan(self.sql).unwrap(); + + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types: HashMap> = self + .expected_types + .iter() + .map(|(k, v)| (k.to_string(), v.clone())) + .collect(); + + assert_eq!(actual_types, expected_types); + + let plan_with_params = plan + .clone() + .with_param_values(self.param_values.clone()) + .unwrap(); + + format!("** Initial Plan:\n{plan}\n** Final Plan:\n{plan_with_params}") + } +} + +fn generate_prepare_stmt_and_data_types(sql: &str) -> (LogicalPlan, String) { + let plan = logical_plan(sql).unwrap(); + let data_types = match &plan { + LogicalPlan::Statement(Statement::Prepare(Prepare { data_types, .. })) => { + format!("{data_types:?}") + } + _ => panic!("Expected a Prepare statement"), + }; + (plan, data_types) +} + +#[test] +fn test_prepare_statement_to_plan_panic_param_format() { + // param is not number following the $ sign + // panic due to error returned from the parser + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $foo"; + + assert_snapshot!( + logical_plan(sql).unwrap_err().strip_backtrace(), + @r###" + Error during planning: Invalid placeholder, not a number: $foo + "### + ); +} + +#[test] +fn test_prepare_statement_to_plan_panic_param_zero() { + // param is zero following the $ sign + // panic due to error returned from the parser + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $0"; + + assert_snapshot!( + logical_plan(sql).unwrap_err().strip_backtrace(), + @r###" + Error during planning: Invalid placeholder, zero is not a valid index: $0 + "### + ); +} + +#[test] +fn test_prepare_statement_to_plan_panic_prepare_wrong_syntax() { + // param is not number following the $ sign + // panic due to error returned from the parser + let sql = "PREPARE AS SELECT id, age FROM person WHERE age = $foo"; + assert!(logical_plan(sql) + .unwrap_err() + .strip_backtrace() + .contains("Expected: AS, found: SELECT")) +} + +#[test] +fn test_prepare_statement_to_plan_panic_no_relation_and_constant_param() { + let sql = "PREPARE my_plan(INT) AS SELECT id + $1"; + + let plan = logical_plan(sql).unwrap_err().strip_backtrace(); + assert_snapshot!( + plan, + @r"Schema error: No field named id." + ); +} + +#[test] +fn test_prepare_statement_should_infer_types() { + // only provide 1 data type while using 2 params + let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1 + $2"; + let plan = logical_plan(sql).unwrap(); + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types = HashMap::from([ + ("$1".to_string(), Some(DataType::Int32)), + ("$2".to_string(), Some(DataType::Int64)), + ]); + assert_eq!(actual_types, expected_types); +} + +#[test] +fn test_non_prepare_statement_should_infer_types() { + // Non prepared statements (like SELECT) should also have their parameter types inferred + let sql = "SELECT 1 + $1"; + let plan = logical_plan(sql).unwrap(); + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types = HashMap::from([ + // constant 1 is inferred to be int64 + ("$1".to_string(), Some(DataType::Int64)), + ]); + assert_eq!(actual_types, expected_types); +} + +#[test] +#[should_panic( + expected = "Expected: [NOT] NULL | TRUE | FALSE | DISTINCT | [form] NORMALIZED FROM after IS, found: $1" +)] +fn test_prepare_statement_to_plan_panic_is_param() { + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age is $1"; + logical_plan(sql).unwrap(); +} + +#[test] +fn test_prepare_statement_to_plan_no_param() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32] + Projection: person.id, person.age + Filter: person.age = Int64(10) + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[Int32]"#); + + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: person.id, person.age + Filter: person.age = Int64(10) + TableScan: person + " + ); + + ////////////////////////////////////////// + // no embedded parameter and no declare it + let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [] + Projection: person.id, person.age + Filter: person.age = Int64(10) + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[]"#); + + /////////////////// + // replace params with values + let param_values: Vec = vec![]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: person.id, person.age + Filter: person.age = Int64(10) + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_to_plan_one_param_no_value_panic() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; + let plan = logical_plan(sql).unwrap(); + // declare 1 param but provide 0 + let param_values: Vec = vec![]; + + assert_snapshot!( + plan.with_param_values(param_values) + .unwrap_err() + .strip_backtrace(), + @r###" + Error during planning: Expected 1 parameters, got 0 + "###); +} + +#[test] +fn test_prepare_statement_to_plan_one_param_one_value_different_type_panic() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; + let plan = logical_plan(sql).unwrap(); + // declare 1 param but provide 0 + let param_values = vec![ScalarValue::Float64(Some(20.0))]; + + assert_snapshot!( + plan.with_param_values(param_values) + .unwrap_err() + .strip_backtrace(), + @r###" + Error during planning: Expected parameter of type Int32, got Float64 at index 0 + "### + ); +} + +#[test] +fn test_prepare_statement_to_plan_no_param_on_value_panic() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; + let plan = logical_plan(sql).unwrap(); + // declare 1 param but provide 0 + let param_values = vec![ScalarValue::Int32(Some(10))]; + + assert_snapshot!( + plan.with_param_values(param_values) + .unwrap_err() + .strip_backtrace(), + @r###" + Error during planning: Expected 0 parameters, got 1 + "### + ); +} + +#[test] +fn test_prepare_statement_to_plan_params_as_constants() { + let sql = "PREPARE my_plan(INT) AS SELECT $1"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32] + Projection: $1 + EmptyRelation + "# + ); + assert_snapshot!(dt, @r#"[Int32]"#); + + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: Int32(10) AS $1 + EmptyRelation + " + ); + + /////////////////////////////////////// + let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32] + Projection: Int64(1) + $1 + EmptyRelation + "# + ); + assert_snapshot!(dt, @r#"[Int32]"#); + + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: Int64(1) + Int32(10) AS Int64(1) + $1 + EmptyRelation + " + ); + + /////////////////////////////////////// + let sql = "PREPARE my_plan(INT, DOUBLE) AS SELECT 1 + $1 + $2"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32, Float64] + Projection: Int64(1) + $1 + $2 + EmptyRelation + "# + ); + assert_snapshot!(dt, @r#"[Int32, Float64]"#); + + /////////////////// + // replace params with values + let param_values = vec![ + ScalarValue::Int32(Some(10)), + ScalarValue::Float64(Some(10.0)), + ]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: Int64(1) + Int32(10) + Float64(10) AS Int64(1) + $1 + $2 + EmptyRelation + " + ); +} + +#[test] +fn test_infer_types_from_join() { + let test = ParameterTest { + sql: + "SELECT id, order_id FROM person JOIN orders ON id = customer_id and age = $1", + expected_types: vec![("$1", Some(DataType::Int32))], + param_values: vec![ScalarValue::Int32(Some(10))], + }; + + assert_snapshot!( + test.run(), + @r" + ** Initial Plan: + Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id AND person.age = $1 + TableScan: person + TableScan: orders + ** Final Plan: + Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10) + TableScan: person + TableScan: orders + " + ); +} + +#[test] +fn test_prepare_statement_infer_types_from_join() { + let test = ParameterTest { + sql: "PREPARE my_plan AS SELECT id, order_id FROM person JOIN orders ON id = customer_id and age = $1", + expected_types: vec![("$1", Some(DataType::Int32))], + param_values: vec![ScalarValue::Int32(Some(10))] + }; + + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [Int32] + Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id AND person.age = $1 + TableScan: person + TableScan: orders + ** Final Plan: + Projection: person.id, orders.order_id + Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10) + TableScan: person + TableScan: orders + "# + ); +} + +#[test] +fn test_infer_types_from_predicate() { + let test = ParameterTest { + sql: "SELECT id, age FROM person WHERE age = $1", + expected_types: vec![("$1", Some(DataType::Int32))], + param_values: vec![ScalarValue::Int32(Some(10))], + }; + + assert_snapshot!( + test.run(), + @r" + ** Initial Plan: + Projection: person.id, person.age + Filter: person.age = $1 + TableScan: person + ** Final Plan: + Projection: person.id, person.age + Filter: person.age = Int32(10) + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_infer_types_from_predicate() { + let test = ParameterTest { + sql: "PREPARE my_plan AS SELECT id, age FROM person WHERE age = $1", + expected_types: vec![("$1", Some(DataType::Int32))], + param_values: vec![ScalarValue::Int32(Some(10))], + }; + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [Int32] + Projection: person.id, person.age + Filter: person.age = $1 + TableScan: person + ** Final Plan: + Projection: person.id, person.age + Filter: person.age = Int32(10) + TableScan: person + "# + ); +} + +#[test] +fn test_infer_types_from_between_predicate() { + let test = ParameterTest { + sql: "SELECT id, age FROM person WHERE age BETWEEN $1 AND $2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::Int32)), + ], + param_values: vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))], + }; + + assert_snapshot!( + test.run(), + @r" + ** Initial Plan: + Projection: person.id, person.age + Filter: person.age BETWEEN $1 AND $2 + TableScan: person + ** Final Plan: + Projection: person.id, person.age + Filter: person.age BETWEEN Int32(10) AND Int32(30) + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_infer_types_from_between_predicate() { + let test = ParameterTest { + sql: "PREPARE my_plan AS SELECT id, age FROM person WHERE age BETWEEN $1 AND $2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::Int32)), + ], + param_values: vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))], + }; + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [Int32, Int32] + Projection: person.id, person.age + Filter: person.age BETWEEN $1 AND $2 + TableScan: person + ** Final Plan: + Projection: person.id, person.age + Filter: person.age BETWEEN Int32(10) AND Int32(30) + TableScan: person + "# + ); +} + +#[test] +fn test_infer_types_subquery() { + let test = ParameterTest { + sql: "SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)", + expected_types: vec![("$1", Some(DataType::UInt32))], + param_values: vec![ScalarValue::UInt32(Some(10))] + }; + + assert_snapshot!( + test.run(), + @r" + ** Initial Plan: + Projection: person.id, person.age + Filter: person.age = () + Subquery: + Projection: max(person.age) + Aggregate: groupBy=[[]], aggr=[[max(person.age)]] + Filter: person.id = $1 + TableScan: person + TableScan: person + ** Final Plan: + Projection: person.id, person.age + Filter: person.age = () + Subquery: + Projection: max(person.age) + Aggregate: groupBy=[[]], aggr=[[max(person.age)]] + Filter: person.id = UInt32(10) + TableScan: person + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_infer_types_subquery() { + let test = ParameterTest { + sql: "PREPARE my_plan AS SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)", + expected_types: vec![("$1", Some(DataType::UInt32))], + param_values: vec![ScalarValue::UInt32(Some(10))] + }; + + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [UInt32] + Projection: person.id, person.age + Filter: person.age = () + Subquery: + Projection: max(person.age) + Aggregate: groupBy=[[]], aggr=[[max(person.age)]] + Filter: person.id = $1 + TableScan: person + TableScan: person + ** Final Plan: + Projection: person.id, person.age + Filter: person.age = () + Subquery: + Projection: max(person.age) + Aggregate: groupBy=[[]], aggr=[[max(person.age)]] + Filter: person.id = UInt32(10) + TableScan: person + TableScan: person + "# + ); +} + +#[test] +fn test_update_infer() { + let test = ParameterTest { + sql: "update person set age=$1 where id=$2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::UInt32)), + ], + param_values: vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))], + }; + + assert_snapshot!( + test.run(), + @r" + ** Initial Plan: + Dml: op=[Update] table=[person] + Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, $1 AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 + Filter: person.id = $2 + TableScan: person + ** Final Plan: + Dml: op=[Update] table=[person] + Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 + Filter: person.id = UInt32(1) + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_update_infer() { + let test = ParameterTest { + sql: "PREPARE my_plan AS update person set age=$1 where id=$2", + expected_types: vec![ + ("$1", Some(DataType::Int32)), + ("$2", Some(DataType::UInt32)), + ], + param_values: vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))], + }; + + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [Int32, UInt32] + Dml: op=[Update] table=[person] + Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, $1 AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 + Filter: person.id = $2 + TableScan: person + ** Final Plan: + Dml: op=[Update] table=[person] + Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 + Filter: person.id = UInt32(1) + TableScan: person + "# + ); +} + +#[test] +fn test_insert_infer() { + let test = ParameterTest { + sql: "insert into person (id, first_name, last_name) values ($1, $2, $3)", + expected_types: vec![ + ("$1", Some(DataType::UInt32)), + ("$2", Some(DataType::Utf8)), + ("$3", Some(DataType::Utf8)), + ], + param_values: vec![ + ScalarValue::UInt32(Some(1)), + ScalarValue::from("Alan"), + ScalarValue::from("Turing"), + ], + }; + + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Dml: op=[Insert Into] table=[person] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 + Values: ($1, $2, $3) + ** Final Plan: + Dml: op=[Insert Into] table=[person] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 + Values: (UInt32(1) AS $1, Utf8("Alan") AS $2, Utf8("Turing") AS $3) + "# + ); +} + +#[test] +fn test_prepare_statement_insert_infer() { + let test = ParameterTest { + sql: "PREPARE my_plan AS insert into person (id, first_name, last_name) values ($1, $2, $3)", + expected_types: vec![ + ("$1", Some(DataType::UInt32)), + ("$2", Some(DataType::Utf8)), + ("$3", Some(DataType::Utf8)), + ], + param_values: vec![ + ScalarValue::UInt32(Some(1)), + ScalarValue::from("Alan"), + ScalarValue::from("Turing"), + ] + }; + assert_snapshot!( + test.run(), + @r#" + ** Initial Plan: + Prepare: "my_plan" [UInt32, Utf8, Utf8] + Dml: op=[Insert Into] table=[person] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 + Values: ($1, $2, $3) + ** Final Plan: + Dml: op=[Insert Into] table=[person] + Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 + Values: (UInt32(1) AS $1, Utf8("Alan") AS $2, Utf8("Turing") AS $3) + "# + ); +} + +#[test] +fn test_prepare_statement_to_plan_one_param() { + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $1"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32] + Projection: person.id, person.age + Filter: person.age = $1 + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[Int32]"#); + + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: person.id, person.age + Filter: person.age = Int32(10) + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_to_plan_data_type() { + let sql = "PREPARE my_plan(DOUBLE) AS SELECT id, age FROM person WHERE age = $1"; + + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + // age is defined as Int32 but prepare statement declares it as DOUBLE/Float64 + // Prepare statement and its logical plan should be created successfully + @r#" + Prepare: "my_plan" [Float64] + Projection: person.id, person.age + Filter: person.age = $1 + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[Float64]"#); + + /////////////////// + // replace params with values still succeed and use Float64 + let param_values = vec![ScalarValue::Float64(Some(10.0))]; + + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r" + Projection: person.id, person.age + Filter: person.age = Float64(10) + TableScan: person + " + ); +} + +#[test] +fn test_prepare_statement_to_plan_multi_params() { + let sql = "PREPARE my_plan(INT, STRING, DOUBLE, INT, DOUBLE, STRING) AS + SELECT id, age, $6 + FROM person + WHERE age IN ($1, $4) AND salary > $3 and salary < $5 OR first_name < $2"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32, Utf8View, Float64, Int32, Float64, Utf8View] + Projection: person.id, person.age, $6 + Filter: person.age IN ([$1, $4]) AND person.salary > $3 AND person.salary < $5 OR person.first_name < $2 + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[Int32, Utf8View, Float64, Int32, Float64, Utf8View]"#); + + /////////////////// + // replace params with values + let param_values = vec![ + ScalarValue::Int32(Some(10)), + ScalarValue::Utf8View(Some("abc".into())), + ScalarValue::Float64(Some(100.0)), + ScalarValue::Int32(Some(20)), + ScalarValue::Float64(Some(200.0)), + ScalarValue::Utf8View(Some("xyz".into())), + ]; + + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r#" + Projection: person.id, person.age, Utf8View("xyz") AS $6 + Filter: person.age IN ([Int32(10), Int32(20)]) AND person.salary > Float64(100) AND person.salary < Float64(200) OR person.first_name < Utf8View("abc") + TableScan: person + "# + ); +} + +#[test] +fn test_prepare_statement_to_plan_having() { + let sql = "PREPARE my_plan(INT, DOUBLE, DOUBLE, DOUBLE) AS + SELECT id, sum(age) + FROM person \ + WHERE salary > $2 + GROUP BY id + HAVING sum(age) < $1 AND sum(age) > 10 OR sum(age) in ($3, $4)\ + "; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int32, Float64, Float64, Float64] + Projection: person.id, sum(person.age) + Filter: sum(person.age) < $1 AND sum(person.age) > Int64(10) OR sum(person.age) IN ([$3, $4]) + Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]] + Filter: person.salary > $2 + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[Int32, Float64, Float64, Float64]"#); + + /////////////////// + // replace params with values + let param_values = vec![ + ScalarValue::Int32(Some(10)), + ScalarValue::Float64(Some(100.0)), + ScalarValue::Float64(Some(200.0)), + ScalarValue::Float64(Some(300.0)), + ]; + + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r#" + Projection: person.id, sum(person.age) + Filter: sum(person.age) < Int32(10) AND sum(person.age) > Int64(10) OR sum(person.age) IN ([Float64(200), Float64(300)]) + Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]] + Filter: person.salary > Float64(100) + TableScan: person + "# + ); +} + +#[test] +fn test_prepare_statement_to_plan_limit() { + let sql = "PREPARE my_plan(BIGINT, BIGINT) AS + SELECT id FROM person \ + OFFSET $1 LIMIT $2"; + let (plan, dt) = generate_prepare_stmt_and_data_types(sql); + assert_snapshot!( + plan, + @r#" + Prepare: "my_plan" [Int64, Int64] + Limit: skip=$1, fetch=$2 + Projection: person.id + TableScan: person + "# + ); + assert_snapshot!(dt, @r#"[Int64, Int64]"#); + + // replace params with values + let param_values = vec![ScalarValue::Int64(Some(10)), ScalarValue::Int64(Some(200))]; + let plan_with_params = plan.with_param_values(param_values).unwrap(); + assert_snapshot!( + plan_with_params, + @r#" + Limit: skip=10, fetch=200 + Projection: person.id + TableScan: person + "# + ); +} + +#[test] +fn test_prepare_statement_unknown_list_param() { + let sql = "SELECT id from person where id = $2"; + let plan = logical_plan(sql).unwrap(); + let param_values = ParamValues::List(vec![]); + let err = plan.replace_params_with_values(¶m_values).unwrap_err(); + assert_contains!( + err.to_string(), + "Error during planning: No value found for placeholder with id $2" + ); +} + +#[test] +fn test_prepare_statement_unknown_hash_param() { + let sql = "SELECT id from person where id = $bar"; + let plan = logical_plan(sql).unwrap(); + let param_values = ParamValues::Map(HashMap::new()); + let err = plan.replace_params_with_values(¶m_values).unwrap_err(); + assert_contains!( + err.to_string(), + "Error during planning: No value found for placeholder with name $bar" + ); +} + +#[test] +fn test_prepare_statement_bad_list_idx() { + let sql = "SELECT id from person where id = $foo"; + let plan = logical_plan(sql).unwrap(); + let param_values = ParamValues::List(vec![]); + + let err = plan.replace_params_with_values(¶m_values).unwrap_err(); + assert_contains!(err.to_string(), "Error during planning: Failed to parse placeholder id: invalid digit found in string"); +} diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 6568618ed757..d1af54a6f4ad 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -34,8 +34,8 @@ use datafusion_functions_nested::map::map_udf; use datafusion_functions_window::rank::rank_udwf; use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ - CustomDialectBuilder, DefaultDialect as UnparserDefaultDialect, DefaultDialect, - Dialect as UnparserDialect, MySqlDialect as UnparserMySqlDialect, + BigQueryDialect, CustomDialectBuilder, DefaultDialect as UnparserDefaultDialect, + DefaultDialect, Dialect as UnparserDialect, MySqlDialect as UnparserMySqlDialect, PostgreSqlDialect as UnparserPostgreSqlDialect, SqliteDialect, }; use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser}; @@ -170,6 +170,13 @@ fn roundtrip_statement() -> Result<()> { UNION ALL SELECT j3_string AS col1, j3_id AS id FROM j3 ) AS subquery GROUP BY col1, id ORDER BY col1 ASC, id ASC"#, + r#"SELECT col1, id FROM ( + SELECT j1_string AS col1, j1_id AS id FROM j1 + UNION + SELECT j2_string AS col1, j2_id AS id FROM j2 + UNION + SELECT j3_string AS col1, j3_id AS id FROM j3 + ) AS subquery ORDER BY col1 ASC, id ASC"#, "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), first_name from person", @@ -691,7 +698,7 @@ fn roundtrip_statement_with_dialect_27() -> Result<(), DataFusionError> { sql: "SELECT * FROM UNNEST([1,2,3])", parser_dialect: GenericDialect {}, unparser_dialect: UnparserDefaultDialect {}, - expected: @r#"SELECT "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))" FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))")"#, + expected: @r#"SELECT "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))" FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS derived_projection ("UNNEST(make_array(Int64(1),Int64(2),Int64(3)))")"#, ); Ok(()) } @@ -713,7 +720,7 @@ fn roundtrip_statement_with_dialect_29() -> Result<(), DataFusionError> { sql: "SELECT * FROM UNNEST([1,2,3]), j1", parser_dialect: GenericDialect {}, unparser_dialect: UnparserDefaultDialect {}, - expected: @r#"SELECT "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))", j1.j1_id, j1.j1_string FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") CROSS JOIN j1"#, + expected: @r#"SELECT "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))", j1.j1_id, j1.j1_string FROM (SELECT UNNEST([1, 2, 3]) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") AS derived_projection ("UNNEST(make_array(Int64(1),Int64(2),Int64(3)))") CROSS JOIN j1"#, ); Ok(()) } @@ -916,6 +923,41 @@ fn roundtrip_statement_with_dialect_45() -> Result<(), DataFusionError> { Ok(()) } +#[test] +fn roundtrip_statement_with_dialect_special_char_alias() -> Result<(), DataFusionError> { + roundtrip_statement_with_dialect_helper!( + sql: "select min(a) as \"min(a)\" from (select 1 as a)", + parser_dialect: GenericDialect {}, + unparser_dialect: BigQueryDialect {}, + expected: @r#"SELECT min(`a`) AS `min_40a_41` FROM (SELECT 1 AS `a`)"#, + ); + roundtrip_statement_with_dialect_helper!( + sql: "select a as \"a*\", b as \"b@\" from (select 1 as a , 2 as b)", + parser_dialect: GenericDialect {}, + unparser_dialect: BigQueryDialect {}, + expected: @r#"SELECT `a` AS `a_42`, `b` AS `b_64` FROM (SELECT 1 AS `a`, 2 AS `b`)"#, + ); + roundtrip_statement_with_dialect_helper!( + sql: "select a as \"a*\", b , c as \"c@\" from (select 1 as a , 2 as b, 3 as c)", + parser_dialect: GenericDialect {}, + unparser_dialect: BigQueryDialect {}, + expected: @r#"SELECT `a` AS `a_42`, `b`, `c` AS `c_64` FROM (SELECT 1 AS `a`, 2 AS `b`, 3 AS `c`)"#, + ); + roundtrip_statement_with_dialect_helper!( + sql: "select * from (select a as \"a*\", b as \"b@\" from (select 1 as a , 2 as b)) where \"a*\" = 1", + parser_dialect: GenericDialect {}, + unparser_dialect: BigQueryDialect {}, + expected: @r#"SELECT `a_42`, `b_64` FROM (SELECT `a` AS `a_42`, `b` AS `b_64` FROM (SELECT 1 AS `a`, 2 AS `b`)) WHERE (`a_42` = 1)"#, + ); + roundtrip_statement_with_dialect_helper!( + sql: "select * from (select a as \"a*\", b as \"b@\" from (select 1 as a , 2 as b)) where \"a*\" = 1", + parser_dialect: GenericDialect {}, + unparser_dialect: UnparserDefaultDialect {}, + expected: @r#"SELECT "a*", "b@" FROM (SELECT a AS "a*", b AS "b@" FROM (SELECT 1 AS a, 2 AS b)) WHERE ("a*" = 1)"#, + ); + Ok(()) +} + #[test] fn test_unnest_logical_plan() -> Result<()> { let query = "select unnest(struct_col), unnest(array_col), struct_col, array_col from unnest_table"; @@ -1820,6 +1862,51 @@ fn test_order_by_to_sql_3() { ); } +#[test] +fn test_complex_order_by_with_grouping() -> Result<()> { + let state = MockSessionState::default().with_aggregate_function(grouping_udaf()); + + let context = MockContextProvider { state }; + let sql_to_rel = SqlToRel::new(&context); + + // This SQL is based on a simplified version of the TPC-DS query 36. + let statement = Parser::new(&GenericDialect {}) + .try_with_sql( + r#"SELECT + j1_id, + j1_string, + grouping(j1_id) + grouping(j1_string) as lochierarchy + FROM + j1 + GROUP BY + ROLLUP (j1_id, j1_string) + ORDER BY + grouping(j1_id) + grouping(j1_string) DESC, + CASE + WHEN grouping(j1_id) + grouping(j1_string) = 0 THEN j1_id + END + LIMIT 100"#, + )? + .parse_statement()?; + + let plan = sql_to_rel.sql_statement_to_plan(statement)?; + let unparser = Unparser::default(); + let sql = unparser.plan_to_sql(&plan)?; + insta::with_settings!({ + filters => vec![ + // Force a deterministic order for the grouping pairs + (r#"grouping\(j1\.(?:j1_id|j1_string)\),\s*grouping\(j1\.(?:j1_id|j1_string)\)"#, "grouping(j1.j1_string), grouping(j1.j1_id)") + ], + }, { + assert_snapshot!( + sql, + @r#"SELECT j1.j1_id, j1.j1_string, lochierarchy FROM (SELECT j1.j1_id, j1.j1_string, (grouping(j1.j1_id) + grouping(j1.j1_string)) AS lochierarchy, grouping(j1.j1_string), grouping(j1.j1_id) FROM j1 GROUP BY ROLLUP (j1.j1_id, j1.j1_string) ORDER BY (grouping(j1.j1_id) + grouping(j1.j1_string)) DESC NULLS FIRST, CASE WHEN ((grouping(j1.j1_id) + grouping(j1.j1_string)) = 0) THEN j1.j1_id END ASC NULLS LAST) LIMIT 100"# + ); + }); + + Ok(()) +} + #[test] fn test_aggregation_to_sql() { let sql = r#"SELECT id, first_name, @@ -2342,3 +2429,170 @@ fn test_unparse_right_anti_join() -> Result<()> { ); Ok(()) } + +#[test] +fn test_unparse_cross_join_with_table_scan_projection() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("k", DataType::Int32, false), + Field::new("v", DataType::Int32, false), + ]); + // Cross Join: + // SubqueryAlias: t1 + // TableScan: test projection=[v] + // SubqueryAlias: t2 + // TableScan: test projection=[v] + let table_scan1 = table_scan(Some("test"), &schema, Some(vec![1]))?.build()?; + let table_scan2 = table_scan(Some("test"), &schema, Some(vec![1]))?.build()?; + let plan = LogicalPlanBuilder::from(subquery_alias(table_scan1, "t1")?) + .cross_join(subquery_alias(table_scan2, "t2")?)? + .build()?; + let unparser = Unparser::new(&UnparserPostgreSqlDialect {}); + let sql = unparser.plan_to_sql(&plan)?; + assert_snapshot!( + sql, + @r#"SELECT "t1"."v", "t2"."v" FROM "test" AS "t1" CROSS JOIN "test" AS "t2""# + ); + Ok(()) +} + +#[test] +fn test_unparse_inner_join_with_table_scan_projection() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("k", DataType::Int32, false), + Field::new("v", DataType::Int32, false), + ]); + // Inner Join: + // SubqueryAlias: t1 + // TableScan: test projection=[v] + // SubqueryAlias: t2 + // TableScan: test projection=[v] + let table_scan1 = table_scan(Some("test"), &schema, Some(vec![1]))?.build()?; + let table_scan2 = table_scan(Some("test"), &schema, Some(vec![1]))?.build()?; + let plan = LogicalPlanBuilder::from(subquery_alias(table_scan1, "t1")?) + .join_on( + subquery_alias(table_scan2, "t2")?, + datafusion_expr::JoinType::Inner, + vec![col("t1.v").eq(col("t2.v"))], + )? + .build()?; + let unparser = Unparser::new(&UnparserPostgreSqlDialect {}); + let sql = unparser.plan_to_sql(&plan)?; + assert_snapshot!( + sql, + @r#"SELECT "t1"."v", "t2"."v" FROM "test" AS "t1" INNER JOIN "test" AS "t2" ON ("t1"."v" = "t2"."v")"# + ); + Ok(()) +} + +#[test] +fn test_unparse_left_semi_join_with_table_scan_projection() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("k", DataType::Int32, false), + Field::new("v", DataType::Int32, false), + ]); + // LeftSemi Join: + // SubqueryAlias: t1 + // TableScan: test projection=[v] + // SubqueryAlias: t2 + // TableScan: test projection=[v] + let table_scan1 = table_scan(Some("test"), &schema, Some(vec![1]))?.build()?; + let table_scan2 = table_scan(Some("test"), &schema, Some(vec![1]))?.build()?; + let plan = LogicalPlanBuilder::from(subquery_alias(table_scan1, "t1")?) + .join_on( + subquery_alias(table_scan2, "t2")?, + datafusion_expr::JoinType::LeftSemi, + vec![col("t1.v").eq(col("t2.v"))], + )? + .build()?; + let unparser = Unparser::new(&UnparserPostgreSqlDialect {}); + let sql = unparser.plan_to_sql(&plan)?; + assert_snapshot!( + sql, + @r#"SELECT "t1"."v" FROM "test" AS "t1" WHERE EXISTS (SELECT 1 FROM "test" AS "t2" WHERE ("t1"."v" = "t2"."v"))"# + ); + Ok(()) +} + +#[test] +fn test_like_filter() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"SELECT first_name FROM person WHERE first_name LIKE '%John%'"#, + ); + assert_snapshot!( + statement, + @"SELECT person.first_name FROM person WHERE person.first_name LIKE '%John%'" + ); +} + +#[test] +fn test_ilike_filter() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"SELECT first_name FROM person WHERE first_name ILIKE '%john%'"#, + ); + assert_snapshot!( + statement, + @"SELECT person.first_name FROM person WHERE person.first_name ILIKE '%john%'" + ); +} + +#[test] +fn test_not_like_filter() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"SELECT first_name FROM person WHERE first_name NOT LIKE 'A%'"#, + ); + assert_snapshot!( + statement, + @"SELECT person.first_name FROM person WHERE person.first_name NOT LIKE 'A%'" + ); +} + +#[test] +fn test_not_ilike_filter() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"SELECT first_name FROM person WHERE first_name NOT ILIKE 'a%'"#, + ); + assert_snapshot!( + statement, + @"SELECT person.first_name FROM person WHERE person.first_name NOT ILIKE 'a%'" + ); +} + +#[test] +fn test_like_filter_with_escape() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"SELECT first_name FROM person WHERE first_name LIKE 'A!_%' ESCAPE '!'"#, + ); + assert_snapshot!( + statement, + @"SELECT person.first_name FROM person WHERE person.first_name LIKE 'A!_%' ESCAPE '!'" + ); +} + +#[test] +fn test_not_like_filter_with_escape() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"SELECT first_name FROM person WHERE first_name NOT LIKE 'A!_%' ESCAPE '!'"#, + ); + assert_snapshot!( + statement, + @"SELECT person.first_name FROM person WHERE person.first_name NOT LIKE 'A!_%' ESCAPE '!'" + ); +} + +#[test] +fn test_not_ilike_filter_with_escape() { + let statement = generate_round_trip_statement( + GenericDialect {}, + r#"SELECT first_name FROM person WHERE first_name NOT ILIKE 'A!_%' ESCAPE '!'"#, + ); + assert_snapshot!( + statement, + @"SELECT person.first_name FROM person WHERE person.first_name NOT ILIKE 'A!_%' ESCAPE '!'" + ); +} diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 2804a1de0606..c82239d9b455 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -17,21 +17,16 @@ use std::any::Any; #[cfg(test)] -use std::collections::HashMap; use std::sync::Arc; use std::vec; use arrow::datatypes::{TimeUnit::Nanosecond, *}; use common::MockContextProvider; -use datafusion_common::{ - assert_contains, DataFusionError, ParamValues, Result, ScalarValue, -}; +use datafusion_common::{assert_contains, DataFusionError, Result}; use datafusion_expr::{ - col, - logical_plan::{LogicalPlan, Prepare}, - test::function_stub::sum_udaf, - ColumnarValue, CreateIndex, DdlStatement, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, Signature, Statement, Volatility, + col, logical_plan::LogicalPlan, test::function_stub::sum_udaf, ColumnarValue, + CreateIndex, DdlStatement, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_functions::{string, unicode}; use datafusion_sql::{ @@ -761,7 +756,7 @@ fn plan_delete() { plan, @r#" Dml: op=[Delete] table=[person] - Filter: id = Int64(1) + Filter: person.id = Int64(1) TableScan: person "# ); @@ -776,7 +771,7 @@ fn plan_delete_quoted_identifier_case_sensitive() { plan, @r#" Dml: op=[Delete] table=[SomeCatalog.SomeSchema.UPPERCASE_test] - Filter: Id = Int64(1) + Filter: SomeCatalog.SomeSchema.UPPERCASE_test.Id = Int64(1) TableScan: SomeCatalog.SomeSchema.UPPERCASE_test "# ); @@ -3360,7 +3355,7 @@ fn parse_decimals_parser_options() -> ParserOptions { parse_float_as_decimal: true, enable_ident_normalization: false, support_varchar_with_length: false, - map_varchar_to_utf8view: false, + map_string_types_to_utf8view: true, enable_options_value_normalization: false, collect_spans: false, } @@ -3371,7 +3366,7 @@ fn ident_normalization_parser_options_no_ident_normalization() -> ParserOptions parse_float_as_decimal: true, enable_ident_normalization: false, support_varchar_with_length: false, - map_varchar_to_utf8view: false, + map_string_types_to_utf8view: true, enable_options_value_normalization: false, collect_spans: false, } @@ -3382,23 +3377,12 @@ fn ident_normalization_parser_options_ident_normalization() -> ParserOptions { parse_float_as_decimal: true, enable_ident_normalization: true, support_varchar_with_length: false, - map_varchar_to_utf8view: false, + map_string_types_to_utf8view: true, enable_options_value_normalization: false, collect_spans: false, } } -fn generate_prepare_stmt_and_data_types(sql: &str) -> (LogicalPlan, String) { - let plan = logical_plan(sql).unwrap(); - let data_types = match &plan { - LogicalPlan::Statement(Statement::Prepare(Prepare { data_types, .. })) => { - format!("{data_types:?}") - } - _ => panic!("Expected a Prepare statement"), - }; - (plan, data_types) -} - #[test] fn select_partially_qualified_column() { let sql = "SELECT person.first_name FROM public.person"; @@ -4330,712 +4314,6 @@ Projection: p1.id, p1.age, p2.id ); } -#[test] -fn test_prepare_statement_to_plan_panic_param_format() { - // param is not number following the $ sign - // panic due to error returned from the parser - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $foo"; - - assert_snapshot!( - logical_plan(sql).unwrap_err().strip_backtrace(), - @r###" - Error during planning: Invalid placeholder, not a number: $foo - "### - ); -} - -#[test] -fn test_prepare_statement_to_plan_panic_param_zero() { - // param is zero following the $ sign - // panic due to error returned from the parser - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $0"; - - assert_snapshot!( - logical_plan(sql).unwrap_err().strip_backtrace(), - @r###" - Error during planning: Invalid placeholder, zero is not a valid index: $0 - "### - ); -} - -#[test] -fn test_prepare_statement_to_plan_panic_prepare_wrong_syntax() { - // param is not number following the $ sign - // panic due to error returned from the parser - let sql = "PREPARE AS SELECT id, age FROM person WHERE age = $foo"; - assert!(logical_plan(sql) - .unwrap_err() - .strip_backtrace() - .contains("Expected: AS, found: SELECT")) -} - -#[test] -fn test_prepare_statement_to_plan_panic_no_relation_and_constant_param() { - let sql = "PREPARE my_plan(INT) AS SELECT id + $1"; - - let plan = logical_plan(sql).unwrap_err().strip_backtrace(); - assert_snapshot!( - plan, - @r"Schema error: No field named id." - ); -} - -#[test] -fn test_prepare_statement_should_infer_types() { - // only provide 1 data type while using 2 params - let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1 + $2"; - let plan = logical_plan(sql).unwrap(); - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::Int64)), - ]); - assert_eq!(actual_types, expected_types); -} - -#[test] -fn test_non_prepare_statement_should_infer_types() { - // Non prepared statements (like SELECT) should also have their parameter types inferred - let sql = "SELECT 1 + $1"; - let plan = logical_plan(sql).unwrap(); - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - // constant 1 is inferred to be int64 - ("$1".to_string(), Some(DataType::Int64)), - ]); - assert_eq!(actual_types, expected_types); -} - -#[test] -#[should_panic( - expected = "Expected: [NOT] NULL | TRUE | FALSE | DISTINCT | [form] NORMALIZED FROM after IS, found: $1" -)] -fn test_prepare_statement_to_plan_panic_is_param() { - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age is $1"; - logical_plan(sql).unwrap(); -} - -#[test] -fn test_prepare_statement_to_plan_no_param() { - // no embedded parameter but still declare it - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - @r#" - Prepare: "my_plan" [Int32] - Projection: person.id, person.age - Filter: person.age = Int64(10) - TableScan: person - "# - ); - assert_snapshot!(dt, @r#"[Int32]"#); - - /////////////////// - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age = Int64(10) - TableScan: person - " - ); - - ////////////////////////////////////////// - // no embedded parameter and no declare it - let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - @r#" - Prepare: "my_plan" [] - Projection: person.id, person.age - Filter: person.age = Int64(10) - TableScan: person - "# - ); - assert_snapshot!(dt, @r#"[]"#); - - /////////////////// - // replace params with values - let param_values: Vec = vec![]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age = Int64(10) - TableScan: person - " - ); -} - -#[test] -fn test_prepare_statement_to_plan_one_param_no_value_panic() { - // no embedded parameter but still declare it - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; - let plan = logical_plan(sql).unwrap(); - // declare 1 param but provide 0 - let param_values: Vec = vec![]; - - assert_snapshot!( - plan.with_param_values(param_values) - .unwrap_err() - .strip_backtrace(), - @r###" - Error during planning: Expected 1 parameters, got 0 - "###); -} - -#[test] -fn test_prepare_statement_to_plan_one_param_one_value_different_type_panic() { - // no embedded parameter but still declare it - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; - let plan = logical_plan(sql).unwrap(); - // declare 1 param but provide 0 - let param_values = vec![ScalarValue::Float64(Some(20.0))]; - - assert_snapshot!( - plan.with_param_values(param_values) - .unwrap_err() - .strip_backtrace(), - @r###" - Error during planning: Expected parameter of type Int32, got Float64 at index 0 - "### - ); -} - -#[test] -fn test_prepare_statement_to_plan_no_param_on_value_panic() { - // no embedded parameter but still declare it - let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; - let plan = logical_plan(sql).unwrap(); - // declare 1 param but provide 0 - let param_values = vec![ScalarValue::Int32(Some(10))]; - - assert_snapshot!( - plan.with_param_values(param_values) - .unwrap_err() - .strip_backtrace(), - @r###" - Error during planning: Expected 0 parameters, got 1 - "### - ); -} - -#[test] -fn test_prepare_statement_to_plan_params_as_constants() { - let sql = "PREPARE my_plan(INT) AS SELECT $1"; - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - @r#" - Prepare: "my_plan" [Int32] - Projection: $1 - EmptyRelation - "# - ); - assert_snapshot!(dt, @r#"[Int32]"#); - - /////////////////// - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r" - Projection: Int32(10) AS $1 - EmptyRelation - " - ); - - /////////////////////////////////////// - let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1"; - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - @r#" - Prepare: "my_plan" [Int32] - Projection: Int64(1) + $1 - EmptyRelation - "# - ); - assert_snapshot!(dt, @r#"[Int32]"#); - - /////////////////// - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r" - Projection: Int64(1) + Int32(10) AS Int64(1) + $1 - EmptyRelation - " - ); - - /////////////////////////////////////// - let sql = "PREPARE my_plan(INT, DOUBLE) AS SELECT 1 + $1 + $2"; - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - @r#" - Prepare: "my_plan" [Int32, Float64] - Projection: Int64(1) + $1 + $2 - EmptyRelation - "# - ); - assert_snapshot!(dt, @r#"[Int32, Float64]"#); - - /////////////////// - // replace params with values - let param_values = vec![ - ScalarValue::Int32(Some(10)), - ScalarValue::Float64(Some(10.0)), - ]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r" - Projection: Int64(1) + Int32(10) + Float64(10) AS Int64(1) + $1 + $2 - EmptyRelation - " - ); -} - -#[test] -fn test_infer_types_from_join() { - let sql = - "SELECT id, order_id FROM person JOIN orders ON id = customer_id and age = $1"; - - let plan = logical_plan(sql).unwrap(); - assert_snapshot!( - plan, - @r#" - Projection: person.id, orders.order_id - Inner Join: Filter: person.id = orders.customer_id AND person.age = $1 - TableScan: person - TableScan: orders - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, orders.order_id - Inner Join: Filter: person.id = orders.customer_id AND person.age = Int32(10) - TableScan: person - TableScan: orders - " - ); -} - -#[test] -fn test_infer_types_from_predicate() { - let sql = "SELECT id, age FROM person WHERE age = $1"; - let plan = logical_plan(sql).unwrap(); - assert_snapshot!( - plan, - @r#" - Projection: person.id, person.age - Filter: person.age = $1 - TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::Int32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age = Int32(10) - TableScan: person - " - ); -} - -#[test] -fn test_infer_types_from_between_predicate() { - let sql = "SELECT id, age FROM person WHERE age BETWEEN $1 AND $2"; - - let plan = logical_plan(sql).unwrap(); - assert_snapshot!( - plan, - @r#" - Projection: person.id, person.age - Filter: person.age BETWEEN $1 AND $2 - TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::Int32)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10)), ScalarValue::Int32(Some(30))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age BETWEEN Int32(10) AND Int32(30) - TableScan: person - " - ); -} - -#[test] -fn test_infer_types_subquery() { - let sql = "SELECT id, age FROM person WHERE age = (select max(age) from person where id = $1)"; - - let plan = logical_plan(sql).unwrap(); - assert_snapshot!( - plan, - @r#" - Projection: person.id, person.age - Filter: person.age = () - Subquery: - Projection: max(person.age) - Aggregate: groupBy=[[]], aggr=[[max(person.age)]] - Filter: person.id = $1 - TableScan: person - TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([("$1".to_string(), Some(DataType::UInt32))]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::UInt32(Some(10))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age = () - Subquery: - Projection: max(person.age) - Aggregate: groupBy=[[]], aggr=[[max(person.age)]] - Filter: person.id = UInt32(10) - TableScan: person - TableScan: person - " - ); -} - -#[test] -fn test_update_infer() { - let sql = "update person set age=$1 where id=$2"; - - let plan = logical_plan(sql).unwrap(); - assert_snapshot!( - plan, - @r#" - Dml: op=[Update] table=[person] - Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, $1 AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 - Filter: person.id = $2 - TableScan: person - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::Int32)), - ("$2".to_string(), Some(DataType::UInt32)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(42)), ScalarValue::UInt32(Some(1))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - - assert_snapshot!( - plan_with_params, - @r" - Dml: op=[Update] table=[person] - Projection: person.id AS id, person.first_name AS first_name, person.last_name AS last_name, Int32(42) AS age, person.state AS state, person.salary AS salary, person.birth_date AS birth_date, person.😀 AS 😀 - Filter: person.id = UInt32(1) - TableScan: person - " - ); -} - -#[test] -fn test_insert_infer() { - let sql = "insert into person (id, first_name, last_name) values ($1, $2, $3)"; - let plan = logical_plan(sql).unwrap(); - assert_snapshot!( - plan, - @r#" - Dml: op=[Insert Into] table=[person] - Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 - Values: ($1, $2, $3) - "# - ); - - let actual_types = plan.get_parameter_types().unwrap(); - let expected_types = HashMap::from([ - ("$1".to_string(), Some(DataType::UInt32)), - ("$2".to_string(), Some(DataType::Utf8)), - ("$3".to_string(), Some(DataType::Utf8)), - ]); - assert_eq!(actual_types, expected_types); - - // replace params with values - let param_values = vec![ - ScalarValue::UInt32(Some(1)), - ScalarValue::from("Alan"), - ScalarValue::from("Turing"), - ]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r#" - Dml: op=[Insert Into] table=[person] - Projection: column1 AS id, column2 AS first_name, column3 AS last_name, CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀 - Values: (UInt32(1) AS $1, Utf8("Alan") AS $2, Utf8("Turing") AS $3) - "# - ); -} - -#[test] -fn test_prepare_statement_to_plan_one_param() { - let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $1"; - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - @r#" - Prepare: "my_plan" [Int32] - Projection: person.id, person.age - Filter: person.age = $1 - TableScan: person - "# - ); - assert_snapshot!(dt, @r#"[Int32]"#); - - /////////////////// - // replace params with values - let param_values = vec![ScalarValue::Int32(Some(10))]; - - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age = Int32(10) - TableScan: person - " - ); -} - -#[test] -fn test_prepare_statement_to_plan_data_type() { - let sql = "PREPARE my_plan(DOUBLE) AS SELECT id, age FROM person WHERE age = $1"; - - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - // age is defined as Int32 but prepare statement declares it as DOUBLE/Float64 - // Prepare statement and its logical plan should be created successfully - @r#" - Prepare: "my_plan" [Float64] - Projection: person.id, person.age - Filter: person.age = $1 - TableScan: person - "# - ); - assert_snapshot!(dt, @r#"[Float64]"#); - - /////////////////// - // replace params with values still succeed and use Float64 - let param_values = vec![ScalarValue::Float64(Some(10.0))]; - - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r" - Projection: person.id, person.age - Filter: person.age = Float64(10) - TableScan: person - " - ); -} - -#[test] -fn test_prepare_statement_to_plan_multi_params() { - let sql = "PREPARE my_plan(INT, STRING, DOUBLE, INT, DOUBLE, STRING) AS - SELECT id, age, $6 - FROM person - WHERE age IN ($1, $4) AND salary > $3 and salary < $5 OR first_name < $2"; - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - @r#" - Prepare: "my_plan" [Int32, Utf8, Float64, Int32, Float64, Utf8] - Projection: person.id, person.age, $6 - Filter: person.age IN ([$1, $4]) AND person.salary > $3 AND person.salary < $5 OR person.first_name < $2 - TableScan: person - "# - ); - assert_snapshot!(dt, @r#"[Int32, Utf8, Float64, Int32, Float64, Utf8]"#); - - /////////////////// - // replace params with values - let param_values = vec![ - ScalarValue::Int32(Some(10)), - ScalarValue::from("abc"), - ScalarValue::Float64(Some(100.0)), - ScalarValue::Int32(Some(20)), - ScalarValue::Float64(Some(200.0)), - ScalarValue::from("xyz"), - ]; - - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r#" - Projection: person.id, person.age, Utf8("xyz") AS $6 - Filter: person.age IN ([Int32(10), Int32(20)]) AND person.salary > Float64(100) AND person.salary < Float64(200) OR person.first_name < Utf8("abc") - TableScan: person - "# - ); -} - -#[test] -fn test_prepare_statement_to_plan_having() { - let sql = "PREPARE my_plan(INT, DOUBLE, DOUBLE, DOUBLE) AS - SELECT id, sum(age) - FROM person \ - WHERE salary > $2 - GROUP BY id - HAVING sum(age) < $1 AND sum(age) > 10 OR sum(age) in ($3, $4)\ - "; - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - @r#" - Prepare: "my_plan" [Int32, Float64, Float64, Float64] - Projection: person.id, sum(person.age) - Filter: sum(person.age) < $1 AND sum(person.age) > Int64(10) OR sum(person.age) IN ([$3, $4]) - Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]] - Filter: person.salary > $2 - TableScan: person - "# - ); - assert_snapshot!(dt, @r#"[Int32, Float64, Float64, Float64]"#); - - /////////////////// - // replace params with values - let param_values = vec![ - ScalarValue::Int32(Some(10)), - ScalarValue::Float64(Some(100.0)), - ScalarValue::Float64(Some(200.0)), - ScalarValue::Float64(Some(300.0)), - ]; - - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r#" - Projection: person.id, sum(person.age) - Filter: sum(person.age) < Int32(10) AND sum(person.age) > Int64(10) OR sum(person.age) IN ([Float64(200), Float64(300)]) - Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]] - Filter: person.salary > Float64(100) - TableScan: person - "# - ); -} - -#[test] -fn test_prepare_statement_to_plan_limit() { - let sql = "PREPARE my_plan(BIGINT, BIGINT) AS - SELECT id FROM person \ - OFFSET $1 LIMIT $2"; - let (plan, dt) = generate_prepare_stmt_and_data_types(sql); - assert_snapshot!( - plan, - @r#" - Prepare: "my_plan" [Int64, Int64] - Limit: skip=$1, fetch=$2 - Projection: person.id - TableScan: person - "# - ); - assert_snapshot!(dt, @r#"[Int64, Int64]"#); - - // replace params with values - let param_values = vec![ScalarValue::Int64(Some(10)), ScalarValue::Int64(Some(200))]; - let plan_with_params = plan.with_param_values(param_values).unwrap(); - assert_snapshot!( - plan_with_params, - @r#" - Limit: skip=10, fetch=200 - Projection: person.id - TableScan: person - "# - ); -} - -#[test] -fn test_prepare_statement_unknown_list_param() { - let sql = "SELECT id from person where id = $2"; - let plan = logical_plan(sql).unwrap(); - let param_values = ParamValues::List(vec![]); - let err = plan.replace_params_with_values(¶m_values).unwrap_err(); - assert_contains!( - err.to_string(), - "Error during planning: No value found for placeholder with id $2" - ); -} - -#[test] -fn test_prepare_statement_unknown_hash_param() { - let sql = "SELECT id from person where id = $bar"; - let plan = logical_plan(sql).unwrap(); - let param_values = ParamValues::Map(HashMap::new()); - let err = plan.replace_params_with_values(¶m_values).unwrap_err(); - assert_contains!( - err.to_string(), - "Error during planning: No value found for placeholder with name $bar" - ); -} - -#[test] -fn test_prepare_statement_bad_list_idx() { - let sql = "SELECT id from person where id = $foo"; - let plan = logical_plan(sql).unwrap(); - let param_values = ParamValues::List(vec![]); - - let err = plan.replace_params_with_values(¶m_values).unwrap_err(); - assert_contains!(err.to_string(), "Error during planning: Failed to parse placeholder id: invalid digit found in string"); -} - #[test] fn test_inner_join_with_cast_key() { let sql = "SELECT person.id, person.age diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index 16cd3d5b3aa4..54c53f7375c4 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -42,8 +42,10 @@ async-trait = { workspace = true } bigdecimal = { workspace = true } bytes = { workspace = true, optional = true } chrono = { workspace = true, optional = true } -clap = { version = "4.5.35", features = ["derive", "env"] } +clap = { version = "4.5.40", features = ["derive", "env"] } datafusion = { workspace = true, default-features = true, features = ["avro"] } +datafusion-spark = { workspace = true, default-features = true } +datafusion-substrait = { workspace = true, default-features = true } futures = { workspace = true } half = { workspace = true, default-features = true } indicatif = "0.17" @@ -52,14 +54,14 @@ log = { workspace = true } object_store = { workspace = true } postgres-protocol = { version = "0.6.7", optional = true } postgres-types = { version = "0.2.8", features = ["derive", "with-chrono-0_4"], optional = true } -rust_decimal = { version = "1.37.1", features = ["tokio-pg"] } +rust_decimal = { version = "1.37.2", features = ["tokio-pg"] } # When updating the following dependency verify that sqlite test file regeneration works correctly # by running the regenerate_sqlite_files.sh script. -sqllogictest = "0.28.0" +sqllogictest = "0.28.3" sqlparser = { workspace = true } tempfile = { workspace = true } -testcontainers = { version = "0.23", features = ["default"], optional = true } -testcontainers-modules = { version = "0.11", features = ["postgres"], optional = true } +testcontainers = { version = "0.24", features = ["default"], optional = true } +testcontainers-modules = { version = "0.12", features = ["postgres"], optional = true } thiserror = "2.0.12" tokio = { workspace = true } tokio-postgres = { version = "0.7.12", optional = true } diff --git a/datafusion/sqllogictest/README.md b/datafusion/sqllogictest/README.md index 77162f4001ae..3fdb29c9d5cd 100644 --- a/datafusion/sqllogictest/README.md +++ b/datafusion/sqllogictest/README.md @@ -156,6 +156,14 @@ sqllogictests also supports `cargo test` style substring matches on file names t cargo test --test sqllogictests -- information ``` +Additionally, executing specific tests within a file is also supported. Tests are identified by line number within +the .slt file; for example, the following command will run the test in line `709` for file `information.slt` along +with any other preparatory statements: + +```shell +cargo test --test sqllogictests -- information:709 +``` + ## Running tests: Postgres compatibility Test files that start with prefix `pg_compat_` verify compatibility @@ -283,6 +291,27 @@ Tests that need to write temporary files should write (only) to this directory to ensure they do not interfere with others concurrently running tests. +## Running tests: Substrait round-trip mode + +This mode will run all the .slt test files in validation mode, adding a Substrait conversion round-trip for each +generated DataFusion logical plan (SQL statement → DF logical → Substrait → DF logical → DF physical → execute). + +Not all statements will be round-tripped, some statements like CREATE, INSERT, SET or EXPLAIN statements will be +issued as is, but any other statement will be round-tripped to/from Substrait. + +_WARNING_: as there are still a lot of failures in this mode (https://github.com/apache/datafusion/issues/16248), +it is not enforced in the CI, instead, it needs to be run manually with the following command: + +```shell +cargo test --test sqllogictests -- --substrait-round-trip +``` + +For focusing on one specific failing test, a file:line filter can be used: + +```shell +cargo test --test sqllogictests -- --substrait-round-trip binary.slt:23 +``` + ## `.slt` file format [`sqllogictest`] was originally written for SQLite to verify the diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index 21dfe2ee08f4..d5fce1a7cdb2 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -20,8 +20,9 @@ use datafusion::common::instant::Instant; use datafusion::common::utils::get_available_parallelism; use datafusion::common::{exec_err, DataFusionError, Result}; use datafusion_sqllogictest::{ - df_value_validator, read_dir_recursive, setup_scratch_dir, value_normalizer, - DataFusion, TestContext, + df_value_validator, read_dir_recursive, setup_scratch_dir, should_skip_file, + should_skip_record, value_normalizer, DataFusion, DataFusionSubstraitRoundTrip, + Filter, TestContext, }; use futures::stream::StreamExt; use indicatif::{ @@ -31,8 +32,8 @@ use itertools::Itertools; use log::Level::Info; use log::{info, log_enabled}; use sqllogictest::{ - parse_file, strict_column_validator, AsyncDB, Condition, Normalizer, Record, - Validator, + parse_file, strict_column_validator, AsyncDB, Condition, MakeConnection, Normalizer, + Record, Validator, }; #[cfg(feature = "postgres")] @@ -50,6 +51,7 @@ const TEST_DIRECTORY: &str = "test_files/"; const DATAFUSION_TESTING_TEST_DIRECTORY: &str = "../../datafusion-testing/data/"; const PG_COMPAT_FILE_PREFIX: &str = "pg_compat_"; const SQLITE_PREFIX: &str = "sqlite"; +const ERRS_PER_FILE_LIMIT: usize = 10; pub fn main() -> Result<()> { tokio::runtime::Builder::new_multi_thread() @@ -101,6 +103,7 @@ async fn run_tests() -> Result<()> { // to stdout and return OK so they can continue listing other tests. return Ok(()); } + options.warn_on_ignored(); #[cfg(feature = "postgres")] @@ -134,27 +137,49 @@ async fn run_tests() -> Result<()> { let m_clone = m.clone(); let m_style_clone = m_style.clone(); + let filters = options.filters.clone(); SpawnedTask::spawn(async move { - match (options.postgres_runner, options.complete) { - (false, false) => { - run_test_file(test_file, validator, m_clone, m_style_clone) - .await? + match ( + options.postgres_runner, + options.complete, + options.substrait_round_trip, + ) { + (_, _, true) => { + run_test_file_substrait_round_trip( + test_file, + validator, + m_clone, + m_style_clone, + filters.as_ref(), + ) + .await? } - (false, true) => { + (false, false, _) => { + run_test_file( + test_file, + validator, + m_clone, + m_style_clone, + filters.as_ref(), + ) + .await? + } + (false, true, _) => { run_complete_file(test_file, validator, m_clone, m_style_clone) .await? } - (true, false) => { + (true, false, _) => { run_test_file_with_postgres( test_file, validator, m_clone, m_style_clone, + filters.as_ref(), ) .await? } - (true, true) => { + (true, true, _) => { run_complete_file_with_postgres( test_file, validator, @@ -201,11 +226,51 @@ async fn run_tests() -> Result<()> { } } +async fn run_test_file_substrait_round_trip( + test_file: TestFile, + validator: Validator, + mp: MultiProgress, + mp_style: ProgressStyle, + filters: &[Filter], +) -> Result<()> { + let TestFile { + path, + relative_path, + } = test_file; + let Some(test_ctx) = TestContext::try_new_for_test_file(&relative_path).await else { + info!("Skipping: {}", path.display()); + return Ok(()); + }; + setup_scratch_dir(&relative_path)?; + + let count: u64 = get_record_count(&path, "DatafusionSubstraitRoundTrip".to_string()); + let pb = mp.add(ProgressBar::new(count)); + + pb.set_style(mp_style); + pb.set_message(format!("{:?}", &relative_path)); + + let mut runner = sqllogictest::Runner::new(|| async { + Ok(DataFusionSubstraitRoundTrip::new( + test_ctx.session_ctx().clone(), + relative_path.clone(), + pb.clone(), + )) + }); + runner.add_label("DatafusionSubstraitRoundTrip"); + runner.with_column_validator(strict_column_validator); + runner.with_normalizer(value_normalizer); + runner.with_validator(validator); + let res = run_file_in_runner(path, runner, filters).await; + pb.finish_and_clear(); + res +} + async fn run_test_file( test_file: TestFile, validator: Validator, mp: MultiProgress, mp_style: ProgressStyle, + filters: &[Filter], ) -> Result<()> { let TestFile { path, @@ -234,15 +299,49 @@ async fn run_test_file( runner.with_column_validator(strict_column_validator); runner.with_normalizer(value_normalizer); runner.with_validator(validator); + let result = run_file_in_runner(path, runner, filters).await; + pb.finish_and_clear(); + result +} - let res = runner - .run_file_async(path) - .await - .map_err(|e| DataFusionError::External(Box::new(e))); +async fn run_file_in_runner>( + path: PathBuf, + mut runner: sqllogictest::Runner, + filters: &[Filter], +) -> Result<()> { + let path = path.canonicalize()?; + let records = + parse_file(&path).map_err(|e| DataFusionError::External(Box::new(e)))?; + let mut errs = vec![]; + for record in records.into_iter() { + if let Record::Halt { .. } = record { + break; + } + if should_skip_record::(&record, filters) { + continue; + } + if let Err(err) = runner.run_async(record).await { + errs.push(format!("{err}")); + } + } - pb.finish_and_clear(); + if !errs.is_empty() { + let mut msg = format!("{} errors in file {}\n\n", errs.len(), path.display()); + for (i, err) in errs.iter().enumerate() { + if i >= ERRS_PER_FILE_LIMIT { + msg.push_str(&format!( + "... other {} errors in {} not shown ...\n\n", + errs.len() - ERRS_PER_FILE_LIMIT, + path.display() + )); + break; + } + msg.push_str(&format!("{}. {err}\n\n", i + 1)); + } + return Err(DataFusionError::External(msg.into())); + } - res + Ok(()) } fn get_record_count(path: &PathBuf, label: String) -> u64 { @@ -287,6 +386,7 @@ async fn run_test_file_with_postgres( validator: Validator, mp: MultiProgress, mp_style: ProgressStyle, + filters: &[Filter], ) -> Result<()> { use datafusion_sqllogictest::Postgres; let TestFile { @@ -308,14 +408,9 @@ async fn run_test_file_with_postgres( runner.with_column_validator(strict_column_validator); runner.with_normalizer(value_normalizer); runner.with_validator(validator); - runner - .run_file_async(path) - .await - .map_err(|e| DataFusionError::External(Box::new(e)))?; - + let result = run_file_in_runner(path, runner, filters).await; pb.finish_and_clear(); - - Ok(()) + result } #[cfg(not(feature = "postgres"))] @@ -324,6 +419,7 @@ async fn run_test_file_with_postgres( _validator: Validator, _mp: MultiProgress, _mp_style: ProgressStyle, + _filters: &[Filter], ) -> Result<()> { use datafusion::common::plan_err; plan_err!("Can not run with postgres as postgres feature is not enabled") @@ -537,14 +633,25 @@ struct Options { )] postgres_runner: bool, + #[clap( + long, + conflicts_with = "complete", + conflicts_with = "postgres_runner", + help = "Before executing each query, convert its logical plan to Substrait and from Substrait back to its logical plan" + )] + substrait_round_trip: bool, + #[clap(long, env = "INCLUDE_SQLITE", help = "Include sqlite files")] include_sqlite: bool, #[clap(long, env = "INCLUDE_TPCH", help = "Include tpch files")] include_tpch: bool, - #[clap(action, help = "test filter (substring match on filenames)")] - filters: Vec, + #[clap( + action, + help = "test filter (substring match on filenames with optional :{line_number} suffix)" + )] + filters: Vec, #[clap( long, @@ -597,15 +704,7 @@ impl Options { /// filter and that does a substring match on each input. returns /// true f this path should be run fn check_test_file(&self, path: &Path) -> bool { - if self.filters.is_empty() { - return true; - } - - // otherwise check if any filter matches - let path_string = path.to_string_lossy(); - self.filters - .iter() - .any(|filter| path_string.contains(filter)) + !should_skip_file(path, &self.filters) } /// Postgres runner executes only tests in files with specific names or in diff --git a/datafusion/sqllogictest/src/engines/conversion.rs b/datafusion/sqllogictest/src/engines/conversion.rs index 516ec69e0b07..92ab64059bbd 100644 --- a/datafusion/sqllogictest/src/engines/conversion.rs +++ b/datafusion/sqllogictest/src/engines/conversion.rs @@ -49,7 +49,7 @@ pub(crate) fn f16_to_str(value: f16) -> String { } else if value == f16::NEG_INFINITY { "-Infinity".to_string() } else { - big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap(), None) } } @@ -63,7 +63,7 @@ pub(crate) fn f32_to_str(value: f32) -> String { } else if value == f32::NEG_INFINITY { "-Infinity".to_string() } else { - big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap(), None) } } @@ -77,7 +77,21 @@ pub(crate) fn f64_to_str(value: f64) -> String { } else if value == f64::NEG_INFINITY { "-Infinity".to_string() } else { - big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap(), None) + } +} + +pub(crate) fn spark_f64_to_str(value: f64) -> String { + if value.is_nan() { + // The sign of NaN can be different depending on platform. + // So the string representation of NaN ignores the sign. + "NaN".to_string() + } else if value == f64::INFINITY { + "Infinity".to_string() + } else if value == f64::NEG_INFINITY { + "-Infinity".to_string() + } else { + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap(), Some(15)) } } @@ -86,6 +100,7 @@ pub(crate) fn decimal_128_to_str(value: i128, scale: i8) -> String { big_decimal_to_str( BigDecimal::from_str(&Decimal128Type::format_decimal(value, precision, scale)) .unwrap(), + None, ) } @@ -94,17 +109,21 @@ pub(crate) fn decimal_256_to_str(value: i256, scale: i8) -> String { big_decimal_to_str( BigDecimal::from_str(&Decimal256Type::format_decimal(value, precision, scale)) .unwrap(), + None, ) } #[cfg(feature = "postgres")] pub(crate) fn decimal_to_str(value: Decimal) -> String { - big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap(), None) } -pub(crate) fn big_decimal_to_str(value: BigDecimal) -> String { +/// Converts a `BigDecimal` to its plain string representation, optionally rounding to a specified number of decimal places. +/// +/// If `round_digits` is `None`, the value is rounded to 12 decimal places by default. +pub(crate) fn big_decimal_to_str(value: BigDecimal, round_digits: Option) -> String { // Round the value to limit the number of decimal places - let value = value.round(12).normalized(); + let value = value.round(round_digits.unwrap_or(12)).normalized(); // Format the value to a string value.to_plain_string() } @@ -115,12 +134,12 @@ mod tests { use bigdecimal::{num_bigint::BigInt, BigDecimal}; macro_rules! assert_decimal_str_eq { - ($integer:expr, $scale:expr, $expected:expr) => { + ($integer:expr, $scale:expr, $round_digits:expr, $expected:expr) => { assert_eq!( - big_decimal_to_str(BigDecimal::from_bigint( - BigInt::from($integer), - $scale - )), + big_decimal_to_str( + BigDecimal::from_bigint(BigInt::from($integer), $scale), + $round_digits + ), $expected ); }; @@ -128,44 +147,51 @@ mod tests { #[test] fn test_big_decimal_to_str() { - assert_decimal_str_eq!(110, 3, "0.11"); - assert_decimal_str_eq!(11, 3, "0.011"); - assert_decimal_str_eq!(11, 2, "0.11"); - assert_decimal_str_eq!(11, 1, "1.1"); - assert_decimal_str_eq!(11, 0, "11"); - assert_decimal_str_eq!(11, -1, "110"); - assert_decimal_str_eq!(0, 0, "0"); + assert_decimal_str_eq!(110, 3, None, "0.11"); + assert_decimal_str_eq!(11, 3, None, "0.011"); + assert_decimal_str_eq!(11, 2, None, "0.11"); + assert_decimal_str_eq!(11, 1, None, "1.1"); + assert_decimal_str_eq!(11, 0, None, "11"); + assert_decimal_str_eq!(11, -1, None, "110"); + assert_decimal_str_eq!(0, 0, None, "0"); assert_decimal_str_eq!( 12345678901234567890123456789012345678_i128, 0, + None, "12345678901234567890123456789012345678" ); assert_decimal_str_eq!( 12345678901234567890123456789012345678_i128, 38, + None, "0.123456789012" ); // Negative cases - assert_decimal_str_eq!(-110, 3, "-0.11"); - assert_decimal_str_eq!(-11, 3, "-0.011"); - assert_decimal_str_eq!(-11, 2, "-0.11"); - assert_decimal_str_eq!(-11, 1, "-1.1"); - assert_decimal_str_eq!(-11, 0, "-11"); - assert_decimal_str_eq!(-11, -1, "-110"); + assert_decimal_str_eq!(-110, 3, None, "-0.11"); + assert_decimal_str_eq!(-11, 3, None, "-0.011"); + assert_decimal_str_eq!(-11, 2, None, "-0.11"); + assert_decimal_str_eq!(-11, 1, None, "-1.1"); + assert_decimal_str_eq!(-11, 0, None, "-11"); + assert_decimal_str_eq!(-11, -1, None, "-110"); assert_decimal_str_eq!( -12345678901234567890123456789012345678_i128, 0, + None, "-12345678901234567890123456789012345678" ); assert_decimal_str_eq!( -12345678901234567890123456789012345678_i128, 38, + None, "-0.123456789012" ); // Round to 12 decimal places // 1.0000000000011 -> 1.000000000001 - assert_decimal_str_eq!(10_i128.pow(13) + 11, 13, "1.000000000001"); + assert_decimal_str_eq!(10_i128.pow(13) + 11, 13, None, "1.000000000001"); + assert_decimal_str_eq!(10_i128.pow(13) + 11, 13, Some(12), "1.000000000001"); + + assert_decimal_str_eq!(10_i128.pow(13) + 11, 13, Some(13), "1.0000000000011"); } } diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs index eeb34186ea20..0d832bb3062d 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs @@ -22,13 +22,16 @@ use arrow::array::{Array, AsArray}; use arrow::datatypes::Fields; use arrow::util::display::ArrayFormatter; use arrow::{array, array::ArrayRef, datatypes::DataType, record_batch::RecordBatch}; -use datafusion::common::format::DEFAULT_CLI_FORMAT_OPTIONS; use datafusion::common::DataFusionError; +use datafusion::config::ConfigField; use std::path::PathBuf; use std::sync::LazyLock; /// Converts `batches` to a result as expected by sqllogictest. -pub fn convert_batches(batches: Vec) -> Result>> { +pub fn convert_batches( + batches: Vec, + is_spark_path: bool, +) -> Result>> { if batches.is_empty() { Ok(vec![]) } else { @@ -46,7 +49,16 @@ pub fn convert_batches(batches: Vec) -> Result>> { ))); } - let new_rows = convert_batch(batch)? + // Convert a single batch to a `Vec>` for comparison, flatten expanded rows, and normalize each. + let new_rows = (0..batch.num_rows()) + .map(|row| { + batch + .columns() + .iter() + .map(|col| cell_to_string(col, row, is_spark_path)) + .collect::>>() + }) + .collect::>>>()? .into_iter() .flat_map(expand_row) .map(normalize_paths); @@ -162,19 +174,6 @@ static WORKSPACE_ROOT: LazyLock = LazyLock::new(|| { object_store::path::Path::parse(sanitized_workplace_root).unwrap() }); -/// Convert a single batch to a `Vec>` for comparison -fn convert_batch(batch: RecordBatch) -> Result>> { - (0..batch.num_rows()) - .map(|row| { - batch - .columns() - .iter() - .map(|col| cell_to_string(col, row)) - .collect::>>() - }) - .collect() -} - macro_rules! get_row_value { ($array_type:ty, $column: ident, $row: ident) => {{ let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); @@ -193,7 +192,7 @@ macro_rules! get_row_value { /// /// Floating numbers are rounded to have a consistent representation with the Postgres runner. /// -pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result { +pub fn cell_to_string(col: &ArrayRef, row: usize, is_spark_path: bool) -> Result { if !col.is_valid(row) { // represent any null value with the string "NULL" Ok(NULL_STR.to_string()) @@ -210,7 +209,12 @@ pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result { Ok(f32_to_str(get_row_value!(array::Float32Array, col, row))) } DataType::Float64 => { - Ok(f64_to_str(get_row_value!(array::Float64Array, col, row))) + let result = get_row_value!(array::Float64Array, col, row); + if is_spark_path { + Ok(spark_f64_to_str(result)) + } else { + Ok(f64_to_str(result)) + } } DataType::Decimal128(_, scale) => { let value = get_row_value!(array::Decimal128Array, col, row); @@ -236,12 +240,20 @@ pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result { DataType::Dictionary(_, _) => { let dict = col.as_any_dictionary(); let key = dict.normalized_keys()[row]; - Ok(cell_to_string(dict.values(), key)?) + Ok(cell_to_string(dict.values(), key, is_spark_path)?) } _ => { - let f = - ArrayFormatter::try_new(col.as_ref(), &DEFAULT_CLI_FORMAT_OPTIONS); - Ok(f.unwrap().value(row).to_string()) + let mut datafusion_format_options = + datafusion::config::FormatOptions::default(); + + datafusion_format_options.set("null", "NULL").unwrap(); + + let arrow_format_options: arrow::util::display::FormatOptions = + (&datafusion_format_options).try_into().unwrap(); + + let f = ArrayFormatter::try_new(col.as_ref(), &arrow_format_options)?; + + Ok(f.value(row).to_string()) } } .map_err(DFSqlLogicTestError::Arrow) @@ -280,7 +292,9 @@ pub fn convert_schema_to_types(columns: &Fields) -> Vec { if key_type.is_integer() { // mapping dictionary string types to Text match value_type.as_ref() { - DataType::Utf8 | DataType::LargeUtf8 => DFColumnType::Text, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { + DFColumnType::Text + } _ => DFColumnType::Another, } } else { diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs index a3a29eda2ee9..a01ac7e2f985 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs @@ -31,6 +31,7 @@ use sqllogictest::DBOutput; use tokio::time::Instant; use crate::engines::output::{DFColumnType, DFOutput}; +use crate::is_spark_path; pub struct DataFusion { ctx: SessionContext, @@ -79,7 +80,7 @@ impl sqllogictest::AsyncDB for DataFusion { } let start = Instant::now(); - let result = run_query(&self.ctx, sql).await; + let result = run_query(&self.ctx, is_spark_path(&self.relative_path), sql).await; let duration = start.elapsed(); if duration.gt(&Duration::from_millis(500)) { @@ -115,7 +116,11 @@ impl sqllogictest::AsyncDB for DataFusion { async fn shutdown(&mut self) {} } -async fn run_query(ctx: &SessionContext, sql: impl Into) -> Result { +async fn run_query( + ctx: &SessionContext, + is_spark_path: bool, + sql: impl Into, +) -> Result { let df = ctx.sql(sql.into().as_str()).await?; let task_ctx = Arc::new(df.task_ctx()); let plan = df.create_physical_plan().await?; @@ -123,7 +128,7 @@ async fn run_query(ctx: &SessionContext, sql: impl Into) -> Result = collect(stream).await?; - let rows = normalize::convert_batches(results)?; + let rows = normalize::convert_batches(results, is_spark_path)?; if rows.is_empty() && types.is_empty() { Ok(DBOutput::StatementComplete(0)) diff --git a/datafusion/sqllogictest/src/engines/datafusion_substrait_roundtrip_engine/mod.rs b/datafusion/sqllogictest/src/engines/datafusion_substrait_roundtrip_engine/mod.rs new file mode 100644 index 000000000000..9ff077c67d8c --- /dev/null +++ b/datafusion/sqllogictest/src/engines/datafusion_substrait_roundtrip_engine/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod runner; + +pub use runner::*; diff --git a/datafusion/sqllogictest/src/engines/datafusion_substrait_roundtrip_engine/runner.rs b/datafusion/sqllogictest/src/engines/datafusion_substrait_roundtrip_engine/runner.rs new file mode 100644 index 000000000000..9d3854755352 --- /dev/null +++ b/datafusion/sqllogictest/src/engines/datafusion_substrait_roundtrip_engine/runner.rs @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; +use std::{path::PathBuf, time::Duration}; + +use crate::engines::datafusion_engine::Result; +use crate::engines::output::{DFColumnType, DFOutput}; +use crate::{convert_batches, convert_schema_to_types, DFSqlLogicTestError}; +use arrow::record_batch::RecordBatch; +use async_trait::async_trait; +use datafusion::logical_expr::LogicalPlan; +use datafusion::physical_plan::common::collect; +use datafusion::physical_plan::execute_stream; +use datafusion::prelude::SessionContext; +use datafusion_substrait::logical_plan::consumer::from_substrait_plan; +use datafusion_substrait::logical_plan::producer::to_substrait_plan; +use indicatif::ProgressBar; +use log::Level::{Debug, Info}; +use log::{debug, log_enabled, warn}; +use sqllogictest::DBOutput; +use tokio::time::Instant; + +pub struct DataFusionSubstraitRoundTrip { + ctx: SessionContext, + relative_path: PathBuf, + pb: ProgressBar, +} + +impl DataFusionSubstraitRoundTrip { + pub fn new(ctx: SessionContext, relative_path: PathBuf, pb: ProgressBar) -> Self { + Self { + ctx, + relative_path, + pb, + } + } + + fn update_slow_count(&self) { + let msg = self.pb.message(); + let split: Vec<&str> = msg.split(" ").collect(); + let mut current_count = 0; + + if split.len() > 2 { + // third match will be current slow count + current_count = split[2].parse::().unwrap(); + } + + current_count += 1; + + self.pb + .set_message(format!("{} - {} took > 500 ms", split[0], current_count)); + } +} + +#[async_trait] +impl sqllogictest::AsyncDB for DataFusionSubstraitRoundTrip { + type Error = DFSqlLogicTestError; + type ColumnType = DFColumnType; + + async fn run(&mut self, sql: &str) -> Result { + if log_enabled!(Debug) { + debug!( + "[{}] Running query: \"{}\"", + self.relative_path.display(), + sql + ); + } + + let start = Instant::now(); + let result = run_query_substrait_round_trip(&self.ctx, sql).await; + let duration = start.elapsed(); + + if duration.gt(&Duration::from_millis(500)) { + self.update_slow_count(); + } + + self.pb.inc(1); + + if log_enabled!(Info) && duration.gt(&Duration::from_secs(2)) { + warn!( + "[{}] Running query took more than 2 sec ({duration:?}): \"{sql}\"", + self.relative_path.display() + ); + } + + result + } + + /// Engine name of current database. + fn engine_name(&self) -> &str { + "DataFusionSubstraitRoundTrip" + } + + /// `DataFusion` calls this function to perform sleep. + /// + /// The default implementation is `std::thread::sleep`, which is universal to any async runtime + /// but would block the current thread. If you are running in tokio runtime, you should override + /// this by `tokio::time::sleep`. + async fn sleep(dur: Duration) { + tokio::time::sleep(dur).await; + } + + async fn shutdown(&mut self) {} +} + +async fn run_query_substrait_round_trip( + ctx: &SessionContext, + sql: impl Into, +) -> Result { + let df = ctx.sql(sql.into().as_str()).await?; + let task_ctx = Arc::new(df.task_ctx()); + + let state = ctx.state(); + let round_tripped_plan = match df.logical_plan() { + // Substrait does not handle these plans + LogicalPlan::Ddl(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Dml(_) + | LogicalPlan::Copy(_) + | LogicalPlan::Statement(_) => df.logical_plan().clone(), + // For any other plan, convert to Substrait + logical_plan => { + let plan = to_substrait_plan(logical_plan, &state)?; + from_substrait_plan(&state, &plan).await? + } + }; + + let physical_plan = state.create_physical_plan(&round_tripped_plan).await?; + let stream = execute_stream(physical_plan, task_ctx)?; + let types = convert_schema_to_types(stream.schema().fields()); + let results: Vec = collect(stream).await?; + let rows = convert_batches(results, false)?; + + if rows.is_empty() && types.is_empty() { + Ok(DBOutput::StatementComplete(0)) + } else { + Ok(DBOutput::Rows { types, rows }) + } +} diff --git a/datafusion/sqllogictest/src/engines/mod.rs b/datafusion/sqllogictest/src/engines/mod.rs index 3569dea70176..ef6335ddbed6 100644 --- a/datafusion/sqllogictest/src/engines/mod.rs +++ b/datafusion/sqllogictest/src/engines/mod.rs @@ -18,12 +18,14 @@ /// Implementation of sqllogictest for datafusion. mod conversion; mod datafusion_engine; +mod datafusion_substrait_roundtrip_engine; mod output; pub use datafusion_engine::convert_batches; pub use datafusion_engine::convert_schema_to_types; pub use datafusion_engine::DFSqlLogicTestError; pub use datafusion_engine::DataFusion; +pub use datafusion_substrait_roundtrip_engine::DataFusionSubstraitRoundTrip; pub use output::DFColumnType; pub use output::DFOutput; diff --git a/datafusion/sqllogictest/src/filters.rs b/datafusion/sqllogictest/src/filters.rs new file mode 100644 index 000000000000..44482236f7c5 --- /dev/null +++ b/datafusion/sqllogictest/src/filters.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::sql::parser::{DFParserBuilder, Statement}; +use sqllogictest::{AsyncDB, Record}; +use sqlparser::ast::{SetExpr, Statement as SqlStatement}; +use sqlparser::dialect::dialect_from_str; +use std::path::Path; +use std::str::FromStr; + +/// Filter specification that determines whether a certain sqllogictest record in +/// a certain file should be filtered. In order for a [`Filter`] to match a test case: +/// +/// - The test must belong to a file whose absolute path contains the `file_substring` substring. +/// - If a `line_number` is specified, the test must be declared in that same line number. +/// +/// If a [`Filter`] matches a specific test case, then the record is executed, if there's +/// no match, the record is skipped. +/// +/// Filters can be parsed from strings of the form `:line_number`. For example, +/// `foo.slt:100` matches any test whose name contains `foo.slt` and the test starts on line +/// number 100. +#[derive(Debug, Clone)] +pub struct Filter { + file_substring: String, + line_number: Option, +} + +impl FromStr for Filter { + type Err = String; + + fn from_str(s: &str) -> Result { + let parts: Vec<&str> = s.rsplitn(2, ':').collect(); + if parts.len() == 2 { + match parts[0].parse::() { + Ok(line) => Ok(Filter { + file_substring: parts[1].to_string(), + line_number: Some(line), + }), + Err(_) => Err(format!("Cannot parse line number from '{s}'")), + } + } else { + Ok(Filter { + file_substring: s.to_string(), + line_number: None, + }) + } + } +} + +/// Given a list of [`Filter`]s, determines if the whole file in the provided +/// path can be skipped. +/// +/// - If there's at least 1 filter whose file name is a substring of the provided path, +/// it returns true. +/// - If the provided filter list is empty, it returns false. +pub fn should_skip_file(path: &Path, filters: &[Filter]) -> bool { + if filters.is_empty() { + return false; + } + + let path_string = path.to_string_lossy(); + for filter in filters { + if path_string.contains(&filter.file_substring) { + return false; + } + } + true +} + +/// Determines whether a certain sqllogictest record should be skipped given the provided +/// filters. +/// +/// If there's at least 1 matching filter, or the filter list is empty, it returns false. +/// +/// There are certain records that will never be skipped even if they are not matched +/// by any filters, like CREATE TABLE, INSERT INTO, DROP or SELECT * INTO statements, +/// as they populate tables necessary for other tests to work. +pub fn should_skip_record( + record: &Record, + filters: &[Filter], +) -> bool { + if filters.is_empty() { + return false; + } + + let (sql, loc) = match record { + Record::Statement { sql, loc, .. } => (sql, loc), + Record::Query { sql, loc, .. } => (sql, loc), + _ => return false, + }; + + let statement = if let Some(statement) = parse_or_none(sql, "Postgres") { + statement + } else if let Some(statement) = parse_or_none(sql, "generic") { + statement + } else { + return false; + }; + + if !statement_is_skippable(&statement) { + return false; + } + + for filter in filters { + if !loc.file().contains(&filter.file_substring) { + continue; + } + if let Some(line_num) = filter.line_number { + if loc.line() != line_num { + continue; + } + } + + // This filter matches both file name substring and the exact + // line number (if one was provided), so don't skip it. + return false; + } + + true +} + +fn statement_is_skippable(statement: &Statement) -> bool { + // Only SQL statements can be skipped. + let Statement::Statement(sql_stmt) = statement else { + return false; + }; + + // Cannot skip SELECT INTO statements, as they can also create tables + // that further test cases will use. + if let SqlStatement::Query(v) = sql_stmt.as_ref() { + if let SetExpr::Select(v) = v.body.as_ref() { + if v.into.is_some() { + return false; + } + } + } + + // Only SELECT and EXPLAIN statements can be skipped, as any other + // statement might be populating tables that future test cases will use. + matches!( + sql_stmt.as_ref(), + SqlStatement::Query(_) | SqlStatement::Explain { .. } + ) +} + +fn parse_or_none(sql: &str, dialect: &str) -> Option { + let Ok(Ok(Some(statement))) = DFParserBuilder::new(sql) + .with_dialect(dialect_from_str(dialect).unwrap().as_ref()) + .build() + .map(|mut v| v.parse_statements().map(|mut v| v.pop_front())) + else { + return None; + }; + Some(statement) +} diff --git a/datafusion/sqllogictest/src/lib.rs b/datafusion/sqllogictest/src/lib.rs index 1a208aa3fac2..3c786d6bdaac 100644 --- a/datafusion/sqllogictest/src/lib.rs +++ b/datafusion/sqllogictest/src/lib.rs @@ -34,12 +34,15 @@ pub use engines::DFColumnType; pub use engines::DFOutput; pub use engines::DFSqlLogicTestError; pub use engines::DataFusion; +pub use engines::DataFusionSubstraitRoundTrip; #[cfg(feature = "postgres")] pub use engines::Postgres; +mod filters; mod test_context; mod util; +pub use filters::*; pub use test_context::TestContext; pub use util::*; diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index ce819f186454..143e3ef1a89b 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -40,8 +40,11 @@ use datafusion::{ prelude::{CsvReadOptions, SessionContext}, }; +use crate::is_spark_path; use async_trait::async_trait; use datafusion::common::cast::as_float64_array; +use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::SessionStateBuilder; use log::info; use tempfile::TempDir; @@ -70,8 +73,20 @@ impl TestContext { let config = SessionConfig::new() // hardcode target partitions so plans are deterministic .with_target_partitions(4); + let runtime = Arc::new(RuntimeEnv::default()); + let mut state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build(); + + if is_spark_path(relative_path) { + info!("Registering Spark functions"); + datafusion_spark::register_all(&mut state) + .expect("Can not register Spark functions"); + } - let mut test_ctx = TestContext::new(SessionContext::new_with_config(config)); + let mut test_ctx = TestContext::new(SessionContext::new_with_state(state)); let file_name = relative_path.file_name().unwrap().to_str().unwrap(); match file_name { @@ -122,6 +137,7 @@ impl TestContext { info!("Using default SessionContext"); } }; + Some(test_ctx) } @@ -223,14 +239,14 @@ pub async fn register_temp_table(ctx: &SessionContext) { self } - fn table_type(&self) -> TableType { - self.0 - } - fn schema(&self) -> SchemaRef { unimplemented!() } + fn table_type(&self) -> TableType { + self.0 + } + async fn scan( &self, _state: &dyn Session, @@ -410,10 +426,24 @@ fn create_example_udf() -> ScalarUDF { fn register_union_table(ctx: &SessionContext) { let union = UnionArray::try_new( - UnionFields::new(vec![3], vec![Field::new("int", DataType::Int32, false)]), - ScalarBuffer::from(vec![3, 3]), + UnionFields::new( + // typeids: 3 for int, 1 for string + vec![3, 1], + vec![ + Field::new("int", DataType::Int32, false), + Field::new("string", DataType::Utf8, false), + ], + ), + ScalarBuffer::from(vec![3, 1, 3]), None, - vec![Arc::new(Int32Array::from(vec![1, 2]))], + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec![ + Some("foo"), + Some("bar"), + Some("baz"), + ])), + ], ) .unwrap(); diff --git a/datafusion/sqllogictest/src/util.rs b/datafusion/sqllogictest/src/util.rs index 5ae640cc98a9..695fe463fa67 100644 --- a/datafusion/sqllogictest/src/util.rs +++ b/datafusion/sqllogictest/src/util.rs @@ -106,3 +106,7 @@ pub fn df_value_validator( normalized_actual == normalized_expected } + +pub fn is_spark_path(relative_path: &Path) -> bool { + relative_path.starts_with("spark/") +} diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a55ac079aa74..3f064485e51a 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -132,37 +132,48 @@ statement error DataFusion error: Schema error: Schema contains duplicate unqual SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100 # csv_query_approx_percentile_cont_with_weight -statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont_with_weight' function: coercion from \[Utf8, Int8, Float64\] to the signature OneOf(.*) failed(.|\n)* -SELECT approx_percentile_cont_with_weight(c1, c2, 0.95) FROM aggregate_test_100 +statement error Failed to coerce arguments to satisfy a call to 'approx_percentile_cont_with_weight' function +SELECT approx_percentile_cont_with_weight(c2, 0.95) WITHIN GROUP (ORDER BY c1) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont_with_weight' function: coercion from \[Int16, Utf8, Float64\] to the signature OneOf(.*) failed(.|\n)* -SELECT approx_percentile_cont_with_weight(c3, c1, 0.95) FROM aggregate_test_100 +statement error Failed to coerce arguments to satisfy a call to 'approx_percentile_cont_with_weight' function +SELECT approx_percentile_cont_with_weight(c1, 0.95) WITHIN GROUP (ORDER BY c3) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont_with_weight' function: coercion from \[Int16, Int8, Utf8\] to the signature OneOf(.*) failed(.|\n)* -SELECT approx_percentile_cont_with_weight(c3, c2, c1) FROM aggregate_test_100 +statement error Failed to coerce arguments to satisfy a call to 'approx_percentile_cont_with_weight' function +SELECT approx_percentile_cont_with_weight(c2, c1) WITHIN GROUP (ORDER BY c3) FROM aggregate_test_100 # csv_query_approx_percentile_cont_with_histogram_bins statement error DataFusion error: This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal \(got data type Int64\)\. -SELECT c1, approx_percentile_cont(c3, 0.95, -1000) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +SELECT c1, approx_percentile_cont(0.95, -1000) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 -statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont' function: coercion from \[Int16, Float64, Utf8\] to the signature OneOf(.*) failed(.|\n)* -SELECT approx_percentile_cont(c3, 0.95, c1) FROM aggregate_test_100 +statement error Failed to coerce arguments to satisfy a call to 'approx_percentile_cont' function +SELECT approx_percentile_cont(0.95, c1) WITHIN GROUP (ORDER BY c3) FROM aggregate_test_100 statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont' function: coercion from \[Int16, Float64, Float64\] to the signature OneOf(.*) failed(.|\n)* -SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM aggregate_test_100 +SELECT approx_percentile_cont(0.95, 111.1) WITHIN GROUP (ORDER BY c3) FROM aggregate_test_100 statement error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'approx_percentile_cont' function: coercion from \[Float64, Float64, Float64\] to the signature OneOf(.*) failed(.|\n)* -SELECT approx_percentile_cont(c12, 0.95, 111.1) FROM aggregate_test_100 +SELECT approx_percentile_cont(0.95, 111.1) WITHIN GROUP (ORDER BY c12) FROM aggregate_test_100 statement error DataFusion error: This feature is not implemented: Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal -SELECT approx_percentile_cont(c12, c12) FROM aggregate_test_100 +SELECT approx_percentile_cont(c12) WITHIN GROUP (ORDER BY c12) FROM aggregate_test_100 statement error DataFusion error: This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal -SELECT approx_percentile_cont(c12, 0.95, c5) FROM aggregate_test_100 +SELECT approx_percentile_cont(0.95, c5) WITHIN GROUP (ORDER BY c12) FROM aggregate_test_100 + +statement error DataFusion error: Error during planning: \[IGNORE | RESPECT\] NULLS are not permitted for approx_percentile_cont +SELECT approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c5) IGNORE NULLS FROM aggregate_test_100 + +statement error DataFusion error: Error during planning: \[IGNORE | RESPECT\] NULLS are not permitted for approx_percentile_cont +SELECT approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c5) RESPECT NULLS FROM aggregate_test_100 + +statement error DataFusion error: This feature is not implemented: Only a single ordering expression is permitted in a WITHIN GROUP clause +SELECT approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c5, c12) FROM aggregate_test_100 # Not supported over sliding windows -query error This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented -SELECT approx_percentile_cont(c3, 0.5) OVER (ROWS BETWEEN 4 PRECEDING AND CURRENT ROW) +query error DataFusion error: Error during planning: OVER and WITHIN GROUP clause cannot be used together. OVER is for window functions, whereas WITHIN GROUP is for ordered set aggregate functions +SELECT approx_percentile_cont(0.5) +WITHIN GROUP (ORDER BY c3) +OVER (ROWS BETWEEN 4 PRECEDING AND CURRENT ROW) FROM aggregate_test_100 # array agg can use order by @@ -289,17 +300,19 @@ CREATE TABLE array_agg_distinct_list_table AS VALUES ('b', [1,0]), ('b', [1,0]), ('b', [1,0]), - ('b', [0,1]) + ('b', [0,1]), + (NULL, [0,1]), + ('b', NULL) ; # Apply array_sort to have deterministic result, higher dimension nested array also works but not for array sort, # so they are covered in `datafusion/functions-aggregate/src/array_agg.rs` query ?? select array_sort(c1), array_sort(c2) from ( - select array_agg(distinct column1) as c1, array_agg(distinct column2) as c2 from array_agg_distinct_list_table + select array_agg(distinct column1) as c1, array_agg(distinct column2) ignore nulls as c2 from array_agg_distinct_list_table ); ---- -[b, w] [[0, 1], [1, 0]] +[NULL, b, w] [[0, 1], [1, 0]] statement ok drop table array_agg_distinct_list_table; @@ -1276,173 +1289,173 @@ SELECT approx_distinct(c9) AS a, approx_distinct(c9) AS b FROM aggregate_test_10 #csv_query_approx_percentile_cont (c2) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.1) AS DOUBLE) / 1.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c2) AS DOUBLE) / 1.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.5) AS DOUBLE) / 3.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c2) AS DOUBLE) / 3.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.9) AS DOUBLE) / 5.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c2) AS DOUBLE) / 5.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c3) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c3, 0.1) AS DOUBLE) / -95.3) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c3) AS DOUBLE) / -95.3) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c3, 0.5) AS DOUBLE) / 15.5) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c3) AS DOUBLE) / 15.5) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c3, 0.9) AS DOUBLE) / 102.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c3) AS DOUBLE) / 102.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c4) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c4, 0.1) AS DOUBLE) / -22925.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c4) AS DOUBLE) / -22925.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c4, 0.5) AS DOUBLE) / 4599.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c4) AS DOUBLE) / 4599.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c4, 0.9) AS DOUBLE) / 25334.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c4) AS DOUBLE) / 25334.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c5) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c5, 0.1) AS DOUBLE) / -1882606710.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c5) AS DOUBLE) / -1882606710.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c5, 0.5) AS DOUBLE) / 377164262.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c5) AS DOUBLE) / 377164262.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c5, 0.9) AS DOUBLE) / 1991374996.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c5) AS DOUBLE) / 1991374996.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c6) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c6, 0.1) AS DOUBLE) / -7250000000000000000) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c6) AS DOUBLE) / -7250000000000000000) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c6, 0.5) AS DOUBLE) / 1130000000000000000) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c6) AS DOUBLE) / 1130000000000000000) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c6, 0.9) AS DOUBLE) / 7370000000000000000) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c6) AS DOUBLE) / 7370000000000000000) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c7) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c7, 0.1) AS DOUBLE) / 18.9) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c7) AS DOUBLE) / 18.9) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c7, 0.5) AS DOUBLE) / 134.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c7) AS DOUBLE) / 134.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c7, 0.9) AS DOUBLE) / 231.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c7) AS DOUBLE) / 231.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c8) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c8, 0.1) AS DOUBLE) / 2671.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c8) AS DOUBLE) / 2671.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c8, 0.5) AS DOUBLE) / 30634.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c8) AS DOUBLE) / 30634.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c8, 0.9) AS DOUBLE) / 57518.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c8) AS DOUBLE) / 57518.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c9) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c9, 0.1) AS DOUBLE) / 472608672.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c9) AS DOUBLE) / 472608672.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c9, 0.5) AS DOUBLE) / 2365817608.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c9) AS DOUBLE) / 2365817608.0) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c9, 0.9) AS DOUBLE) / 3776538487.0) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c9) AS DOUBLE) / 3776538487.0) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c10) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c10, 0.1) AS DOUBLE) / 1830000000000000000) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c10) AS DOUBLE) / 1830000000000000000) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c10, 0.5) AS DOUBLE) / 9300000000000000000) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c10) AS DOUBLE) / 9300000000000000000) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c10, 0.9) AS DOUBLE) / 16100000000000000000) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c10) AS DOUBLE) / 16100000000000000000) < 0.05) AS q FROM aggregate_test_100 ---- true # csv_query_approx_percentile_cont (c11) query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c11, 0.1) AS DOUBLE) / 0.109) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c11) AS DOUBLE) / 0.109) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c11, 0.5) AS DOUBLE) / 0.491) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY c11) AS DOUBLE) / 0.491) < 0.05) AS q FROM aggregate_test_100 ---- true query B -SELECT (ABS(1 - CAST(approx_percentile_cont(c11, 0.9) AS DOUBLE) / 0.834) < 0.05) AS q FROM aggregate_test_100 +SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c11) AS DOUBLE) / 0.834) < 0.05) AS q FROM aggregate_test_100 ---- true # percentile_cont_with_nulls query I -SELECT APPROX_PERCENTILE_CONT(v, 0.5) FROM (VALUES (1), (2), (3), (NULL), (NULL), (NULL)) as t (v); +SELECT APPROX_PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY v) FROM (VALUES (1), (2), (3), (NULL), (NULL), (NULL)) as t (v); ---- 2 # percentile_cont_with_nulls_only query I -SELECT APPROX_PERCENTILE_CONT(v, 0.5) FROM (VALUES (CAST(NULL as INT))) as t (v); +SELECT APPROX_PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY v) FROM (VALUES (CAST(NULL as INT))) as t (v); ---- NULL @@ -1465,7 +1478,7 @@ NaN # ISSUE: https://github.com/apache/datafusion/issues/11870 query R -select APPROX_PERCENTILE_CONT(v2, 0.8) from tmp_percentile_cont; +select APPROX_PERCENTILE_CONT(0.8) WITHIN GROUP (ORDER BY v2) from tmp_percentile_cont; ---- NaN @@ -1473,10 +1486,10 @@ NaN # Note: `approx_percentile_cont_with_weight()` uses the same implementation as `approx_percentile_cont()` query R SELECT APPROX_PERCENTILE_CONT_WITH_WEIGHT( - v2, '+Inf'::Double, 0.9 ) +WITHIN GROUP (ORDER BY v2) FROM tmp_percentile_cont; ---- NaN @@ -1495,7 +1508,7 @@ INSERT INTO t1 VALUES (TRUE); # ISSUE: https://github.com/apache/datafusion/issues/12716 # This test verifies that approx_percentile_cont_with_weight does not panic when given 'NaN' and returns 'inf' query R -SELECT approx_percentile_cont_with_weight('NaN'::DOUBLE, 0, 0) FROM t1 WHERE t1.v1; +SELECT approx_percentile_cont_with_weight(0, 0) WITHIN GROUP (ORDER BY 'NaN'::DOUBLE) FROM t1 WHERE t1.v1; ---- Infinity @@ -1722,7 +1735,7 @@ b NULL NULL 7732.315789473684 # csv_query_approx_percentile_cont_with_weight query TI -SELECT c1, approx_percentile_cont(c3, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +SELECT c1, approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- a 73 b 68 @@ -1730,9 +1743,18 @@ c 122 d 124 e 115 +query TI +SELECT c1, approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c3 DESC) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a -101 +b -114 +c -109 +d -98 +e -93 + # csv_query_approx_percentile_cont_with_weight (2) query TI -SELECT c1, approx_percentile_cont_with_weight(c3, 1, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +SELECT c1, approx_percentile_cont_with_weight(1, 0.95) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- a 73 b 68 @@ -1740,9 +1762,18 @@ c 122 d 124 e 115 +query TI +SELECT c1, approx_percentile_cont_with_weight(1, 0.95) WITHIN GROUP (ORDER BY c3 DESC) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a -101 +b -114 +c -109 +d -98 +e -93 + # csv_query_approx_percentile_cont_with_histogram_bins query TI -SELECT c1, approx_percentile_cont(c3, 0.95, 200) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +SELECT c1, approx_percentile_cont(0.95, 200) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- a 73 b 68 @@ -1751,7 +1782,7 @@ d 124 e 115 query TI -SELECT c1, approx_percentile_cont_with_weight(c3, c2, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +SELECT c1, approx_percentile_cont_with_weight(c2, 0.95) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- a 74 b 68 @@ -2247,10 +2278,10 @@ create table t (c string) as values query T select arrow_typeof(c) from t; ---- -Utf8 -Utf8 -Utf8 -Utf8 +Utf8View +Utf8View +Utf8View +Utf8View query IT select count(c), arrow_typeof(count(c)) from t; @@ -3041,7 +3072,7 @@ SELECT COUNT(DISTINCT c1) FROM test # test_approx_percentile_cont_decimal_support query TI -SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +SELECT c1, approx_percentile_cont(cast(0.85 as decimal(10,2))) WITHIN GROUP (ORDER BY c2) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- a 4 b 5 @@ -3194,6 +3225,33 @@ select array_agg(column1) from t; statement ok drop table t; +# array_agg_ignore_nulls +statement ok +create table t as values (NULL, ''), (1, 'c'), (2, 'a'), (NULL, 'b'), (4, NULL), (NULL, NULL), (5, 'a'); + +query ? +select array_agg(column1) ignore nulls as c1 from t; +---- +[1, 2, 4, 5] + +query II +select count(*), array_length(array_agg(distinct column2) ignore nulls) from t; +---- +7 4 + +query ? +select array_agg(column2 order by column1) ignore nulls from t; +---- +[c, a, a, , b] + +query ? +select array_agg(DISTINCT column2 order by column2) ignore nulls from t; +---- +[, a, b, c] + +statement ok +drop table t; + # variance_single_value query RRRR select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0)) as sq; @@ -4952,23 +5010,128 @@ set datafusion.sql_parser.dialect = 'Generic'; ## Multiple distinct aggregates and dictionaries statement ok -create table dict_test as values (1, arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (2, arrow_cast('bar', 'Dictionary(Int32, Utf8)')); +create table dict_test as values (1, arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (1, arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (2, arrow_cast('bar', 'Dictionary(Int32, Utf8)')), (1, arrow_cast('bar', 'Dictionary(Int32, Utf8)')); query IT -select * from dict_test; +select * from dict_test order by column1, column2; ---- +1 bar +1 foo 1 foo 2 bar query II -select count(distinct column1), count(distinct column2) from dict_test group by column1; +select count(distinct column1), count(distinct column2) from dict_test group by column1 order by column1; ---- -1 1 +1 2 1 1 statement ok drop table dict_test; +## count distinct dictionary with null values +statement ok +create table dict_null_test as + select arrow_cast(NULL, 'Dictionary(Int32, Utf8)') as d + from (values (1), (2), (3), (4), (5)); + +query I +select count(distinct d) from dict_null_test; +---- +0 + +statement ok +drop table dict_null_test; + +# avg_duration + +statement ok +create table d as values + (arrow_cast(1, 'Duration(Second)'), arrow_cast(2, 'Duration(Millisecond)'), arrow_cast(3, 'Duration(Microsecond)'), arrow_cast(4, 'Duration(Nanosecond)'), 1), + (arrow_cast(11, 'Duration(Second)'), arrow_cast(22, 'Duration(Millisecond)'), arrow_cast(33, 'Duration(Microsecond)'), arrow_cast(44, 'Duration(Nanosecond)'), 1); + +query ???? +SELECT avg(column1), avg(column2), avg(column3), avg(column4) FROM d; +---- +0 days 0 hours 0 mins 6 secs 0 days 0 hours 0 mins 0.012 secs 0 days 0 hours 0 mins 0.000018 secs 0 days 0 hours 0 mins 0.000000024 secs + +query ????I +SELECT avg(column1), avg(column2), avg(column3), avg(column4), column5 FROM d GROUP BY column5; +---- +0 days 0 hours 0 mins 6 secs 0 days 0 hours 0 mins 0.012 secs 0 days 0 hours 0 mins 0.000018 secs 0 days 0 hours 0 mins 0.000000024 secs 1 + +statement ok +drop table d; + +statement ok +create table d as values + (arrow_cast(1, 'Duration(Second)'), arrow_cast(2, 'Duration(Millisecond)'), arrow_cast(3, 'Duration(Microsecond)'), arrow_cast(4, 'Duration(Nanosecond)'), 1), + (arrow_cast(11, 'Duration(Second)'), arrow_cast(22, 'Duration(Millisecond)'), arrow_cast(33, 'Duration(Microsecond)'), arrow_cast(44, 'Duration(Nanosecond)'), 1), + (arrow_cast(5, 'Duration(Second)'), arrow_cast(10, 'Duration(Millisecond)'), arrow_cast(15, 'Duration(Microsecond)'), arrow_cast(20, 'Duration(Nanosecond)'), 2), + (arrow_cast(25, 'Duration(Second)'), arrow_cast(50, 'Duration(Millisecond)'), arrow_cast(75, 'Duration(Microsecond)'), arrow_cast(100, 'Duration(Nanosecond)'), 2), + (NULL, NULL, NULL, NULL, 1), + (NULL, NULL, NULL, NULL, 2); + + +query I? rowsort +SELECT column5, avg(column1) FROM d GROUP BY column5; +---- +1 0 days 0 hours 0 mins 6 secs +2 0 days 0 hours 0 mins 15 secs + +query I?? rowsort +SELECT column5, column1, avg(column1) OVER (PARTITION BY column5 ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) as window_avg +FROM d WHERE column1 IS NOT NULL; +---- +1 0 days 0 hours 0 mins 1 secs 0 days 0 hours 0 mins 1 secs +1 0 days 0 hours 0 mins 11 secs 0 days 0 hours 0 mins 6 secs +2 0 days 0 hours 0 mins 25 secs 0 days 0 hours 0 mins 15 secs +2 0 days 0 hours 0 mins 5 secs 0 days 0 hours 0 mins 5 secs + +# Cumulative average window function +query I?? +SELECT column5, column1, avg(column1) OVER (ORDER BY column5, column1 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as cumulative_avg +FROM d WHERE column1 IS NOT NULL; +---- +1 0 days 0 hours 0 mins 1 secs 0 days 0 hours 0 mins 1 secs +1 0 days 0 hours 0 mins 11 secs 0 days 0 hours 0 mins 6 secs +2 0 days 0 hours 0 mins 5 secs 0 days 0 hours 0 mins 5 secs +2 0 days 0 hours 0 mins 25 secs 0 days 0 hours 0 mins 10 secs + +# Centered average window function +query I?? +SELECT column5, column1, avg(column1) OVER (ORDER BY column5 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as centered_avg +FROM d WHERE column1 IS NOT NULL; +---- +1 0 days 0 hours 0 mins 1 secs 0 days 0 hours 0 mins 6 secs +1 0 days 0 hours 0 mins 11 secs 0 days 0 hours 0 mins 5 secs +2 0 days 0 hours 0 mins 5 secs 0 days 0 hours 0 mins 13 secs +2 0 days 0 hours 0 mins 25 secs 0 days 0 hours 0 mins 15 secs + +statement ok +drop table d; + +statement ok +create table dn as values + (arrow_cast(10, 'Duration(Second)'), 'a', 1), + (arrow_cast(20, 'Duration(Second)'), 'a', 2), + (NULL, 'b', 1), + (arrow_cast(40, 'Duration(Second)'), 'b', 2), + (arrow_cast(50, 'Duration(Second)'), 'c', 1), + (NULL, 'c', 2); + +query T?I +SELECT column2, avg(column1), column3 FROM dn GROUP BY column2, column3 ORDER BY column2, column3; +---- +a 0 days 0 hours 0 mins 10 secs 1 +a 0 days 0 hours 0 mins 20 secs 2 +b NULL 1 +b 0 days 0 hours 0 mins 40 secs 2 +c 0 days 0 hours 0 mins 50 secs 1 +c NULL 2 + +statement ok +drop table dn; # Prepare the table with dictionary values for testing statement ok @@ -5159,13 +5322,13 @@ physical_plan 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], file_type=csv, has_header=true query I -SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; +SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 order by c3 limit 5; ---- -1 --40 -29 --85 --82 +-117 +-111 +-107 +-106 +-101 query TT EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5 offset 4; @@ -5183,13 +5346,13 @@ physical_plan 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], file_type=csv, has_header=true query II -SELECT c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5 offset 4; +SELECT c2, c3 FROM aggregate_test_100 group by c2, c3 order by c2, c3 limit 5 offset 4; ---- -5 -82 -4 -111 -3 104 -3 13 -1 38 +1 -56 +1 -25 +1 -24 +1 -8 +1 -5 # The limit should only apply to the aggregations which group by c3 query TT @@ -5218,12 +5381,12 @@ physical_plan 13)------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], file_type=csv, has_header=true query I -SELECT DISTINCT c3 FROM aggregate_test_100 WHERE c3 between 10 and 20 group by c2, c3 limit 4; +SELECT DISTINCT c3 FROM aggregate_test_100 WHERE c3 between 10 and 20 group by c3 order by c3 limit 4; ---- -13 -17 12 +13 14 +17 # An aggregate expression causes the limit to not be pushed to the aggregation query TT @@ -5268,11 +5431,11 @@ physical_plan 11)--------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], file_type=csv, has_header=true query II -SELECT DISTINCT c3, c2 FROM aggregate_test_100 group by c2, c3 limit 3 offset 10; +SELECT DISTINCT c3, c2 FROM aggregate_test_100 group by c3, c2 order by c3, c2 limit 3 offset 10; ---- -57 1 --54 4 -112 3 +-95 3 +-94 5 +-90 4 query TT EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; @@ -6758,7 +6921,7 @@ group1 0.0003 # median with all nulls statement ok create table group_median_all_nulls( - a STRING NOT NULL, + a STRING NOT NULL, b INT ) AS VALUES ( 'group0', NULL), @@ -6796,3 +6959,100 @@ select c2, count(*) from test WHERE 1 = 1 group by c2; 5 1 6 1 +# Min/Max struct +query ?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT MIN(c), MAX(c) FROM (SELECT STRUCT(c1 AS 'a', c2 AS 'b') AS c FROM t) +---- +{a: 1, b: 2} {a: 10, b: 11} + +# Min/Max struct with NULL +query ?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT MIN(c), MAX(c) FROM (SELECT CASE WHEN c1 % 2 == 0 THEN STRUCT(c1 AS 'a', c2 AS 'b') ELSE NULL END AS c FROM t) +---- +{a: 2, b: 3} {a: 10, b: 11} + +# Min/Max struct with two recordbatch +query ?? rowsort +SELECT MIN(c), MAX(c) FROM (SELECT STRUCT(1 as 'a', 2 as 'b') AS c UNION SELECT STRUCT(3 as 'a', 4 as 'b') AS c ) +---- +{a: 1, b: 2} {a: 3, b: 4} + +# Min/Max struct empty +query ?? rowsort +SELECT MIN(c), MAX(c) FROM (SELECT * FROM (SELECT STRUCT(1 as 'a', 2 as 'b') AS c) LIMIT 0) +---- +NULL NULL + +# Min/Max group struct +query I?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT key, MIN(c), MAX(c) FROM (SELECT STRUCT(c1 AS 'a', c2 AS 'b') AS c, (c1 % 2) AS key FROM t) GROUP BY key +---- +0 {a: 2, b: 3} {a: 10, b: 11} +1 {a: 1, b: 2} {a: 9, b: 10} + +# Min/Max group struct with NULL +query I?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT key, MIN(c), MAX(c) FROM (SELECT CASE WHEN c1 % 2 == 0 THEN STRUCT(c1 AS 'a', c2 AS 'b') ELSE NULL END AS c, (c1 % 2) AS key FROM t) GROUP BY key +---- +0 {a: 2, b: 3} {a: 10, b: 11} +1 NULL NULL + +# Min/Max group struct with NULL +query I?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT key, MIN(c), MAX(c) FROM (SELECT CASE WHEN c1 % 3 == 0 THEN STRUCT(c1 AS 'a', c2 AS 'b') ELSE NULL END AS c, (c1 % 2) AS key FROM t) GROUP BY key +---- +0 {a: 6, b: 7} {a: 6, b: 7} +1 {a: 3, b: 4} {a: 9, b: 10} + +# Min/Max struct empty +query ?? rowsort +WITH t AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) +SELECT MIN(c), MAX(c) FROM (SELECT STRUCT(c1 AS 'a', c2 AS 'b') AS c, (c1 % 2) AS key FROM t LIMIT 0) GROUP BY key +---- + +# Min/Max aggregation on struct with a single field +query ?? +WITH t AS (SELECT i as c1 FROM generate_series(1, 10) t(i)) +SELECT MIN(c), MAX(c) FROM (SELECT STRUCT(c1 AS 'a') AS c FROM t); +---- +{a: 1} {a: 10} + +# Min/Max aggregation on struct with identical first fields but different last fields +query ?? +SELECT MIN(column1),MAX(column1) FROM ( +VALUES + (STRUCT(1 AS 'a',2 AS 'b', 3 AS 'c')), + (STRUCT(1 AS 'a',2 AS 'b', 4 AS 'c')) +); +---- +{a: 1, b: 2, c: 3} {a: 1, b: 2, c: 4} + +query TI +SELECT column1, COUNT(DISTINCT column2) FROM ( +VALUES + ('x', arrow_cast('NAN','Float64')), + ('x', arrow_cast('NAN','Float64')) +) GROUP BY 1 ORDER BY 1; +---- +x 1 + +query ? +SELECT array_agg(a_varchar) WITHIN GROUP (ORDER BY a_varchar) +FROM (VALUES ('a'), ('d'), ('c'), ('a')) t(a_varchar); +---- +[a, a, c, d] + +query ? +SELECT array_agg(DISTINCT a_varchar) WITHIN GROUP (ORDER BY a_varchar) +FROM (VALUES ('a'), ('d'), ('c'), ('a')) t(a_varchar); +---- +[a, c, d] + +query error Error during planning: ORDER BY and WITHIN GROUP clauses cannot be used together in the same aggregate function +SELECT array_agg(a_varchar order by a_varchar) WITHIN GROUP (ORDER BY a_varchar) +FROM (VALUES ('a'), ('d'), ('c'), ('a')) t(a_varchar); diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index f165d3bf66ba..a3d9c3e1d9c1 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -310,7 +310,7 @@ AS VALUES statement ok CREATE TABLE fixed_size_array_has_table_2D AS VALUES - (arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(1,3), 'FixedSizeList(2, Int64)'), arrow_cast(make_array([1,2,3], [4,5], [6,7]), 'FixedSizeList(3, List(Int64))'), arrow_cast(make_array([4,5], [6,7], [1,2]), 'FixedSizeList(3, List(Int64))')), + (arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(1,3), 'FixedSizeList(2, Int64)'), arrow_cast(make_array([1,2,3], [4,5], [6,7]), 'FixedSizeList(3, List(Int64))'), arrow_cast(make_array([4,5], [6,7], [1,2,3]), 'FixedSizeList(3, List(Int64))')), (arrow_cast(make_array([3,4], [5]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array(5, 3), 'FixedSizeList(2, Int64)'), arrow_cast(make_array([1,2,3,4], [5,6,7], [8,9,10]), 'FixedSizeList(3, List(Int64))'), arrow_cast(make_array([1,2,3], [5,6,7], [8,9,10]), 'FixedSizeList(3, List(Int64))')) ; @@ -362,6 +362,14 @@ AS VALUES (make_array(NULL, NULL, NULL), 2) ; +statement ok +CREATE TABLE array_has_table_empty +AS VALUES + (make_array(1, 3, 5), 1), + (make_array(), 1), + (NULL, 1) +; + statement ok CREATE TABLE array_distinct_table_1D AS VALUES @@ -1204,7 +1212,7 @@ select array_element([1, 2], NULL); ---- NULL -query I +query ? select array_element(NULL, 2); ---- NULL @@ -1435,6 +1443,12 @@ NULL 23 NULL 43 5 NULL +# array_element of empty array +query T +select coalesce(array_element([], 1), array_element(NULL, 1), 'ok'); +---- +ok + ## array_max # array_max scalar function #1 (with positive index) @@ -1448,7 +1462,7 @@ select array_max(make_array(5, 3, 4, NULL, 6, NULL)); ---- 6 -query I +query ? select array_max(make_array(NULL, NULL)); ---- NULL @@ -1512,7 +1526,7 @@ select array_max(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), ar ---- 3 1 -query I +query ? select array_max(make_array()); ---- NULL @@ -1521,6 +1535,91 @@ NULL query error DataFusion error: Error during planning: 'array_max' does not support zero arguments select array_max(); +## array_min + +query I +select array_min(make_array(5, 3, 6, 4)); +---- +3 + +query I +select array_min(make_array(5, 3, 4, NULL, 6, NULL)); +---- +3 + +query ? +select array_min(make_array(NULL, NULL)); +---- +NULL + +query T +select array_min(make_array('h', 'e', 'o', 'l', 'l')); +---- +e + +query T +select array_min(make_array('h', 'e', 'l', NULL, 'l', 'o', NULL)); +---- +e + +query B +select array_min(make_array(false, true, false, true)); +---- +false + +query B +select array_min(make_array(false, true, NULL, false, true)); +---- +false + +query D +select array_min(make_array(DATE '1992-09-01', DATE '1993-03-01', DATE '1999-05-01', DATE '1985-11-01')); +---- +1985-11-01 + +query D +select array_min(make_array(DATE '1995-09-01', DATE '1999-05-01', DATE '1993-03-01', NULL)); +---- +1993-03-01 + +query P +select array_min(make_array(TIMESTAMP '1992-09-01', TIMESTAMP '1995-06-01', TIMESTAMP '1984-10-01')); +---- +1984-10-01T00:00:00 + +query P +select array_min(make_array(NULL, TIMESTAMP '1996-10-01', TIMESTAMP '1995-06-01')); +---- +1995-06-01T00:00:00 + +query R +select array_min(make_array(5.1, -3.2, 6.3, 4.9)); +---- +-3.2 + +query ?I +select input, array_min(input) from (select make_array(d - 1, d, d + 1) input from (values (0), (10), (20), (30), (NULL)) t(d)) +---- +[-1, 0, 1] -1 +[9, 10, 11] 9 +[19, 20, 21] 19 +[29, 30, 31] 29 +[NULL, NULL, NULL] NULL + +query II +select array_min(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), array_min(arrow_cast(make_array(1), 'FixedSizeList(1, Int64)')); +---- +1 1 + +query ? +select array_min(make_array()); +---- +NULL + +# Testing with empty arguments should result in an error +query error DataFusion error: Error during planning: 'array_min' does not support zero arguments +select array_min(); + ## array_pop_back (aliases: `list_pop_back`) @@ -2177,7 +2276,7 @@ select array_any_value(1), array_any_value('a'), array_any_value(NULL); # array_any_value scalar function #1 (with null and non-null elements) -query ITII +query IT?I select array_any_value(make_array(NULL, 1, 2, 3, 4, 5)), array_any_value(make_array(NULL, 'h', 'e', 'l', 'l', 'o')), array_any_value(make_array(NULL, NULL)), array_any_value(make_array(NULL, NULL, 1, 2, 3)); ---- 1 h NULL 1 @@ -2348,6 +2447,11 @@ NULL [NULL, 51, 52, 54, 55, 56, 57, 58, 59, 60] [61, 62, 63, 64, 65, 66, 67, 68, 69, 70] +# test with empty table +query ? +select array_sort(column1, 'DESC', 'NULLS FIRST') from arrays_values where false; +---- + # test with empty array query ? select array_sort([]); @@ -2435,11 +2539,15 @@ select array_append(null, 1); ---- [1] -query error +query ? select array_append(null, [2, 3]); +---- +[[2, 3]] -query error +query ? select array_append(null, [[4]]); +---- +[[[4]]] query ???? select @@ -2716,8 +2824,10 @@ select array_prepend(null, [[1,2,3]]); # DuckDB: [[]] # ClickHouse: [[]] # TODO: We may also return [[]] -query error +query ? select array_prepend([], []); +---- +[[]] query ? select array_prepend(null, null); @@ -3053,6 +3163,42 @@ select array_concat([]); ---- [] +# test with NULL array +query ? +select array_concat(NULL::integer[]); +---- +NULL + +# test with multiple NULL arrays +query ? +select array_concat(NULL::integer[], NULL::integer[]); +---- +NULL + +# test with NULL LargeList +query ? +select array_concat(arrow_cast(NULL::string[], 'LargeList(Utf8)')); +---- +NULL + +# test with NULL FixedSizeList +query ? +select array_concat(arrow_cast(NULL::string[], 'FixedSizeList(2, Utf8)')); +---- +NULL + +# test with mix of NULL and empty arrays +query ? +select array_concat(NULL::integer[], []); +---- +[] + +# test with mix of NULL and non-empty arrays +query ? +select array_concat(NULL::integer[], [1, 2, 3]); +---- +[1, 2, 3] + # Concatenating strings arrays query ? select array_concat( @@ -3080,22 +3226,25 @@ select array_concat( ---- [1, 2, 3] -# Concatenating Mixed types (doesn't work) -query error DataFusion error: Error during planning: It is not possible to concatenate arrays of different types\. Expected: List\(Field \{ name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), got: List\(Field \{ name: "item", data_type: LargeUtf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +# Concatenating Mixed types +query ? select array_concat( [arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')], [arrow_cast('3', 'LargeUtf8')] ); +---- +[1, 2, 3] -# Concatenating Mixed types (doesn't work) -query error DataFusion error: Error during planning: It is not possible to concatenate arrays of different types\. Expected: List\(Field \{ name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), got: List\(Field \{ name: "item", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) -select array_concat( - [arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')], - [arrow_cast('3', 'Utf8View')] -); +# Concatenating Mixed types +query ?T +select + array_concat([arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')], [arrow_cast('3', 'Utf8View')]), + arrow_typeof(array_concat([arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')], [arrow_cast('3', 'Utf8View')])); +---- +[1, 2, 3] List(Field { name: "item", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) # array_concat error -query error DataFusion error: Error during planning: The array_concat function can only accept list as the args\. +query error DataFusion error: Error during planning: Execution error: Function 'array_concat' user-defined coercion failed with "Error during planning: array_concat does not support type Int64" select array_concat(1, 2); # array_concat scalar function #1 @@ -3347,10 +3496,16 @@ select array_concat(make_array(column3), column1, column2) from arrays_values_v2 ## array_position (aliases: `list_position`, `array_indexof`, `list_indexof`) ## array_position with NULL (follow PostgreSQL) -#query I -#select array_position([1, 2, 3, 4, 5], null), array_position(NULL, 1); -#---- -#NULL NULL +query II +select array_position([1, 2, 3, 4, 5], arrow_cast(NULL, 'Int64')), array_position(arrow_cast(NULL, 'List(Int64)'), 1); +---- +NULL NULL + +# array_position with no match (incl. empty array) returns NULL +query II +select array_position([], 1), array_position([2], 1); +---- +NULL NULL # array_position scalar function #1 query III @@ -3406,15 +3561,11 @@ SELECT array_position(arrow_cast([1, 1, 100, 1, 1], 'LargeList(Int32)'), 100) ---- 3 -query I +query error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'array_position' function: coercion from SELECT array_position([1, 2, 3], 'foo') ----- -NULL -query I +query error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'array_position' function: coercion from SELECT array_position([1, 2, 3], 'foo', 2) ----- -NULL # list_position scalar function #5 (function alias `array_position`) query III @@ -4376,7 +4527,8 @@ select array_union(arrow_cast([1, 2, 3, 4], 'LargeList(Int64)'), arrow_cast([5, statement ok CREATE TABLE arrays_with_repeating_elements_for_union AS VALUES - ([1], [2]), + ([0, 1, 1], []), + ([1, 1], [2]), ([2, 3], [3]), ([3], [3, 4]) ; @@ -4384,6 +4536,7 @@ AS VALUES query ? select array_union(column1, column2) from arrays_with_repeating_elements_for_union; ---- +[0, 1] [1, 2] [2, 3] [3, 4] @@ -4391,6 +4544,7 @@ select array_union(column1, column2) from arrays_with_repeating_elements_for_uni query ? select array_union(arrow_cast(column1, 'LargeList(Int64)'), arrow_cast(column2, 'LargeList(Int64)')) from arrays_with_repeating_elements_for_union; ---- +[0, 1] [1, 2] [2, 3] [3, 4] @@ -4413,12 +4567,10 @@ select array_union(arrow_cast([], 'LargeList(Int64)'), arrow_cast([], 'LargeList query ? select array_union([[null]], []); ---- -[[NULL]] +[[]] -query ? +query error DataFusion error: Error during planning: Failed to coerce arguments to satisfy a call to 'array_union' function: select array_union(arrow_cast([[null]], 'LargeList(List(Int64))'), arrow_cast([], 'LargeList(Int64)')); ----- -[[NULL]] # array_union scalar function #8 query ? @@ -5223,6 +5375,19 @@ NULL 10 NULL 10 NULL 10 +# array_length for fixed sized list + +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'FixedSizeList(5, Int64)')), array_length(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'FixedSizeList(3, List(Int64))')); +---- +5 3 3 + +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'FixedSizeList(5, Int64)'), 1), array_length(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 1), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'FixedSizeList(3, List(Int64))'), 1); +---- +5 3 3 + + query RRR select array_distance([2], [3]), list_distance([1], [2]), list_distance([1], [-2]); ---- @@ -5670,6 +5835,30 @@ false false false +# array_has([1, 3, 5], 1) -> true (array contains element) +# array_has([], 1) -> false (empty array, not null) +# array_has(null, 1) -> null (null array) +query B +select array_has(column1, column2) +from array_has_table_empty; +---- +true +false +NULL + +# Test for issue: array_has should return false for empty arrays, not null +# This test demonstrates the correct behavior with COALESCE to show the distinction +# array_has([1, 3, 5], 1) -> 'true' +# array_has([], 1) -> 'false' (empty array should return false) +# array_has(null, 1) -> 'null' (null array should return null) +query ?T +SELECT column1, COALESCE(CAST(array_has(column1, column2) AS VARCHAR), 'null') +from array_has_table_empty; +---- +[1, 3, 5] true +[] false +NULL null + query B select array_has(column1, column2) from fixed_size_array_has_table_1D; @@ -5677,14 +5866,13 @@ from fixed_size_array_has_table_1D; true false -#TODO: array_has_all and array_has_any cannot handle FixedSizeList -#query BB -#select array_has_all(column3, column4), -# array_has_any(column5, column6) -#from fixed_size_array_has_table_1D; -#---- -#true true -#false false +query BB +select array_has_all(column3, column4), + array_has_any(column5, column6) +from fixed_size_array_has_table_1D; +---- +true true +false false query BBB select array_has(column1, column2), @@ -5711,14 +5899,13 @@ from fixed_size_array_has_table_1D_Float; true false -#TODO: array_has_all and array_has_any cannot handle FixedSizeList -#query BB -#select array_has_all(column3, column4), -# array_has_any(column5, column6) -#from fixed_size_array_has_table_1D_Float; -#---- -#true true -#false true +query BB +select array_has_all(column3, column4), + array_has_any(column5, column6) +from fixed_size_array_has_table_1D_Float; +---- +true true +false true query BBB select array_has(column1, column2), @@ -5745,14 +5932,27 @@ from fixed_size_array_has_table_1D_Boolean; false true -#TODO: array_has_all and array_has_any cannot handle FixedSizeList -#query BB -#select array_has_all(column3, column4), -# array_has_any(column5, column6) -#from fixed_size_array_has_table_1D_Boolean; -#---- -#true true -#true true +query BB +select array_has_all(column3, column4), + array_has_any(column5, column6) +from fixed_size_array_has_table_1D_Boolean; +---- +true true +true true + +query BBBBBBBB +select array_has_all(column3, arrow_cast(column4,'LargeList(Boolean)')), + array_has_any(column5, arrow_cast(column6,'LargeList(Boolean)')), + array_has_all(column3, arrow_cast(column4,'List(Boolean)')), + array_has_any(column5, arrow_cast(column6,'List(Boolean)')), + array_has_all(arrow_cast(column3, 'LargeList(Boolean)'), column4), + array_has_any(arrow_cast(column5, 'LargeList(Boolean)'), column6), + array_has_all(arrow_cast(column3, 'List(Boolean)'), column4), + array_has_any(arrow_cast(column5, 'List(Boolean)'), column6) +from fixed_size_array_has_table_1D_Boolean; +---- +true true true true true true true true +true true true true true true true true query BBB select array_has(column1, column2), @@ -5802,13 +6002,12 @@ from fixed_size_array_has_table_2D; false false -#TODO: array_has_all and array_has_any cannot handle FixedSizeList -#query B -#select array_has_all(arrow_cast(column3, 'LargeList(List(Int64))'), arrow_cast(column4, 'LargeList(List(Int64))')) -#from fixed_size_array_has_table_2D; -#---- -#true -#false +query B +select array_has_all(arrow_cast(column3, 'LargeList(List(Int64))'), arrow_cast(column4, 'LargeList(List(Int64))')) +from fixed_size_array_has_table_2D; +---- +true +false query B select array_has_all(column1, column2) @@ -5824,13 +6023,12 @@ from array_has_table_2D_float; true false -#TODO: array_has_all and array_has_any cannot handle FixedSizeList -#query B -#select array_has_all(column1, column2) -#from fixed_size_array_has_table_2D_float; -#---- -#false -#false +query B +select array_has_all(column1, column2) +from fixed_size_array_has_table_2D_float; +---- +false +false query B select array_has(column1, column2) from array_has_table_3D; @@ -5895,6 +6093,13 @@ NULL NULL false false false false NULL false false false false NULL +# Row 1: [[NULL,2],[3,NULL]], [1.1,2.2,3.3], ['L','o','r','e','m'] +# Row 2: [[3,4],[5,6]], [NULL,5.5,6.6], ['i','p',NULL,'u','m'] +# Row 3: [[5,6],[7,8]], [7.7,8.8,9.9], ['d',NULL,'l','o','r'] +# Row 4: [[7,NULL],[9,10]], [10.1,NULL,12.2], ['s','i','t','a','b'] +# Row 5: NULL, [13.3,14.4,15.5], ['a','m','e','t','x'] +# Row 6: [[11,12],[13,14]], NULL, [',','a','b','c','d'] +# Row 7: [[15,16],[NULL,18]], [16.6,17.7,18.8], NULL query BBBB select array_has(column1, make_array(5, 6)), array_has(column1, make_array(7, NULL)), @@ -5906,9 +6111,9 @@ false false false true true false true false true false false true false true false false -false false false false -false false false false -false false false false +NULL NULL false false +false false NULL false +false false false NULL query BBBB select array_has_all(make_array(1,2,3), []), @@ -5955,24 +6160,23 @@ select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_ca ---- true false true false false false true true false false true false true -#TODO: array_has_all and array_has_any cannot handle FixedSizeList -#query BBBBBBBBBBBBB -#select array_has_all(arrow_cast(make_array(1,2,3), 'FixedSizeList(3, Int64)'), arrow_cast(make_array(1, 3), 'FixedSizeList(2, Int64)')), -# array_has_all(arrow_cast(make_array(1,2,3),'FixedSizeList(3, Int64)'), arrow_cast(make_array(1, 4), 'FixedSizeList(2, Int64)')), -# array_has_all(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,2]), 'FixedSizeList(1, List(Int64))')), -# array_has_all(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,3]), 'FixedSizeList(1, List(Int64))')), -# array_has_all(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'FixedSizeList(3, List(Int64))')), -# array_has_all(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1]]), 'FixedSizeList(1, List(List(Int64)))')), -# array_has_all(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))')), -# array_has_any(arrow_cast(make_array(1,2,3),'FixedSizeList(3, Int64)'), arrow_cast(make_array(1,10,100), 'FixedSizeList(3, Int64)')), -# array_has_any(arrow_cast(make_array(1,2,3),'FixedSizeList(3, Int64)'), arrow_cast(make_array(10, 100),'FixedSizeList(2, Int64)')), -# array_has_any(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'FixedSizeList(2, List(Int64))')), -# array_has_any(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'FixedSizeList(2, List(Int64))')), -# array_has_any(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'FixedSizeList(1, List(List(Int64)))')), -# array_has_any(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'FixedSizeList(2, List(List(Int64)))')) -#; -#---- -#true false true false false false true true false false true false true +query BBBBBBBBBBBBB +select array_has_all(arrow_cast(make_array(1,2,3), 'FixedSizeList(3, Int64)'), arrow_cast(make_array(1, 3), 'FixedSizeList(2, Int64)')), + array_has_all(arrow_cast(make_array(1,2,3),'FixedSizeList(3, Int64)'), arrow_cast(make_array(1, 4), 'FixedSizeList(2, Int64)')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,2]), 'FixedSizeList(1, List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,3]), 'FixedSizeList(1, List(Int64))')), + array_has_all(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'FixedSizeList(3, List(Int64))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1]]), 'FixedSizeList(1, List(List(Int64)))')), + array_has_all(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))')), + array_has_any(arrow_cast(make_array(1,2,3),'FixedSizeList(3, Int64)'), arrow_cast(make_array(1,10,100), 'FixedSizeList(3, Int64)')), + array_has_any(arrow_cast(make_array(1,2,3),'FixedSizeList(3, Int64)'), arrow_cast(make_array(10, 100),'FixedSizeList(2, Int64)')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'FixedSizeList(2, List(Int64))')), + array_has_any(arrow_cast(make_array([1,2], [3,4]), 'FixedSizeList(2, List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'FixedSizeList(2, List(Int64))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'FixedSizeList(1, List(List(Int64)))')), + array_has_any(arrow_cast(make_array([[1,2,3]]), 'FixedSizeList(1, List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'FixedSizeList(2, List(List(Int64)))')) +; +---- +true false true false false false true true false false true false true # rewrite various array_has operations to InList where the haystack is a literal list # NB that `col in (a, b, c)` is simplified to OR if there are <= 3 elements, so we make 4-element haystack lists @@ -5993,8 +6197,8 @@ logical_plan 03)----SubqueryAlias: test 04)------SubqueryAlias: t 05)--------Projection: -06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) -07)------------TableScan: tmp_table projection=[value] +06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) +07)------------TableScan: generate_series() projection=[value] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] @@ -6002,7 +6206,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("a"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("b"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("c"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6022,8 +6226,8 @@ logical_plan 03)----SubqueryAlias: test 04)------SubqueryAlias: t 05)--------Projection: -06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) -07)------------TableScan: tmp_table projection=[value] +06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) +07)------------TableScan: generate_series() projection=[value] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] @@ -6031,7 +6235,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("a"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("b"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("c"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6051,8 +6255,8 @@ logical_plan 03)----SubqueryAlias: test 04)------SubqueryAlias: t 05)--------Projection: -06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) -07)------------TableScan: tmp_table projection=[value] +06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) +07)------------TableScan: generate_series() projection=[value] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] @@ -6060,7 +6264,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("a"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("b"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("c"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6082,8 +6286,8 @@ logical_plan 03)----SubqueryAlias: test 04)------SubqueryAlias: t 05)--------Projection: -06)----------Filter: array_has(LargeList([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32))) -07)------------TableScan: tmp_table projection=[value] +06)----------Filter: array_has(LargeList([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c]), substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32))) +07)------------TableScan: generate_series() projection=[value] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] @@ -6091,7 +6295,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: array_has([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c], substr(md5(CAST(value@0 AS Utf8)), 1, 32)) +07)------------FilterExec: array_has([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c], substr(md5(CAST(value@0 AS Utf8View)), 1, 32)) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6111,8 +6315,8 @@ logical_plan 03)----SubqueryAlias: test 04)------SubqueryAlias: t 05)--------Projection: -06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) -07)------------TableScan: tmp_table projection=[value] +06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IN ([Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), Utf8View("a"), Utf8View("b"), Utf8View("c")]) +07)------------TableScan: generate_series() projection=[value] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] @@ -6120,7 +6324,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278") }, Literal { value: Utf8View("a") }, Literal { value: Utf8View("b") }, Literal { value: Utf8View("c") }]) +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([Literal { value: Utf8View("7f4b18de3cfeb9b4ac78c381ee2ad278"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("a"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("b"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("c"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6130,7 +6334,7 @@ select count(*) from test WHERE array_has([needle], needle); ---- 100000 -# The optimizer does not currently eliminate the filter; +# The optimizer does not currently eliminate the filter; # Instead, it's rewritten as `IS NULL OR NOT NULL` due to SQL null semantics query TT explain with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) @@ -6141,9 +6345,9 @@ logical_plan 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 03)----SubqueryAlias: test 04)------SubqueryAlias: t -05)--------Projection: -06)----------Filter: substr(CAST(md5(CAST(tmp_table.value AS Utf8)) AS Utf8), Int64(1), Int64(32)) IS NOT NULL OR Boolean(NULL) -07)------------TableScan: tmp_table projection=[value] +05)--------Projection: +06)----------Filter: substr(CAST(md5(CAST(generate_series().value AS Utf8View)) AS Utf8View), Int64(1), Int64(32)) IS NOT NULL OR Boolean(NULL) +07)------------TableScan: generate_series() projection=[value] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as count(*)] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] @@ -6151,7 +6355,7 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------ProjectionExec: expr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8)), 1, 32) IS NOT NULL OR NULL +07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IS NOT NULL OR NULL 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] @@ -6427,12 +6631,12 @@ select array_intersect(arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)'), null) query ? select array_intersect(null, [1, 1, 2, 2, 3, 3]); ---- -NULL +[] query ? select array_intersect(null, arrow_cast([1, 1, 2, 2, 3, 3], 'LargeList(Int64)')); ---- -NULL +[] query ? select array_intersect([], null); @@ -6457,12 +6661,12 @@ select array_intersect(arrow_cast([], 'LargeList(Int64)'), null); query ? select array_intersect(null, []); ---- -NULL +[] query ? select array_intersect(null, arrow_cast([], 'LargeList(Int64)')); ---- -NULL +[] query ? select array_intersect(null, null); @@ -7285,12 +7489,10 @@ select array_concat(column1, [7]) from arrays_values_v2; # flatten -#TODO: https://github.com/apache/datafusion/issues/7142 -# follow DuckDB -#query ? -#select flatten(NULL); -#---- -#NULL +query ? +select flatten(NULL); +---- +NULL # flatten with scalar values #1 query ??? @@ -7298,21 +7500,21 @@ select flatten(make_array(1, 2, 1, 3, 2)), flatten(make_array([1], [2, 3], [null], make_array(4, null, 5))), flatten(make_array([[1.1]], [[2.2]], [[3.3], [4.4]])); ---- -[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [1.1, 2.2, 3.3, 4.4] +[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [[1.1], [2.2], [3.3], [4.4]] query ??? select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'LargeList(Int64)')), flatten(arrow_cast(make_array([1], [2, 3], [null], make_array(4, null, 5)), 'LargeList(LargeList(Int64))')), flatten(arrow_cast(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]), 'LargeList(LargeList(LargeList(Float64)))')); ---- -[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [1.1, 2.2, 3.3, 4.4] +[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [[1.1], [2.2], [3.3], [4.4]] query ??? select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'FixedSizeList(5, Int64)')), flatten(arrow_cast(make_array([1], [2, 3], [null], make_array(4, null, 5)), 'FixedSizeList(4, List(Int64))')), flatten(arrow_cast(make_array([[1.1], [2.2]], [[3.3], [4.4]]), 'FixedSizeList(2, List(List(Float64)))')); ---- -[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [1.1, 2.2, 3.3, 4.4] +[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [[1.1], [2.2], [3.3], [4.4]] # flatten with column values query ???? @@ -7322,8 +7524,8 @@ select flatten(column1), flatten(column4) from flatten_table; ---- -[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] -[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] +[1, 2, 3] [[1, 2, 3], [4, 5], [6]] [[[1]], [[2, 3]]] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] +[1, 2, 3, 4, 5, 6] [[8]] [[[1, 2]], [[3]]] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] query ???? select flatten(column1), @@ -7332,8 +7534,8 @@ select flatten(column1), flatten(column4) from large_flatten_table; ---- -[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] -[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] +[1, 2, 3] [[1, 2, 3], [4, 5], [6]] [[[1]], [[2, 3]]] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] +[1, 2, 3, 4, 5, 6] [[8]] [[[1, 2]], [[3]]] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] query ???? select flatten(column1), @@ -7342,8 +7544,19 @@ select flatten(column1), flatten(column4) from fixed_size_flatten_table; ---- -[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] -[1, 2, 3, 4, 5, 6] [8, 9, 10, 11, 12, 13] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] +[1, 2, 3] [[1, 2, 3], [4, 5], [6]] [[[1]], [[2, 3]]] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] +[1, 2, 3, 4, 5, 6] [[8], [9, 10], [11, 12, 13]] [[[1, 2]], [[3]]] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + +# flatten with different inner list type +query ?????? +select flatten(arrow_cast(make_array([1, 2], [3, 4]), 'List(FixedSizeList(2, Int64))')), + flatten(arrow_cast(make_array([[1, 2]], [[3, 4]]), 'List(FixedSizeList(1, List(Int64)))')), + flatten(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(List(Int64))')), + flatten(arrow_cast(make_array([[1, 2]], [[3, 4]]), 'LargeList(List(List(Int64)))')), + flatten(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(FixedSizeList(2, Int64))')), + flatten(arrow_cast(make_array([[1, 2]], [[3, 4]]), 'LargeList(FixedSizeList(1, List(Int64)))')) +---- +[1, 2, 3, 4] [[1, 2], [3, 4]] [1, 2, 3, 4] [[1, 2], [3, 4]] [1, 2, 3, 4] [[1, 2], [3, 4]] ## empty (aliases: `array_empty`, `list_empty`) # empty scalar function #1 @@ -7764,11 +7977,13 @@ select array_reverse(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), array ---- [3, 2, 1] [1] -#TODO: support after FixedSizeList type coercion -#query ?? -#select array_reverse(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), array_reverse(arrow_cast(make_array(1), 'FixedSizeList(1, Int64)')); -#---- -#[3, 2, 1] [1] +query ???? +select array_reverse(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), + array_reverse(arrow_cast(make_array(1), 'FixedSizeList(1, Int64)')), + array_reverse(arrow_cast(make_array(1, NULL, 3), 'FixedSizeList(3, Int64)')), + array_reverse(arrow_cast(make_array(NULL, NULL, NULL), 'FixedSizeList(3, Int64)')); +---- +[3, 2, 1] [1] [3, NULL, 1] [NULL, NULL, NULL] query ?? select array_reverse(NULL), array_reverse([]); @@ -7824,7 +8039,7 @@ List(Field { name: "item", data_type: List(Field { name: "item", data_type: Int3 query ??T select [1,2,3]::int[], [['1']]::int[][], arrow_typeof([]::text[]); ---- -[1, 2, 3] [[1]] List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +[1, 2, 3] [[1]] List(Field { name: "item", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) # test empty arrays return length # issue: https://github.com/apache/datafusion/pull/12459 diff --git a/datafusion/sqllogictest/test_files/array_query.slt b/datafusion/sqllogictest/test_files/array_query.slt index 8fde295e6051..65d4fa495e3b 100644 --- a/datafusion/sqllogictest/test_files/array_query.slt +++ b/datafusion/sqllogictest/test_files/array_query.slt @@ -108,11 +108,15 @@ SELECT * FROM data WHERE column2 is not distinct from null; # Aggregates ########### -query error Internal error: Min/Max accumulator not implemented for type List +query ? SELECT min(column1) FROM data; +---- +[1, 2, 3] -query error Internal error: Min/Max accumulator not implemented for type List +query ? SELECT max(column1) FROM data; +---- +[2, 3] query I SELECT count(column1) FROM data; diff --git a/datafusion/sqllogictest/test_files/arrow_files.slt b/datafusion/sqllogictest/test_files/arrow_files.slt index 30f322cf98fc..62453ec4bf3e 100644 --- a/datafusion/sqllogictest/test_files/arrow_files.slt +++ b/datafusion/sqllogictest/test_files/arrow_files.slt @@ -19,6 +19,11 @@ ## Arrow Files Format support ############# +# We using fixed arrow file to test for sqllogictests, and this arrow field is writing with arrow-ipc utf8, +# so when we decode to read it's also loading utf8. +# Currently, so we disable the map_string_types_to_utf8view +statement ok +set datafusion.sql_parser.map_string_types_to_utf8view = false; statement ok diff --git a/datafusion/sqllogictest/test_files/avro.slt b/datafusion/sqllogictest/test_files/avro.slt index 1b4150b074cc..2ad60c0082e8 100644 --- a/datafusion/sqllogictest/test_files/avro.slt +++ b/datafusion/sqllogictest/test_files/avro.slt @@ -15,6 +15,10 @@ # specific language governing permissions and limitations # under the License. +# Currently, the avro not support Utf8View type, so we disable the map_string_types_to_utf8view +# After https://github.com/apache/arrow-rs/issues/7262 released, we can remove this setting +statement ok +set datafusion.sql_parser.map_string_types_to_utf8view = false; statement ok CREATE EXTERNAL TABLE alltypes_plain ( @@ -253,3 +257,13 @@ physical_plan 04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/avro/alltypes_plain.avro]]}, file_type=avro + +# test column projection order from avro file +query ITII +SELECT id, string_col, int_col, bigint_col FROM alltypes_plain ORDER BY id LIMIT 5 +---- +0 0 0 0 +1 1 1 10 +2 0 0 0 +3 1 1 10 +4 0 0 0 diff --git a/datafusion/sqllogictest/test_files/binary.slt b/datafusion/sqllogictest/test_files/binary.slt index 5c5f9d510e55..1077c32e46f3 100644 --- a/datafusion/sqllogictest/test_files/binary.slt +++ b/datafusion/sqllogictest/test_files/binary.slt @@ -147,8 +147,45 @@ query error DataFusion error: Error during planning: Cannot infer common argumen SELECT column1, column1 = arrow_cast(X'0102', 'FixedSizeBinary(2)') FROM t # Comparison to different sized Binary -query error DataFusion error: Error during planning: Cannot infer common argument type for comparison operation FixedSizeBinary\(3\) = Binary +query ?B SELECT column1, column1 = X'0102' FROM t +---- +000102 false +003102 false +NULL NULL +ff0102 false +000102 false + +query ?B +SELECT column1, column1 = X'000102' FROM t +---- +000102 true +003102 false +NULL NULL +ff0102 false +000102 true + +query ?B +SELECT arrow_cast(column1, 'FixedSizeBinary(3)'), arrow_cast(column1, 'FixedSizeBinary(3)') = arrow_cast(arrow_cast(X'000102', 'FixedSizeBinary(3)'), 'BinaryView') FROM t; +---- +000102 true +003102 false +NULL NULL +ff0102 false +000102 true + +# Plan should not have a cast of the column (should have casted the literal +# to FixedSizeBinary as that is much faster) + +query TT +explain SELECT column1, column1 = X'000102' FROM t +---- +logical_plan +01)Projection: t.column1, t.column1 = FixedSizeBinary(3, "0,1,2") AS t.column1 = Binary("0,1,2") +02)--TableScan: t projection=[column1] +physical_plan +01)ProjectionExec: expr=[column1@0 as column1, column1@0 = 000102 as t.column1 = Binary("0,1,2")] +02)--DataSourceExec: partitions=1, partition_sizes=[1] statement ok drop table t_source diff --git a/datafusion/sqllogictest/test_files/clickbench_extended.slt b/datafusion/sqllogictest/test_files/clickbench_extended.slt new file mode 100644 index 000000000000..ee3e33551ee3 --- /dev/null +++ b/datafusion/sqllogictest/test_files/clickbench_extended.slt @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# DataFusion specific ClickBench "Extended" Queries +# See data provenance notes in clickbench.slt + +statement ok +CREATE EXTERNAL TABLE hits +STORED AS PARQUET +LOCATION '../core/tests/data/clickbench_hits_10.parquet'; + +# If you change any of these queries, please change the corresponding query in +# benchmarks/queries/clickbench/extended.sql and update the README. + +query III +SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; +---- +1 1 1 + +query III +SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; +---- +1 1 1 + +query TIIII +SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; +---- +� 1 1 1 1 + +query IIIRRRR +SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") FROM hits GROUP BY "SocialSourceNetworkID", "RegionID" HAVING s IS NOT NULL ORDER BY s DESC LIMIT 10; +---- +0 839 6 0 0 0 0 +0 197 2 0 0 0 0 + +query IIIIII +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, MEDIAN("ResponseStartTiming") tmed, MAX("ResponseStartTiming") tmax FROM hits WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tmed DESC LIMIT 10; +---- + +query IIIIII +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY "ResponseStartTiming") tp95, MAX("ResponseStartTiming") tmax FROM 'hits' WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tp95 DESC LIMIT 10; +---- + +query I +SELECT COUNT(*) AS ShareCount FROM hits WHERE "IsMobile" = 1 AND "MobilePhoneModel" LIKE 'iPhone%' AND "SocialAction" = 'share' AND "SocialSourceNetworkID" IN (5, 12) AND "ClientTimeZone" BETWEEN -5 AND 5 AND regexp_match("Referer", '\/campaign\/(spring|summer)_promo') IS NOT NULL AND CASE WHEN split_part(split_part("URL", 'resolution=', 2), '&', 1) ~ '^\d+$' THEN split_part(split_part("URL", 'resolution=', 2), '&', 1)::INT ELSE 0 END > 1920 AND levenshtein(CAST("UTMSource" AS STRING), CAST("UTMCampaign" AS STRING)) < 3; +---- +0 + + +statement ok +drop table hits; diff --git a/datafusion/sqllogictest/test_files/coalesce.slt b/datafusion/sqllogictest/test_files/coalesce.slt index e7cf31dc690b..9740bade5e27 100644 --- a/datafusion/sqllogictest/test_files/coalesce.slt +++ b/datafusion/sqllogictest/test_files/coalesce.slt @@ -260,8 +260,8 @@ select arrow_typeof(coalesce(c, arrow_cast('b', 'Dictionary(Int32, Utf8)'))) from t; ---- -a Dictionary(Int32, Utf8) -b Dictionary(Int32, Utf8) +a Utf8View +b Utf8View statement ok drop table t; diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index 925f96bd4ac0..5eeb05e814ac 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -637,7 +637,7 @@ query error DataFusion error: SQL error: ParserError\("Expected: \), found: EOF" COPY (select col2, sum(col1) from source_table # Copy from table with non literal -query error DataFusion error: SQL error: ParserError\("Unexpected token \("\) +query error DataFusion error: SQL error: ParserError\("Expected: end of statement or ;, found: \( at Line: 1, Column: 44"\) COPY source_table to '/tmp/table.parquet' (row_group_size 55 + 102); # Copy using execution.keep_partition_by_columns with an invalid value diff --git a/datafusion/sqllogictest/test_files/create_external_table.slt b/datafusion/sqllogictest/test_files/create_external_table.slt index bb66aef2514c..03cb5edb5fcc 100644 --- a/datafusion/sqllogictest/test_files/create_external_table.slt +++ b/datafusion/sqllogictest/test_files/create_external_table.slt @@ -77,7 +77,7 @@ statement error DataFusion error: SQL error: ParserError\("Expected: HEADER, fou CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH LOCATION 'foo.csv'; # Unrecognized random clause -statement error DataFusion error: SQL error: ParserError\("Unexpected token FOOBAR"\) +statement error DataFusion error: SQL error: ParserError\("Expected: end of statement or ;, found: FOOBAR at Line: 1, Column: 47"\) CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV FOOBAR BARBAR BARFOO LOCATION 'foo.csv'; # Missing partition column diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index e019af9775a4..32320a06f4fb 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -722,7 +722,7 @@ logical_plan 03)----Projection: Int64(1) AS val 04)------EmptyRelation 05)----Projection: Int64(2) AS val -06)------Cross Join: +06)------Cross Join: 07)--------Filter: recursive_cte.val < Int64(2) 08)----------TableScan: recursive_cte 09)--------SubqueryAlias: sub_cte diff --git a/datafusion/sqllogictest/test_files/ddl.slt b/datafusion/sqllogictest/test_files/ddl.slt index 088d0155a66f..81f2955eff49 100644 --- a/datafusion/sqllogictest/test_files/ddl.slt +++ b/datafusion/sqllogictest/test_files/ddl.slt @@ -819,7 +819,7 @@ show columns FROM table_with_pk; ---- datafusion public table_with_pk sn Int32 NO datafusion public table_with_pk ts Timestamp(Nanosecond, Some("+00:00")) NO -datafusion public table_with_pk currency Utf8 NO +datafusion public table_with_pk currency Utf8View NO datafusion public table_with_pk amount Float32 YES statement ok @@ -828,18 +828,18 @@ drop table table_with_pk; statement ok set datafusion.catalog.information_schema = false; -# Test VARCHAR is mapped to Utf8View during SQL planning when setting map_varchar_to_utf8view to true +# Test VARCHAR is mapped to Utf8View during SQL planning when setting map_string_types_to_utf8view to true statement ok CREATE TABLE t1(c1 VARCHAR(10) NOT NULL, c2 VARCHAR); query TTT DESCRIBE t1; ---- -c1 Utf8 NO -c2 Utf8 YES +c1 Utf8View NO +c2 Utf8View YES statement ok -set datafusion.sql_parser.map_varchar_to_utf8view = true; +set datafusion.sql_parser.map_string_types_to_utf8view = true; statement ok CREATE TABLE t2(c1 VARCHAR(10) NOT NULL, c2 VARCHAR); diff --git a/datafusion/sqllogictest/test_files/delete.slt b/datafusion/sqllogictest/test_files/delete.slt new file mode 100644 index 000000000000..258318f09423 --- /dev/null +++ b/datafusion/sqllogictest/test_files/delete.slt @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Delete Tests +########## + +statement ok +create table t1(a int, b varchar, c double, d int); + +# Turn off the optimizer to make the logical plan closer to the initial one +statement ok +set datafusion.optimizer.max_passes = 0; + + +# Delete all +query TT +explain delete from t1; +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--TableScan: t1 +physical_plan_error This feature is not implemented: Unsupported logical plan: Dml(Delete) + + +# Filtered by existing columns +query TT +explain delete from t1 where a = 1 and b = 2 and c > 3 and d != 4; +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Filter: CAST(t1.a AS Int64) = Int64(1) AND t1.b = CAST(Int64(2) AS Utf8View) AND t1.c > CAST(Int64(3) AS Float64) AND CAST(t1.d AS Int64) != Int64(4) +03)----TableScan: t1 +physical_plan_error This feature is not implemented: Unsupported logical plan: Dml(Delete) + + +# Filtered by existing columns, using qualified and unqualified names +query TT +explain delete from t1 where t1.a = 1 and b = 2 and t1.c > 3 and d != 4; +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Filter: CAST(t1.a AS Int64) = Int64(1) AND t1.b = CAST(Int64(2) AS Utf8View) AND t1.c > CAST(Int64(3) AS Float64) AND CAST(t1.d AS Int64) != Int64(4) +03)----TableScan: t1 +physical_plan_error This feature is not implemented: Unsupported logical plan: Dml(Delete) + + +# Filtered by a mix of columns and literal predicates +query TT +explain delete from t1 where a = 1 and 1 = 1 and true; +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Filter: CAST(t1.a AS Int64) = Int64(1) AND Int64(1) = Int64(1) AND Boolean(true) +03)----TableScan: t1 +physical_plan_error This feature is not implemented: Unsupported logical plan: Dml(Delete) + + +# Deleting by columns that do not exist returns an error +query error DataFusion error: Schema error: No field named e. Valid fields are t1.a, t1.b, t1.c, t1.d. +explain delete from t1 where e = 1; + + +# Filtering using subqueries + +statement ok +create table t2(a int, b varchar, c double, d int); + +query TT +explain delete from t1 where a = (select max(a) from t2 where t1.b = t2.b); +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Filter: t1.a = () +03)----Subquery: +04)------Projection: max(t2.a) +05)--------Aggregate: groupBy=[[]], aggr=[[max(t2.a)]] +06)----------Filter: outer_ref(t1.b) = t2.b +07)------------TableScan: t2 +08)----TableScan: t1 +physical_plan_error This feature is not implemented: Physical plan does not support logical expression ScalarSubquery() + +query TT +explain delete from t1 where a in (select a from t2); +---- +logical_plan +01)Dml: op=[Delete] table=[t1] +02)--Filter: t1.a IN () +03)----Subquery: +04)------Projection: t2.a +05)--------TableScan: t2 +06)----TableScan: t1 +physical_plan_error This feature is not implemented: Physical plan does not support logical expression InSubquery(InSubquery { expr: Column(Column { relation: Some(Bare { table: "t1" }), name: "a" }), subquery: , negated: false }) diff --git a/datafusion/sqllogictest/test_files/describe.slt b/datafusion/sqllogictest/test_files/describe.slt index e4cb30628eec..de5208b5483a 100644 --- a/datafusion/sqllogictest/test_files/describe.slt +++ b/datafusion/sqllogictest/test_files/describe.slt @@ -86,3 +86,33 @@ string_col Utf8View YES timestamp_col Timestamp(Nanosecond, None) YES year Int32 YES month Int32 YES + +# Test DESC alias functionality +statement ok +CREATE TABLE test_desc_table (id INT, name VARCHAR); + +# Test DESC works the same as DESCRIBE +query TTT +DESC test_desc_table; +---- +id Int32 YES +name Utf8View YES + +query TTT +DESCRIBE test_desc_table; +---- +id Int32 YES +name Utf8View YES + +# Test with qualified table names +statement ok +CREATE TABLE public.test_qualified (col1 INT); + +query TTT +DESC public.test_qualified; +---- +col1 Int32 YES + +# Test error cases +statement error +DESC nonexistent_table; diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt index 1769f42c2d2a..d241e61f33ff 100644 --- a/datafusion/sqllogictest/test_files/dictionary.slt +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -456,4 +456,4 @@ statement ok CREATE TABLE test0 AS VALUES ('foo',1), ('bar',2), ('foo',3); statement ok -COPY (SELECT arrow_cast(column1, 'Dictionary(Int32, Utf8)') AS column1, column2 FROM test0) TO 'test_files/scratch/copy/part_dict_test' STORED AS PARQUET PARTITIONED BY (column1); \ No newline at end of file +COPY (SELECT arrow_cast(column1, 'Dictionary(Int32, Utf8)') AS column1, column2 FROM test0) TO 'test_files/scratch/copy/part_dict_test' STORED AS PARQUET PARTITIONED BY (column1); diff --git a/datafusion/sqllogictest/test_files/encrypted_parquet.slt b/datafusion/sqllogictest/test_files/encrypted_parquet.slt new file mode 100644 index 000000000000..d580b7d1ad2b --- /dev/null +++ b/datafusion/sqllogictest/test_files/encrypted_parquet.slt @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Test parquet encryption and decryption in DataFusion SQL. +# See datafusion/common/src/config.rs for equivalent rust code + +statement count 0 +CREATE EXTERNAL TABLE encrypted_parquet_table +( +double_field double, +float_field float +) +STORED AS PARQUET LOCATION 'test_files/scratch/encrypted_parquet/' OPTIONS ( + -- Configure encryption for reading and writing Parquet files + -- Encryption properties + 'format.crypto.file_encryption.encrypt_footer' 'true', + 'format.crypto.file_encryption.footer_key_as_hex' '30313233343536373839303132333435', -- b"0123456789012345" + 'format.crypto.file_encryption.column_key_as_hex::double_field' '31323334353637383930313233343530', -- b"1234567890123450" + 'format.crypto.file_encryption.column_key_as_hex::float_field' '31323334353637383930313233343531', -- b"1234567890123451" + -- Decryption properties + 'format.crypto.file_decryption.footer_key_as_hex' '30313233343536373839303132333435', -- b"0123456789012345" + 'format.crypto.file_decryption.column_key_as_hex::double_field' '31323334353637383930313233343530', -- b"1234567890123450" + 'format.crypto.file_decryption.column_key_as_hex::float_field' '31323334353637383930313233343531', -- b"1234567890123451" +) + +statement count 0 +CREATE TABLE temp_table ( + double_field double, + float_field float +) + +query I +INSERT INTO temp_table VALUES(-1.0, -1.0) +---- +1 + +query I +INSERT INTO temp_table VALUES(1.0, 2.0) +---- +1 + +query I +INSERT INTO temp_table VALUES(3.0, 4.0) +---- +1 + +query I +INSERT INTO temp_table VALUES(5.0, 6.0) +---- +1 + +query I +INSERT INTO TABLE encrypted_parquet_table(double_field, float_field) SELECT * FROM temp_table +---- +4 + +query RR +SELECT * FROM encrypted_parquet_table +WHERE double_field > 0.0 AND float_field > 0.0 +ORDER BY double_field +---- +1 2 +3 4 +5 6 + +statement count 0 +CREATE EXTERNAL TABLE parquet_table +( +double_field double, +float_field float +) +STORED AS PARQUET LOCATION 'test_files/scratch/encrypted_parquet/' + +query error DataFusion error: Parquet error: Parquet error: Parquet file has an encrypted footer but decryption properties were not provided +SELECT * FROM parquet_table diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index deff793e5110..50575a3aba4d 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -183,6 +183,7 @@ logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE logical_plan after eliminate_join SAME TEXT AS ABOVE logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE +logical_plan after decorrelate_lateral_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE @@ -204,6 +205,7 @@ logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE logical_plan after eliminate_join SAME TEXT AS ABOVE logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE +logical_plan after decorrelate_lateral_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE @@ -229,16 +231,20 @@ physical_plan after OutputRequirements physical_plan after aggregate_statistics SAME TEXT AS ABOVE physical_plan after join_selection SAME TEXT AS ABOVE physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after FilterPushdown SAME TEXT AS ABOVE physical_plan after EnforceDistribution SAME TEXT AS ABOVE physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after coalesce_async_exec_input SAME TEXT AS ABOVE physical_plan after OutputRequirements DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after LimitPushdown SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan after EnsureCooperative SAME TEXT AS ABOVE +physical_plan after FilterPushdown(Post) SAME TEXT AS ABOVE physical_plan after SanityCheckPlan SAME TEXT AS ABOVE physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true physical_plan_with_stats DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] @@ -303,18 +309,22 @@ physical_plan after OutputRequirements physical_plan after aggregate_statistics SAME TEXT AS ABOVE physical_plan after join_selection SAME TEXT AS ABOVE physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after FilterPushdown SAME TEXT AS ABOVE physical_plan after EnforceDistribution SAME TEXT AS ABOVE physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after coalesce_async_exec_input SAME TEXT AS ABOVE physical_plan after OutputRequirements 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Exact(671), [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Exact(671), [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after LimitPushdown DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Exact(671), [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan after EnsureCooperative SAME TEXT AS ABOVE +physical_plan after FilterPushdown(Post) SAME TEXT AS ABOVE physical_plan after SanityCheckPlan SAME TEXT AS ABOVE physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Exact(671), [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] physical_plan_with_schema DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:BinaryView;N, string_col:BinaryView;N, timestamp_col:Timestamp(Nanosecond, None);N] @@ -343,18 +353,22 @@ physical_plan after OutputRequirements physical_plan after aggregate_statistics SAME TEXT AS ABOVE physical_plan after join_selection SAME TEXT AS ABOVE physical_plan after LimitedDistinctAggregation SAME TEXT AS ABOVE +physical_plan after FilterPushdown SAME TEXT AS ABOVE physical_plan after EnforceDistribution SAME TEXT AS ABOVE physical_plan after CombinePartialFinalAggregate SAME TEXT AS ABOVE physical_plan after EnforceSorting SAME TEXT AS ABOVE physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE +physical_plan after coalesce_async_exec_input SAME TEXT AS ABOVE physical_plan after OutputRequirements 01)GlobalLimitExec: skip=0, fetch=10 02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after LimitPushdown DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan after EnsureCooperative SAME TEXT AS ABOVE +physical_plan after FilterPushdown(Post) SAME TEXT AS ABOVE physical_plan after SanityCheckPlan SAME TEXT AS ABOVE physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet physical_plan_with_stats DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, file_type=parquet, statistics=[Rows=Exact(8), Bytes=Exact(671), [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] @@ -415,7 +429,7 @@ logical_plan 01)LeftSemi Join: 02)--TableScan: t1 projection=[a] 03)--SubqueryAlias: __correlated_sq_1 -04)----Projection: +04)----Projection: 05)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 06)--------TableScan: t2 projection=[] physical_plan diff --git a/datafusion/sqllogictest/test_files/explain_tree.slt b/datafusion/sqllogictest/test_files/explain_tree.slt index 9041541297a9..f4188f4cb395 100644 --- a/datafusion/sqllogictest/test_files/explain_tree.slt +++ b/datafusion/sqllogictest/test_files/explain_tree.slt @@ -180,8 +180,8 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ RepartitionExec │ 15)│ -------------------- │ -16)│ input_partition_count: │ -17)│ 1 │ +16)│ partition_count(in->out): │ +17)│ 1 -> 4 │ 18)│ │ 19)│ partitioning_scheme: │ 20)│ RoundRobinBatch(4) │ @@ -218,8 +218,8 @@ physical_plan 18)┌─────────────┴─────────────┐ 19)│ RepartitionExec │ 20)│ -------------------- │ -21)│ input_partition_count: │ -22)│ 4 │ +21)│ partition_count(in->out): │ +22)│ 4 -> 4 │ 23)│ │ 24)│ partitioning_scheme: │ 25)│ Hash([string_col@0], 4) │ @@ -236,8 +236,8 @@ physical_plan 36)┌─────────────┴─────────────┐ 37)│ RepartitionExec │ 38)│ -------------------- │ -39)│ input_partition_count: │ -40)│ 1 │ +39)│ partition_count(in->out): │ +40)│ 1 -> 4 │ 41)│ │ 42)│ partitioning_scheme: │ 43)│ RoundRobinBatch(4) │ @@ -280,7 +280,7 @@ physical_plan 06)┌─────────────┴─────────────┐ 07)│ DataSourceExec │ 08)│ -------------------- │ -09)│ bytes: 3120 │ +09)│ bytes: 1040 │ 10)│ format: memory │ 11)│ rows: 2 │ 12)└───────────────────────────┘ @@ -291,47 +291,40 @@ explain SELECT table1.string_col, table2.date_col FROM table1 JOIN table2 ON tab ---- physical_plan 01)┌───────────────────────────┐ -02)│ CoalesceBatchesExec │ +02)│ ProjectionExec │ 03)│ -------------------- │ -04)│ target_batch_size: │ -05)│ 8192 │ -06)└─────────────┬─────────────┘ -07)┌─────────────┴─────────────┐ -08)│ HashJoinExec │ -09)│ -------------------- │ -10)│ on: ├──────────────┐ -11)│ (int_col = int_col) │ │ -12)└─────────────┬─────────────┘ │ -13)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -14)│ CoalesceBatchesExec ││ CoalesceBatchesExec │ -15)│ -------------------- ││ -------------------- │ -16)│ target_batch_size: ││ target_batch_size: │ -17)│ 8192 ││ 8192 │ -18)└─────────────┬─────────────┘└─────────────┬─────────────┘ -19)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -20)│ RepartitionExec ││ RepartitionExec │ -21)│ -------------------- ││ -------------------- │ -22)│ input_partition_count: ││ input_partition_count: │ -23)│ 4 ││ 4 │ -24)│ ││ │ -25)│ partitioning_scheme: ││ partitioning_scheme: │ -26)│ Hash([int_col@0], 4) ││ Hash([int_col@0], 4) │ -27)└─────────────┬─────────────┘└─────────────┬─────────────┘ -28)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -29)│ RepartitionExec ││ RepartitionExec │ -30)│ -------------------- ││ -------------------- │ -31)│ input_partition_count: ││ input_partition_count: │ -32)│ 1 ││ 1 │ -33)│ ││ │ -34)│ partitioning_scheme: ││ partitioning_scheme: │ -35)│ RoundRobinBatch(4) ││ RoundRobinBatch(4) │ -36)└─────────────┬─────────────┘└─────────────┬─────────────┘ -37)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -38)│ DataSourceExec ││ DataSourceExec │ -39)│ -------------------- ││ -------------------- │ -40)│ files: 1 ││ files: 1 │ -41)│ format: csv ││ format: parquet │ -42)└───────────────────────────┘└───────────────────────────┘ +04)│ date_col: date_col │ +05)│ │ +06)│ string_col: │ +07)│ string_col │ +08)└─────────────┬─────────────┘ +09)┌─────────────┴─────────────┐ +10)│ CoalesceBatchesExec │ +11)│ -------------------- │ +12)│ target_batch_size: │ +13)│ 8192 │ +14)└─────────────┬─────────────┘ +15)┌─────────────┴─────────────┐ +16)│ HashJoinExec │ +17)│ -------------------- │ +18)│ on: ├──────────────┐ +19)│ (int_col = int_col) │ │ +20)└─────────────┬─────────────┘ │ +21)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ +22)│ DataSourceExec ││ RepartitionExec │ +23)│ -------------------- ││ -------------------- │ +24)│ files: 1 ││ partition_count(in->out): │ +25)│ format: parquet ││ 1 -> 4 │ +26)│ ││ │ +27)│ ││ partitioning_scheme: │ +28)│ ││ RoundRobinBatch(4) │ +29)└───────────────────────────┘└─────────────┬─────────────┘ +30)-----------------------------┌─────────────┴─────────────┐ +31)-----------------------------│ DataSourceExec │ +32)-----------------------------│ -------------------- │ +33)-----------------------------│ files: 1 │ +34)-----------------------------│ format: csv │ +35)-----------------------------└───────────────────────────┘ # 3 Joins query TT @@ -365,48 +358,41 @@ physical_plan 19)│ (int_col = int_col) │ │ 20)└─────────────┬─────────────┘ │ 21)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -22)│ DataSourceExec ││ CoalesceBatchesExec │ +22)│ DataSourceExec ││ ProjectionExec │ 23)│ -------------------- ││ -------------------- │ -24)│ bytes: 1560 ││ target_batch_size: │ -25)│ format: memory ││ 8192 │ +24)│ bytes: 520 ││ date_col: date_col │ +25)│ format: memory ││ int_col: int_col │ 26)│ rows: 1 ││ │ -27)└───────────────────────────┘└─────────────┬─────────────┘ -28)-----------------------------┌─────────────┴─────────────┐ -29)-----------------------------│ HashJoinExec │ -30)-----------------------------│ -------------------- │ -31)-----------------------------│ on: ├──────────────┐ -32)-----------------------------│ (int_col = int_col) │ │ -33)-----------------------------└─────────────┬─────────────┘ │ -34)-----------------------------┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -35)-----------------------------│ CoalesceBatchesExec ││ CoalesceBatchesExec │ -36)-----------------------------│ -------------------- ││ -------------------- │ -37)-----------------------------│ target_batch_size: ││ target_batch_size: │ -38)-----------------------------│ 8192 ││ 8192 │ -39)-----------------------------└─────────────┬─────────────┘└─────────────┬─────────────┘ -40)-----------------------------┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -41)-----------------------------│ RepartitionExec ││ RepartitionExec │ -42)-----------------------------│ -------------------- ││ -------------------- │ -43)-----------------------------│ input_partition_count: ││ input_partition_count: │ -44)-----------------------------│ 4 ││ 4 │ -45)-----------------------------│ ││ │ -46)-----------------------------│ partitioning_scheme: ││ partitioning_scheme: │ -47)-----------------------------│ Hash([int_col@0], 4) ││ Hash([int_col@0], 4) │ -48)-----------------------------└─────────────┬─────────────┘└─────────────┬─────────────┘ -49)-----------------------------┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -50)-----------------------------│ RepartitionExec ││ RepartitionExec │ -51)-----------------------------│ -------------------- ││ -------------------- │ -52)-----------------------------│ input_partition_count: ││ input_partition_count: │ -53)-----------------------------│ 1 ││ 1 │ -54)-----------------------------│ ││ │ -55)-----------------------------│ partitioning_scheme: ││ partitioning_scheme: │ -56)-----------------------------│ RoundRobinBatch(4) ││ RoundRobinBatch(4) │ -57)-----------------------------└─────────────┬─────────────┘└─────────────┬─────────────┘ -58)-----------------------------┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -59)-----------------------------│ DataSourceExec ││ DataSourceExec │ -60)-----------------------------│ -------------------- ││ -------------------- │ -61)-----------------------------│ files: 1 ││ files: 1 │ -62)-----------------------------│ format: csv ││ format: parquet │ -63)-----------------------------└───────────────────────────┘└───────────────────────────┘ +27)│ ││ string_col: │ +28)│ ││ string_col │ +29)└───────────────────────────┘└─────────────┬─────────────┘ +30)-----------------------------┌─────────────┴─────────────┐ +31)-----------------------------│ CoalesceBatchesExec │ +32)-----------------------------│ -------------------- │ +33)-----------------------------│ target_batch_size: │ +34)-----------------------------│ 8192 │ +35)-----------------------------└─────────────┬─────────────┘ +36)-----------------------------┌─────────────┴─────────────┐ +37)-----------------------------│ HashJoinExec │ +38)-----------------------------│ -------------------- │ +39)-----------------------------│ on: ├──────────────┐ +40)-----------------------------│ (int_col = int_col) │ │ +41)-----------------------------└─────────────┬─────────────┘ │ +42)-----------------------------┌─────────────┴─────────────┐┌─────────────┴─────────────┐ +43)-----------------------------│ DataSourceExec ││ RepartitionExec │ +44)-----------------------------│ -------------------- ││ -------------------- │ +45)-----------------------------│ files: 1 ││ partition_count(in->out): │ +46)-----------------------------│ format: parquet ││ 1 -> 4 │ +47)-----------------------------│ ││ │ +48)-----------------------------│ ││ partitioning_scheme: │ +49)-----------------------------│ ││ RoundRobinBatch(4) │ +50)-----------------------------└───────────────────────────┘└─────────────┬─────────────┘ +51)----------------------------------------------------------┌─────────────┴─────────────┐ +52)----------------------------------------------------------│ DataSourceExec │ +53)----------------------------------------------------------│ -------------------- │ +54)----------------------------------------------------------│ files: 1 │ +55)----------------------------------------------------------│ format: csv │ +56)----------------------------------------------------------└───────────────────────────┘ # Long Filter (demonstrate what happens with wrapping) query TT @@ -434,8 +420,8 @@ physical_plan 17)┌─────────────┴─────────────┐ 18)│ RepartitionExec │ 19)│ -------------------- │ -20)│ input_partition_count: │ -21)│ 1 │ +20)│ partition_count(in->out): │ +21)│ 1 -> 4 │ 22)│ │ 23)│ partitioning_scheme: │ 24)│ RoundRobinBatch(4) │ @@ -496,8 +482,8 @@ physical_plan 41)┌─────────────┴─────────────┐ 42)│ RepartitionExec │ 43)│ -------------------- │ -44)│ input_partition_count: │ -45)│ 1 │ +44)│ partition_count(in->out): │ +45)│ 1 -> 4 │ 46)│ │ 47)│ partitioning_scheme: │ 48)│ RoundRobinBatch(4) │ @@ -530,8 +516,8 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ RepartitionExec │ 15)│ -------------------- │ -16)│ input_partition_count: │ -17)│ 1 │ +16)│ partition_count(in->out): │ +17)│ 1 -> 4 │ 18)│ │ 19)│ partitioning_scheme: │ 20)│ RoundRobinBatch(4) │ @@ -566,8 +552,8 @@ physical_plan 15)┌─────────────┴─────────────┐ 16)│ RepartitionExec │ 17)│ -------------------- │ -18)│ input_partition_count: │ -19)│ 1 │ +18)│ partition_count(in->out): │ +19)│ 1 -> 4 │ 20)│ │ 21)│ partitioning_scheme: │ 22)│ RoundRobinBatch(4) │ @@ -599,8 +585,8 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ RepartitionExec │ 15)│ -------------------- │ -16)│ input_partition_count: │ -17)│ 1 │ +16)│ partition_count(in->out): │ +17)│ 1 -> 4 │ 18)│ │ 19)│ partitioning_scheme: │ 20)│ RoundRobinBatch(4) │ @@ -633,8 +619,8 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ RepartitionExec │ 15)│ -------------------- │ -16)│ input_partition_count: │ -17)│ 1 │ +16)│ partition_count(in->out): │ +17)│ 1 -> 4 │ 18)│ │ 19)│ partitioning_scheme: │ 20)│ RoundRobinBatch(4) │ @@ -669,7 +655,7 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ DataSourceExec │ 15)│ -------------------- │ -16)│ bytes: 1560 │ +16)│ bytes: 520 │ 17)│ format: memory │ 18)│ rows: 1 │ 19)└───────────────────────────┘ @@ -694,8 +680,8 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ RepartitionExec │ 15)│ -------------------- │ -16)│ input_partition_count: │ -17)│ 1 │ +16)│ partition_count(in->out): │ +17)│ 1 -> 4 │ 18)│ │ 19)│ partitioning_scheme: │ 20)│ RoundRobinBatch(4) │ @@ -727,8 +713,8 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ RepartitionExec │ 15)│ -------------------- │ -16)│ input_partition_count: │ -17)│ 1 │ +16)│ partition_count(in->out): │ +17)│ 1 -> 4 │ 18)│ │ 19)│ partitioning_scheme: │ 20)│ RoundRobinBatch(4) │ @@ -922,8 +908,8 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ RepartitionExec │ 15)│ -------------------- │ -16)│ input_partition_count: │ -17)│ 1 │ +16)│ partition_count(in->out): │ +17)│ 1 -> 4 │ 18)│ │ 19)│ partitioning_scheme: │ 20)│ RoundRobinBatch(4) │ @@ -1029,21 +1015,11 @@ physical_plan 11)│ bigint_col │ 12)└─────────────┬─────────────┘ 13)┌─────────────┴─────────────┐ -14)│ RepartitionExec │ +14)│ DataSourceExec │ 15)│ -------------------- │ -16)│ input_partition_count: │ -17)│ 1 │ -18)│ │ -19)│ partitioning_scheme: │ -20)│ RoundRobinBatch(4) │ -21)└─────────────┬─────────────┘ -22)┌─────────────┴─────────────┐ -23)│ DataSourceExec │ -24)│ -------------------- │ -25)│ files: 1 │ -26)│ format: parquet │ -27)└───────────────────────────┘ - +16)│ files: 1 │ +17)│ format: parquet │ +18)└───────────────────────────┘ # Query with projection on memory query TT @@ -1065,7 +1041,7 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ DataSourceExec │ 15)│ -------------------- │ -16)│ bytes: 1560 │ +16)│ bytes: 520 │ 17)│ format: memory │ 18)│ rows: 1 │ 19)└───────────────────────────┘ @@ -1089,8 +1065,8 @@ physical_plan 12)┌─────────────┴─────────────┐ 13)│ RepartitionExec │ 14)│ -------------------- │ -15)│ input_partition_count: │ -16)│ 1 │ +15)│ partition_count(in->out): │ +16)│ 1 -> 4 │ 17)│ │ 18)│ partitioning_scheme: │ 19)│ RoundRobinBatch(4) │ @@ -1123,8 +1099,8 @@ physical_plan 13)┌─────────────┴─────────────┐ 14)│ RepartitionExec │ 15)│ -------------------- │ -16)│ input_partition_count: │ -17)│ 1 │ +16)│ partition_count(in->out): │ +17)│ 1 -> 4 │ 18)│ │ 19)│ partitioning_scheme: │ 20)│ RoundRobinBatch(4) │ @@ -1186,69 +1162,46 @@ explain select * from table1 inner join table2 on table1.int_col = table2.int_co ---- physical_plan 01)┌───────────────────────────┐ -02)│ CoalesceBatchesExec │ +02)│ ProjectionExec │ 03)│ -------------------- │ -04)│ target_batch_size: │ -05)│ 8192 │ -06)└─────────────┬─────────────┘ -07)┌─────────────┴─────────────┐ -08)│ HashJoinExec │ -09)│ -------------------- │ -10)│ on: │ -11)│ (int_col = int_col), (CAST├──────────────┐ -12)│ (table1.string_col AS │ │ -13)│ Utf8View) = │ │ -14)│ string_col) │ │ -15)└─────────────┬─────────────┘ │ -16)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -17)│ CoalesceBatchesExec ││ CoalesceBatchesExec │ -18)│ -------------------- ││ -------------------- │ -19)│ target_batch_size: ││ target_batch_size: │ -20)│ 8192 ││ 8192 │ -21)└─────────────┬─────────────┘└─────────────┬─────────────┘ -22)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -23)│ RepartitionExec ││ RepartitionExec │ -24)│ -------------------- ││ -------------------- │ -25)│ input_partition_count: ││ input_partition_count: │ -26)│ 4 ││ 4 │ -27)│ ││ │ -28)│ partitioning_scheme: ││ partitioning_scheme: │ -29)│ Hash([int_col@0, CAST ││ Hash([int_col@0, │ -30)│ (table1.string_col ││ string_col@1], │ -31)│ AS Utf8View)@4], 4) ││ 4) │ -32)└─────────────┬─────────────┘└─────────────┬─────────────┘ -33)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -34)│ ProjectionExec ││ RepartitionExec │ -35)│ -------------------- ││ -------------------- │ -36)│ CAST(table1.string_col AS ││ input_partition_count: │ -37)│ Utf8View): ││ 1 │ -38)│ CAST(string_col AS ││ │ -39)│ Utf8View) ││ partitioning_scheme: │ -40)│ ││ RoundRobinBatch(4) │ -41)│ bigint_col: ││ │ -42)│ bigint_col ││ │ -43)│ ││ │ -44)│ date_col: date_col ││ │ -45)│ int_col: int_col ││ │ -46)│ ││ │ -47)│ string_col: ││ │ -48)│ string_col ││ │ -49)└─────────────┬─────────────┘└─────────────┬─────────────┘ -50)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -51)│ RepartitionExec ││ DataSourceExec │ -52)│ -------------------- ││ -------------------- │ -53)│ input_partition_count: ││ files: 1 │ -54)│ 1 ││ format: parquet │ -55)│ ││ │ -56)│ partitioning_scheme: ││ │ -57)│ RoundRobinBatch(4) ││ │ -58)└─────────────┬─────────────┘└───────────────────────────┘ -59)┌─────────────┴─────────────┐ -60)│ DataSourceExec │ -61)│ -------------------- │ -62)│ files: 1 │ -63)│ format: csv │ -64)└───────────────────────────┘ +04)│ bigint_col: │ +05)│ bigint_col │ +06)│ │ +07)│ date_col: date_col │ +08)│ int_col: int_col │ +09)│ │ +10)│ string_col: │ +11)│ string_col │ +12)└─────────────┬─────────────┘ +13)┌─────────────┴─────────────┐ +14)│ CoalesceBatchesExec │ +15)│ -------------------- │ +16)│ target_batch_size: │ +17)│ 8192 │ +18)└─────────────┬─────────────┘ +19)┌─────────────┴─────────────┐ +20)│ HashJoinExec │ +21)│ -------------------- │ +22)│ on: │ +23)│ (int_col = int_col), ├──────────────┐ +24)│ (string_col = │ │ +25)│ string_col) │ │ +26)└─────────────┬─────────────┘ │ +27)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ +28)│ DataSourceExec ││ RepartitionExec │ +29)│ -------------------- ││ -------------------- │ +30)│ files: 1 ││ partition_count(in->out): │ +31)│ format: parquet ││ 1 -> 4 │ +32)│ ││ │ +33)│ ││ partitioning_scheme: │ +34)│ ││ RoundRobinBatch(4) │ +35)└───────────────────────────┘└─────────────┬─────────────┘ +36)-----------------------------┌─────────────┴─────────────┐ +37)-----------------------------│ DataSourceExec │ +38)-----------------------------│ -------------------- │ +39)-----------------------------│ files: 1 │ +40)-----------------------------│ format: csv │ +41)-----------------------------└───────────────────────────┘ # Query with outer hash join. query TT @@ -1256,71 +1209,48 @@ explain select * from table1 left outer join table2 on table1.int_col = table2.i ---- physical_plan 01)┌───────────────────────────┐ -02)│ CoalesceBatchesExec │ +02)│ ProjectionExec │ 03)│ -------------------- │ -04)│ target_batch_size: │ -05)│ 8192 │ -06)└─────────────┬─────────────┘ -07)┌─────────────┴─────────────┐ -08)│ HashJoinExec │ -09)│ -------------------- │ -10)│ join_type: Left │ -11)│ │ -12)│ on: ├──────────────┐ -13)│ (int_col = int_col), (CAST│ │ -14)│ (table1.string_col AS │ │ -15)│ Utf8View) = │ │ -16)│ string_col) │ │ -17)└─────────────┬─────────────┘ │ -18)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -19)│ CoalesceBatchesExec ││ CoalesceBatchesExec │ -20)│ -------------------- ││ -------------------- │ -21)│ target_batch_size: ││ target_batch_size: │ -22)│ 8192 ││ 8192 │ -23)└─────────────┬─────────────┘└─────────────┬─────────────┘ -24)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -25)│ RepartitionExec ││ RepartitionExec │ -26)│ -------------------- ││ -------------------- │ -27)│ input_partition_count: ││ input_partition_count: │ -28)│ 4 ││ 4 │ -29)│ ││ │ -30)│ partitioning_scheme: ││ partitioning_scheme: │ -31)│ Hash([int_col@0, CAST ││ Hash([int_col@0, │ -32)│ (table1.string_col ││ string_col@1], │ -33)│ AS Utf8View)@4], 4) ││ 4) │ -34)└─────────────┬─────────────┘└─────────────┬─────────────┘ -35)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -36)│ ProjectionExec ││ RepartitionExec │ -37)│ -------------------- ││ -------------------- │ -38)│ CAST(table1.string_col AS ││ input_partition_count: │ -39)│ Utf8View): ││ 1 │ -40)│ CAST(string_col AS ││ │ -41)│ Utf8View) ││ partitioning_scheme: │ -42)│ ││ RoundRobinBatch(4) │ -43)│ bigint_col: ││ │ -44)│ bigint_col ││ │ -45)│ ││ │ -46)│ date_col: date_col ││ │ -47)│ int_col: int_col ││ │ -48)│ ││ │ -49)│ string_col: ││ │ -50)│ string_col ││ │ -51)└─────────────┬─────────────┘└─────────────┬─────────────┘ -52)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -53)│ RepartitionExec ││ DataSourceExec │ -54)│ -------------------- ││ -------------------- │ -55)│ input_partition_count: ││ files: 1 │ -56)│ 1 ││ format: parquet │ -57)│ ││ │ -58)│ partitioning_scheme: ││ │ -59)│ RoundRobinBatch(4) ││ │ -60)└─────────────┬─────────────┘└───────────────────────────┘ -61)┌─────────────┴─────────────┐ -62)│ DataSourceExec │ -63)│ -------------------- │ -64)│ files: 1 │ -65)│ format: csv │ -66)└───────────────────────────┘ +04)│ bigint_col: │ +05)│ bigint_col │ +06)│ │ +07)│ date_col: date_col │ +08)│ int_col: int_col │ +09)│ │ +10)│ string_col: │ +11)│ string_col │ +12)└─────────────┬─────────────┘ +13)┌─────────────┴─────────────┐ +14)│ CoalesceBatchesExec │ +15)│ -------------------- │ +16)│ target_batch_size: │ +17)│ 8192 │ +18)└─────────────┬─────────────┘ +19)┌─────────────┴─────────────┐ +20)│ HashJoinExec │ +21)│ -------------------- │ +22)│ join_type: Right │ +23)│ │ +24)│ on: ├──────────────┐ +25)│ (int_col = int_col), │ │ +26)│ (string_col = │ │ +27)│ string_col) │ │ +28)└─────────────┬─────────────┘ │ +29)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ +30)│ DataSourceExec ││ RepartitionExec │ +31)│ -------------------- ││ -------------------- │ +32)│ files: 1 ││ partition_count(in->out): │ +33)│ format: parquet ││ 1 -> 4 │ +34)│ ││ │ +35)│ ││ partitioning_scheme: │ +36)│ ││ RoundRobinBatch(4) │ +37)└───────────────────────────┘└─────────────┬─────────────┘ +38)-----------------------------┌─────────────┴─────────────┐ +39)-----------------------------│ DataSourceExec │ +40)-----------------------------│ -------------------- │ +41)-----------------------------│ files: 1 │ +42)-----------------------------│ format: csv │ +43)-----------------------------└───────────────────────────┘ # Query with nested loop join. query TT @@ -1339,35 +1269,8 @@ physical_plan 10)│ format: csv ││ │ 11)└───────────────────────────┘└─────────────┬─────────────┘ 12)-----------------------------┌─────────────┴─────────────┐ -13)-----------------------------│ AggregateExec │ -14)-----------------------------│ -------------------- │ -15)-----------------------------│ aggr: count(1) │ -16)-----------------------------│ mode: Final │ -17)-----------------------------└─────────────┬─────────────┘ -18)-----------------------------┌─────────────┴─────────────┐ -19)-----------------------------│ CoalescePartitionsExec │ -20)-----------------------------└─────────────┬─────────────┘ -21)-----------------------------┌─────────────┴─────────────┐ -22)-----------------------------│ AggregateExec │ -23)-----------------------------│ -------------------- │ -24)-----------------------------│ aggr: count(1) │ -25)-----------------------------│ mode: Partial │ -26)-----------------------------└─────────────┬─────────────┘ -27)-----------------------------┌─────────────┴─────────────┐ -28)-----------------------------│ RepartitionExec │ -29)-----------------------------│ -------------------- │ -30)-----------------------------│ input_partition_count: │ -31)-----------------------------│ 1 │ -32)-----------------------------│ │ -33)-----------------------------│ partitioning_scheme: │ -34)-----------------------------│ RoundRobinBatch(4) │ -35)-----------------------------└─────────────┬─────────────┘ -36)-----------------------------┌─────────────┴─────────────┐ -37)-----------------------------│ DataSourceExec │ -38)-----------------------------│ -------------------- │ -39)-----------------------------│ files: 1 │ -40)-----------------------------│ format: parquet │ -41)-----------------------------└───────────────────────────┘ +13)-----------------------------│ PlaceholderRowExec │ +14)-----------------------------└───────────────────────────┘ # Query with cross join. query TT @@ -1378,21 +1281,11 @@ physical_plan 02)│ CrossJoinExec ├──────────────┐ 03)└─────────────┬─────────────┘ │ 04)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ -05)│ DataSourceExec ││ RepartitionExec │ +05)│ DataSourceExec ││ DataSourceExec │ 06)│ -------------------- ││ -------------------- │ -07)│ files: 1 ││ input_partition_count: │ -08)│ format: csv ││ 1 │ -09)│ ││ │ -10)│ ││ partitioning_scheme: │ -11)│ ││ RoundRobinBatch(4) │ -12)└───────────────────────────┘└─────────────┬─────────────┘ -13)-----------------------------┌─────────────┴─────────────┐ -14)-----------------------------│ DataSourceExec │ -15)-----------------------------│ -------------------- │ -16)-----------------------------│ files: 1 │ -17)-----------------------------│ format: parquet │ -18)-----------------------------└───────────────────────────┘ - +07)│ files: 1 ││ files: 1 │ +08)│ format: csv ││ format: parquet │ +09)└───────────────────────────┘└───────────────────────────┘ # Query with sort merge join. statement ok @@ -1505,8 +1398,8 @@ physical_plan 33)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ 34)│ RepartitionExec ││ RepartitionExec │ 35)│ -------------------- ││ -------------------- │ -36)│ input_partition_count: ││ input_partition_count: │ -37)│ 4 ││ 4 │ +36)│ partition_count(in->out): ││ partition_count(in->out): │ +37)│ 4 -> 4 ││ 4 -> 4 │ 38)│ ││ │ 39)│ partitioning_scheme: ││ partitioning_scheme: │ 40)│ Hash([name@0], 4) ││ Hash([name@0], 4) │ @@ -1514,8 +1407,8 @@ physical_plan 42)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ 43)│ RepartitionExec ││ RepartitionExec │ 44)│ -------------------- ││ -------------------- │ -45)│ input_partition_count: ││ input_partition_count: │ -46)│ 1 ││ 1 │ +45)│ partition_count(in->out): ││ partition_count(in->out): │ +46)│ 1 -> 4 ││ 1 -> 4 │ 47)│ ││ │ 48)│ partitioning_scheme: ││ partitioning_scheme: │ 49)│ RoundRobinBatch(4) ││ RoundRobinBatch(4) │ @@ -1529,7 +1422,7 @@ physical_plan 57)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ 58)│ DataSourceExec ││ DataSourceExec │ 59)│ -------------------- ││ -------------------- │ -60)│ bytes: 1320 ││ bytes: 1312 │ +60)│ bytes: 296 ││ bytes: 288 │ 61)│ format: memory ││ format: memory │ 62)│ rows: 1 ││ rows: 1 │ 63)└───────────────────────────┘└───────────────────────────┘ @@ -1548,14 +1441,14 @@ physical_plan 04)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ 05)│ DataSourceExec ││ ProjectionExec │ 06)│ -------------------- ││ -------------------- │ -07)│ bytes: 1320 ││ id: CAST(id AS Int32) │ +07)│ bytes: 296 ││ id: CAST(id AS Int32) │ 08)│ format: memory ││ name: name │ 09)│ rows: 1 ││ │ 10)└───────────────────────────┘└─────────────┬─────────────┘ 11)-----------------------------┌─────────────┴─────────────┐ 12)-----------------------------│ DataSourceExec │ 13)-----------------------------│ -------------------- │ -14)-----------------------------│ bytes: 1312 │ +14)-----------------------------│ bytes: 288 │ 15)-----------------------------│ format: memory │ 16)-----------------------------│ rows: 1 │ 17)-----------------------------└───────────────────────────┘ @@ -1606,8 +1499,8 @@ physical_plan 18)┌─────────────┴─────────────┐ 19)│ RepartitionExec │ 20)│ -------------------- │ -21)│ input_partition_count: │ -22)│ 1 │ +21)│ partition_count(in->out): │ +22)│ 1 -> 4 │ 23)│ │ 24)│ partitioning_scheme: │ 25)│ RoundRobinBatch(4) │ @@ -1648,8 +1541,8 @@ physical_plan 19)┌─────────────┴─────────────┐ 20)│ RepartitionExec │ 21)│ -------------------- │ -22)│ input_partition_count: │ -23)│ 1 │ +22)│ partition_count(in->out): │ +23)│ 1 -> 4 │ 24)│ │ 25)│ partitioning_scheme: │ 26)│ RoundRobinBatch(4) │ @@ -1689,8 +1582,8 @@ physical_plan 19)┌─────────────┴─────────────┐ 20)│ RepartitionExec │ 21)│ -------------------- │ -22)│ input_partition_count: │ -23)│ 1 │ +22)│ partition_count(in->out): │ +23)│ 1 -> 4 │ 24)│ │ 25)│ partitioning_scheme: │ 26)│ RoundRobinBatch(4) │ @@ -1728,8 +1621,8 @@ physical_plan 17)┌─────────────┴─────────────┐ 18)│ RepartitionExec │ 19)│ -------------------- │ -20)│ input_partition_count: │ -21)│ 1 │ +20)│ partition_count(in->out): │ +21)│ 1 -> 4 │ 22)│ │ 23)│ partitioning_scheme: │ 24)│ RoundRobinBatch(4) │ @@ -1771,8 +1664,8 @@ physical_plan 20)┌─────────────┴─────────────┐ 21)│ RepartitionExec │ 22)│ -------------------- │ -23)│ input_partition_count: │ -24)│ 1 │ +23)│ partition_count(in->out): │ +24)│ 1 -> 4 │ 25)│ │ 26)│ partitioning_scheme: │ 27)│ RoundRobinBatch(4) │ @@ -1815,8 +1708,8 @@ physical_plan 19)┌─────────────┴─────────────┐ 20)│ RepartitionExec │ 21)│ -------------------- │ -22)│ input_partition_count: │ -23)│ 1 │ +22)│ partition_count(in->out): │ +23)│ 1 -> 4 │ 24)│ │ 25)│ partitioning_scheme: │ 26)│ RoundRobinBatch(4) │ @@ -1869,8 +1762,8 @@ physical_plan 25)-----------------------------┌─────────────┴─────────────┐ 26)-----------------------------│ RepartitionExec │ 27)-----------------------------│ -------------------- │ -28)-----------------------------│ input_partition_count: │ -29)-----------------------------│ 1 │ +28)-----------------------------│ partition_count(in->out): │ +29)-----------------------------│ 1 -> 4 │ 30)-----------------------------│ │ 31)-----------------------------│ partitioning_scheme: │ 32)-----------------------------│ RoundRobinBatch(4) │ @@ -1983,8 +1876,8 @@ physical_plan 22)┌─────────────┴─────────────┐ 23)│ RepartitionExec │ 24)│ -------------------- │ -25)│ input_partition_count: │ -26)│ 1 │ +25)│ partition_count(in->out): │ +26)│ 1 -> 4 │ 27)│ │ 28)│ partitioning_scheme: │ 29)│ RoundRobinBatch(4) │ @@ -2062,8 +1955,8 @@ physical_plan 19)┌─────────────┴─────────────┐ 20)│ RepartitionExec │ 21)│ -------------------- │ -22)│ input_partition_count: │ -23)│ 1 │ +22)│ partition_count(in->out): │ +23)│ 1 -> 4 │ 24)│ │ 25)│ partitioning_scheme: │ 26)│ RoundRobinBatch(4) │ diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index e4d0b7233856..67a2af11870d 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -698,6 +698,11 @@ SELECT to_hex(2147483647) ---- 7fffffff +query T +SELECT to_hex(CAST(2147483647 as BIGINT UNSIGNED)) +---- +7fffffff + query T SELECT to_hex(9223372036854775807) ---- diff --git a/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt b/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt index d96044fda8c0..a09d8ce26ddf 100644 --- a/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt +++ b/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt @@ -34,7 +34,7 @@ ORDER BY "date", "time"; ---- logical_plan 01)Sort: data.date ASC NULLS LAST, data.time ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") +02)--Filter: data.ticker = Utf8View("A") 03)----TableScan: data projection=[date, ticker, time] physical_plan 01)SortPreservingMergeExec: [date@0 ASC NULLS LAST, time@2 ASC NULLS LAST] @@ -51,7 +51,7 @@ ORDER BY "time" ---- logical_plan 01)Sort: data.time ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") AND CAST(data.time AS Date32) = data.date +02)--Filter: data.ticker = Utf8View("A") AND CAST(data.time AS Date32) = data.date 03)----TableScan: data projection=[date, ticker, time] physical_plan 01)SortPreservingMergeExec: [time@2 ASC NULLS LAST] @@ -68,7 +68,7 @@ ORDER BY "date" ---- logical_plan 01)Sort: data.date ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") AND CAST(data.time AS Date32) = data.date +02)--Filter: data.ticker = Utf8View("A") AND CAST(data.time AS Date32) = data.date 03)----TableScan: data projection=[date, ticker, time] physical_plan 01)SortPreservingMergeExec: [date@0 ASC NULLS LAST] @@ -85,7 +85,7 @@ ORDER BY "ticker" ---- logical_plan 01)Sort: data.ticker ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") AND CAST(data.time AS Date32) = data.date +02)--Filter: data.ticker = Utf8View("A") AND CAST(data.time AS Date32) = data.date 03)----TableScan: data projection=[date, ticker, time] physical_plan 01)CoalescePartitionsExec @@ -102,7 +102,7 @@ ORDER BY "time", "date"; ---- logical_plan 01)Sort: data.time ASC NULLS LAST, data.date ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") AND CAST(data.time AS Date32) = data.date +02)--Filter: data.ticker = Utf8View("A") AND CAST(data.time AS Date32) = data.date 03)----TableScan: data projection=[date, ticker, time] physical_plan 01)SortPreservingMergeExec: [time@2 ASC NULLS LAST, date@0 ASC NULLS LAST] @@ -120,7 +120,7 @@ ORDER BY "time" ---- logical_plan 01)Sort: data.time ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") AND CAST(data.time AS Date32) != data.date +02)--Filter: data.ticker = Utf8View("A") AND CAST(data.time AS Date32) != data.date 03)----TableScan: data projection=[date, ticker, time] # no relation between time & date @@ -132,7 +132,7 @@ ORDER BY "time" ---- logical_plan 01)Sort: data.time ASC NULLS LAST -02)--Filter: data.ticker = Utf8("A") +02)--Filter: data.ticker = Utf8View("A") 03)----TableScan: data projection=[date, ticker, time] # query diff --git a/datafusion/sqllogictest/test_files/float16.slt b/datafusion/sqllogictest/test_files/float16.slt new file mode 100644 index 000000000000..5e59c730f078 --- /dev/null +++ b/datafusion/sqllogictest/test_files/float16.slt @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Basic tests Tests for Float16 Type + +statement ok +create table floats as values (1.0), (2.0), (3.0), (NULL), ('Nan'); + +statement ok +create table float16s as select arrow_cast(column1, 'Float16') as column1 from floats; + +query RT +select column1, arrow_typeof(column1) as type from float16s; +---- +1 Float16 +2 Float16 +3 Float16 +NULL Float16 +NaN Float16 + +# Test coercions with arithmetic + +query RRRRRR +SELECT + column1 + 1::tinyint as column1_plus_int8, + column1 + 1::smallint as column1_plus_int16, + column1 + 1::int as column1_plus_int32, + column1 + 1::bigint as column1_plus_int64, + column1 + 1.0::float as column1_plus_float32, + column1 + 1.0 as column1_plus_float64 +FROM float16s; +---- +2 2 2 2 2 2 +3 3 3 3 3 3 +4 4 4 4 4 4 +NULL NULL NULL NULL NULL NULL +NaN NaN NaN NaN NaN NaN + +# Try coercing with literal NULL +query error +select column1 + NULL from float16s; +---- +DataFusion error: type_coercion +caused by +Error during planning: Cannot automatically convert Null to Float16 + + +# Test coercions with equality +query BBBBBB +SELECT + column1 = 1::tinyint as column1_equals_int8, + column1 = 1::smallint as column1_equals_int16, + column1 = 1::int as column1_equals_int32, + column1 = 1::bigint as column1_equals_int64, + column1 = 1.0::float as column1_equals_float32, + column1 = 1.0 as column1_equals_float64 +FROM float16s; +---- +true true true true true true +false false false false false false +false false false false false false +NULL NULL NULL NULL NULL NULL +false false false false false false + + +# Try coercing with literal NULL +query error +select column1 = NULL from float16s; +---- +DataFusion error: Error during planning: Cannot infer common argument type for comparison operation Float16 = Null + + +# Cleanup +statement ok +drop table floats; + +statement ok +drop table float16s; diff --git a/datafusion/sqllogictest/test_files/imdb.slt b/datafusion/sqllogictest/test_files/imdb.slt new file mode 100644 index 000000000000..c17f9c47c745 --- /dev/null +++ b/datafusion/sqllogictest/test_files/imdb.slt @@ -0,0 +1,4040 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file contains IMDB test queries against a small sample dataset. +# The test creates tables with sample data and runs all the IMDB benchmark queries. + +# company_type table +statement ok +CREATE TABLE company_type ( + id INT NOT NULL, + kind VARCHAR NOT NULL +); + +statement ok +INSERT INTO company_type VALUES + (1, 'production companies'), + (2, 'distributors'), + (3, 'special effects companies'), + (4, 'other companies'), + (5, 'miscellaneous companies'), + (6, 'film distributors'), + (7, 'theaters'), + (8, 'sales companies'), + (9, 'producers'), + (10, 'publishers'), + (11, 'visual effects companies'), + (12, 'makeup departments'), + (13, 'costume designers'), + (14, 'movie studios'), + (15, 'sound departments'), + (16, 'talent agencies'), + (17, 'casting companies'), + (18, 'film commissions'), + (19, 'production services'), + (20, 'digital effects studios'); + +# info_type table +statement ok +CREATE TABLE info_type ( + id INT NOT NULL, + info VARCHAR NOT NULL +); + +statement ok +INSERT INTO info_type VALUES + (1, 'runtimes'), + (2, 'color info'), + (3, 'genres'), + (4, 'languages'), + (5, 'certificates'), + (6, 'sound mix'), + (7, 'countries'), + (8, 'top 250 rank'), + (9, 'bottom 10 rank'), + (10, 'release dates'), + (11, 'filming locations'), + (12, 'production companies'), + (13, 'technical info'), + (14, 'trivia'), + (15, 'goofs'), + (16, 'martial-arts'), + (17, 'quotes'), + (18, 'movie connections'), + (19, 'plot description'), + (20, 'biography'), + (21, 'plot summary'), + (22, 'box office'), + (23, 'ratings'), + (24, 'taglines'), + (25, 'keywords'), + (26, 'soundtrack'), + (27, 'votes'), + (28, 'height'), + (30, 'mini biography'), + (31, 'budget'), + (32, 'rating'); + +# title table +statement ok +CREATE TABLE title ( + id INT NOT NULL, + title VARCHAR NOT NULL, + imdb_index VARCHAR, + kind_id INT NOT NULL, + production_year INT, + imdb_id INT, + phonetic_code VARCHAR, + episode_of_id INT, + season_nr INT, + episode_nr INT, + series_years VARCHAR, + md5sum VARCHAR +); + +statement ok +INSERT INTO title VALUES + (1, 'The Shawshank Redemption', NULL, 1, 1994, 111161, NULL, NULL, NULL, NULL, NULL, NULL), + (2, 'The Godfather', NULL, 1, 1985, 68646, NULL, NULL, NULL, NULL, NULL, NULL), + (3, 'The Dark Knight', NULL, 1, 2008, 468569, NULL, NULL, NULL, NULL, NULL, NULL), + (4, 'The Godfather Part II', NULL, 1, 2012, 71562, NULL, NULL, NULL, NULL, NULL, NULL), + (5, 'Pulp Fiction', NULL, 1, 1994, 110912, NULL, NULL, NULL, NULL, NULL, NULL), + (6, 'Schindler''s List', NULL, 1, 1993, 108052, NULL, NULL, NULL, NULL, NULL, NULL), + (7, 'The Lord of the Rings: The Return of the King', NULL, 1, 2003, 167260, NULL, NULL, NULL, NULL, NULL, NULL), + (8, '12 Angry Men', NULL, 1, 1957, 50083, NULL, NULL, NULL, NULL, NULL, NULL), + (9, 'Inception', NULL, 1, 2010, 1375666, NULL, NULL, NULL, NULL, NULL, NULL), + (10, 'Fight Club', NULL, 1, 1999, 137523, NULL, NULL, NULL, NULL, NULL, NULL), + (11, 'The Matrix', NULL, 1, 2014, 133093, NULL, NULL, NULL, NULL, NULL, NULL), + (12, 'Goodfellas', NULL, 1, 1990, 99685, NULL, NULL, NULL, NULL, NULL, NULL), + (13, 'Avengers: Endgame', NULL, 1, 2019, 4154796, NULL, NULL, NULL, NULL, NULL, NULL), + (14, 'Interstellar', NULL, 1, 2014, 816692, NULL, NULL, NULL, NULL, NULL, NULL), + (15, 'The Silence of the Lambs', NULL, 1, 1991, 102926, NULL, NULL, NULL, NULL, NULL, NULL), + (16, 'Saving Private Ryan', NULL, 1, 1998, 120815, NULL, NULL, NULL, NULL, NULL, NULL), + (17, 'The Green Mile', NULL, 1, 1999, 120689, NULL, NULL, NULL, NULL, NULL, NULL), + (18, 'Forrest Gump', NULL, 1, 1994, 109830, NULL, NULL, NULL, NULL, NULL, NULL), + (19, 'Joker', NULL, 1, 2019, 7286456, NULL, NULL, NULL, NULL, NULL, NULL), + (20, 'Parasite', NULL, 1, 2019, 6751668, NULL, NULL, NULL, NULL, NULL, NULL), + (21, 'The Iron Giant', NULL, 1, 1999, 129167, NULL, NULL, NULL, NULL, NULL, NULL), + (22, 'Spider-Man: Into the Spider-Verse', NULL, 1, 2018, 4633694, NULL, NULL, NULL, NULL, NULL, NULL), + (23, 'Iron Man', NULL, 1, 2008, 371746, NULL, NULL, NULL, NULL, NULL, NULL), + (24, 'Black Panther', NULL, 1, 2018, 1825683, NULL, NULL, NULL, NULL, NULL, NULL), + (25, 'Titanic', NULL, 1, 1997, 120338, NULL, NULL, NULL, NULL, NULL, NULL), + (26, 'Kung Fu Panda 2', NULL, 1, 2011, 0441773, NULL, NULL, NULL, NULL, NULL, NULL), + (27, 'Halloween', NULL, 1, 2008, 1311067, NULL, NULL, NULL, NULL, NULL, NULL), + (28, 'Breaking Bad', NULL, 2, 2003, 903254, NULL, NULL, NULL, NULL, NULL, NULL), + (29, 'Breaking Bad: The Final Season', NULL, 2, 2007, 903255, NULL, NULL, NULL, NULL, NULL, NULL), + (30, 'Amsterdam Detective', NULL, 2, 2005, 905001, NULL, NULL, NULL, NULL, NULL, NULL), + (31, 'Amsterdam Detective: Cold Case', NULL, 2, 2007, 905002, NULL, NULL, NULL, NULL, NULL, NULL), + (32, 'Saw IV', NULL, 1, 2007, 905003, NULL, NULL, NULL, NULL, NULL, NULL), + (33, 'Shrek 2', NULL, 1, 2004, 906001, NULL, NULL, NULL, NULL, NULL, NULL), + (35, 'Dark Blood', NULL, 1, 2005, 907001, NULL, NULL, NULL, NULL, NULL, NULL), + (36, 'The Nordic Murders', NULL, 1, 2008, 908002, NULL, NULL, NULL, NULL, NULL, NULL), + (37, 'Scandinavian Crime', NULL, 1, 2009, 909001, NULL, NULL, NULL, NULL, NULL, NULL), + (38, 'The Western Sequel', NULL, 1, 1998, NULL, NULL, NULL, NULL, NULL, NULL, NULL), + (39, 'Marvel Superhero Epic', NULL, 1, 2010, NULL, NULL, NULL, NULL, NULL, NULL, NULL), + (40, 'The Champion', NULL, 1, 2016, 999555, NULL, NULL, NULL, NULL, NULL, NULL), + (41, 'Champion Boxer', NULL, 1, 2018, 999556, NULL, NULL, NULL, NULL, NULL, NULL), + (42, 'Avatar', NULL, 5, 2010, 499549, NULL, NULL, NULL, NULL, NULL, NULL), + (43, 'The Godfather Connection', NULL, 1, 1985, 68647, NULL, NULL, NULL, NULL, NULL, NULL), + (44, 'Digital Connection', NULL, 1, 2005, 888999, NULL, NULL, NULL, NULL, NULL, NULL), + (45, 'Berlin Noir', NULL, 1, 2010, NULL, NULL, NULL, NULL, NULL, NULL, NULL), + (46, 'YouTube Documentary', NULL, 1, 2008, 777999, NULL, NULL, NULL, NULL, NULL, NULL), + (47, 'The Swedish Murder Case', NULL, 1, 2012, 666777, NULL, NULL, NULL, NULL, NULL, NULL), + (48, 'Nordic Noir', NULL, 1, 2015, 555666, NULL, NULL, NULL, NULL, NULL, NULL), + (49, 'Derek Jacobi Story', NULL, 1, 1982, 444555, NULL, NULL, NULL, NULL, NULL, NULL), + (50, 'Woman in Black', NULL, 1, 2010, 987654, NULL, NULL, NULL, NULL, NULL, NULL), + (51, 'Kung Fu Panda', NULL, 1, 2008, 441772, NULL, NULL, NULL, NULL, NULL, NULL), + (52, 'Bruno', NULL, 1, 2009, NULL, NULL, NULL, NULL, NULL, NULL, NULL), + (53, 'Character Series', NULL, 2, 2020, 999888, NULL, NULL, NULL, 55, NULL, NULL), + (54, 'Vampire Chronicles', NULL, 1, 2015, 999999, NULL, NULL, NULL, NULL, NULL, NULL), + (55, 'Alien Invasion', NULL, 1, 2020, 888888, NULL, NULL, NULL, NULL, NULL, NULL), + (56, 'Dragon Warriors', NULL, 1, 2015, 888889, NULL, NULL, NULL, NULL, NULL, NULL), + (57, 'One Piece: Grand Adventure', NULL, 1, 2007, 777777, NULL, NULL, NULL, NULL, NULL, NULL), + (58, 'Moscow Nights', NULL, 1, 2010, 777778, NULL, NULL, NULL, NULL, NULL, NULL), + (59, 'Money Talks', NULL, 1, 1998, 888888, NULL, NULL, NULL, NULL, NULL, NULL), + (60, 'Fox Novel Movie', NULL, 1, 2005, 777888, NULL, NULL, NULL, NULL, NULL, NULL), + (61, 'Bad Movie Sequel', NULL, 1, 2010, 888777, NULL, NULL, NULL, NULL, NULL, NULL); + +# movie_companies table +statement ok +CREATE TABLE movie_companies ( + id INT NOT NULL, + movie_id INT NOT NULL, + company_id INT NOT NULL, + company_type_id INT NOT NULL, + note VARCHAR +); + +statement ok +INSERT INTO movie_companies VALUES + (1, 1, 4, 1, '(presents) (co-production)'), + (2, 2, 5, 1, '(presents)'), + (3, 3, 6, 1, '(co-production)'), + (4, 4, 7, 1, '(as Metro-Goldwyn-Mayer Pictures)'), + (5, 5, 8, 1, '(presents) (co-production)'), + (6, 6, 9, 1, '(presents)'), + (7, 7, 10, 1, '(co-production)'), + (8, 8, 11, 2, '(distributor)'), + (9, 9, 12, 1, '(presents) (co-production)'), + (10, 10, 13, 1, '(presents)'), + (11, 11, 14, 1, '(presents) (co-production)'), + (12, 12, 15, 1, '(presents)'), + (13, 13, 16, 1, '(co-production)'), + (14, 14, 17, 1, '(presents)'), + (15, 15, 18, 1, '(co-production)'), + (16, 16, 19, 1, '(presents)'), + (17, 17, 20, 1, '(co-production)'), + (18, 18, 21, 1, '(presents)'), + (19, 19, 22, 1, '(co-production)'), + (20, 20, 23, 1, '(presents)'), + (21, 21, 24, 1, '(presents) (co-production)'), + (22, 22, 25, 1, '(presents)'), + (23, 23, 26, 1, '(co-production)'), + (24, 24, 27, 1, '(presents)'), + (25, 25, 28, 1, '(presents) (co-production)'), + (26, 3, 35, 1, '(as Warner Bros. Pictures)'), + (27, 9, 35, 1, '(as Warner Bros. Pictures)'), + (28, 23, 14, 1, '(as Marvel Studios)'), + (29, 24, 14, 1, '(as Marvel Studios)'), + (30, 13, 14, 1, '(as Marvel Studios)'), + (31, 26, 23, 1, '(as DreamWorks Animation)'), + (32, 3, 6, 2, '(distributor)'), + (33, 2, 8, 2, '(distributor)'), + (34, 3, 6, 1, '(as Warner Bros.) (2008) (USA) (worldwide)'), + (35, 44, 36, 1, NULL), + (36, 40, 9, 1, '(production) (USA) (2016)'), + (37, 56, 18, 1, '(production)'), + (38, 2, 6, 1, NULL), + (39, 13, 14, 2, '(as Marvel Studios)'), + (40, 19, 25, 1, '(co-production)'), + (41, 23, 26, 1, '(co-production)'), + (42, 19, 27, 1, '(co-production)'), + (43, 11, 18, 1, '(theatrical) (France)'), + (44, 11, 8, 1, '(VHS) (USA) (1994)'), + (45, 11, 4, 1, '(USA)'), + (46, 9, 28, 1, '(co-production)'), + (47, 28, 5, 1, '(production)'), + (48, 29, 5, 1, '(production)'), + (49, 30, 29, 1, '(production)'), + (50, 31, 30, 1, '(production)'), + (51, 27, 22, 1, '(production)'), + (52, 32, 22, 1, '(distribution) (Blu-ray)'), + (53, 33, 31, 1, '(production)'), + (54, 33, 31, 2, '(distribution)'), + (55, 35, 32, 1, NULL), + (56, 36, 33, 1, '(production) (2008)'), + (57, 37, 34, 1, '(production) (2009) (Norway)'), + (58, 38, 35, 1, NULL), + (59, 25, 9, 1, '(production)'), + (60, 52, 19, 1, NULL), + (61, 26, 37, 1, '(voice: English version)'), + (62, 21, 3, 1, '(production) (Japan) (anime)'), + (63, 57, 2, 1, '(production) (Japan) (2007) (anime)'), + (64, 58, 1, 1, '(production) (Russia) (2010)'), + (65, 59, 35, 1, NULL), + (66, 60, 13, 2, '(distribution) (DVD) (US)'), + (67, 61, 14, 1, '(production)'), + (68, 41, 9, 1, '(production) (USA) (2018)'), + (69, 46, 16, 1, '(production) (2008) (worldwide)'), + (70, 51, 31, 1, '(production) (2008) (USA) (worldwide)'), + (71, 45, 32, 1, 'Studio (2000) Berlin'), + (72, 53, 6, 1, '(production) (2020) (USA)'), + (73, 62, 9, 1, '(production) (USA) (2010) (worldwide)'); + +# movie_info_idx table +statement ok +CREATE TABLE movie_info_idx ( + id INT NOT NULL, + movie_id INT NOT NULL, + info_type_id INT NOT NULL, + info VARCHAR NOT NULL, + note VARCHAR +); + +statement ok +INSERT INTO movie_info_idx VALUES + (1, 1, 8, '1', NULL), + (2, 2, 8, '2', NULL), + (3, 3, 8, '3', NULL), + (4, 4, 8, '4', NULL), + (5, 5, 8, '5', NULL), + (6, 6, 8, '6', NULL), + (7, 7, 8, '7', NULL), + (8, 8, 8, '8', NULL), + (9, 9, 8, '9', NULL), + (10, 10, 8, '10', NULL), + (11, 11, 8, '11', NULL), + (12, 12, 8, '12', NULL), + (13, 13, 8, '13', NULL), + (14, 14, 8, '14', NULL), + (15, 15, 8, '15', NULL), + (16, 16, 8, '16', NULL), + (17, 17, 8, '17', NULL), + (18, 18, 8, '18', NULL), + (19, 19, 8, '19', NULL), + (20, 20, 8, '20', NULL), + (21, 21, 8, '21', NULL), + (22, 22, 8, '22', NULL), + (23, 23, 8, '23', NULL), + (24, 24, 8, '24', NULL), + (25, 25, 8, '25', NULL), + (26, 40, 32, '8.6', NULL), + (27, 41, 32, '7.5', NULL), + (28, 45, 32, '6.8', NULL), + (29, 45, 22, '$10,000,000', NULL), + (30, 1, 22, '9.3', NULL), + (31, 2, 22, '9.2', NULL), + (32, 1, 27, '2,345,678', NULL), + (33, 3, 22, '9.0', NULL), + (34, 9, 22, '8.8', NULL), + (35, 23, 22, '8.5', NULL), + (36, 20, 9, '1', NULL), + (37, 25, 9, '2', NULL), + (38, 3, 9, '10', NULL), + (39, 28, 32, '8.2', NULL), + (40, 29, 32, '2.8', NULL), + (41, 30, 32, '8.5', NULL), + (42, 31, 32, '2.5', NULL), + (43, 27, 27, '45000', NULL), + (44, 32, 27, '52000', NULL), + (45, 33, 27, '120000', NULL), + (46, 35, 32, '7.2', NULL), + (47, 36, 32, '7.8', NULL), + (48, 37, 32, '7.5', NULL), + (49, 37, 27, '100000', NULL), + (50, 39, 32, '8.5', NULL), + (51, 54, 27, '1000', NULL), + (52, 3, 3002, '500', NULL), + (53, 3, 999, '9.5', NULL), + (54, 4, 999, '9.1', NULL), + (55, 13, 999, '8.9', NULL), + (56, 3, 32, '9.5', NULL), + (57, 4, 32, '9.1', NULL), + (58, 13, 32, '8.9', NULL), + (59, 4, 32, '9.3', NULL), + (60, 61, 9, '3', NULL), + (61, 35, 22, '8.4', NULL), + (62, 50, 32, '8.5', NULL), + (63, 48, 32, '7.5', NULL), + (64, 48, 27, '85000', NULL), + (65, 47, 32, '7.8', NULL), + (66, 46, 3, 'Documentary', NULL), + (67, 46, 10, 'USA: 2008-05-15', 'internet release'); + +# movie_info table +statement ok +CREATE TABLE movie_info ( + id INT NOT NULL, + movie_id INT NOT NULL, + info_type_id INT NOT NULL, + info VARCHAR NOT NULL, + note VARCHAR +); + +statement ok +INSERT INTO movie_info VALUES + (1, 1, 1, '113', NULL), + (2, 4, 7, 'Germany', NULL), + (3, 3, 7, 'Bulgaria', NULL), + (4, 2, 1, '175', NULL), + (5, 3, 1, '152', NULL), + (6, 4, 1, '202', NULL), + (7, 5, 1, '154', NULL), + (8, 6, 1, '195', NULL), + (9, 7, 1, '201', NULL), + (10, 8, 1, '139', NULL), + (11, 9, 1, '148', NULL), + (12, 10, 1, '139', NULL), + (13, 11, 1, '136', NULL), + (14, 12, 1, '146', NULL), + (15, 13, 1, '181', NULL), + (16, 14, 1, '141', NULL), + (17, 15, 1, '159', NULL), + (18, 16, 1, '150', NULL), + (19, 17, 1, '156', NULL), + (20, 18, 1, '164', NULL), + (21, 19, 1, '122', NULL), + (22, 20, 1, '140', NULL), + (23, 40, 1, '125', NULL), + (24, 21, 1, '86', NULL), + (25, 22, 1, '117', NULL), + (26, 23, 1, '126', NULL), + (27, 24, 1, '134', NULL), + (28, 25, 1, '194', NULL), + (29, 1, 10, '1994-10-14', 'internet release'), + (30, 2, 10, '1972-03-24', 'internet release'), + (31, 3, 10, '2008-07-18', 'internet release'), + (32, 9, 10, '2010-07-16', 'internet release'), + (33, 13, 10, '2019-04-26', 'internet release'), + (34, 23, 10, '2008-05-02', 'internet release'), + (35, 24, 10, '2018-02-16', 'internet release'), + (36, 1, 2, 'Color', NULL), + (37, 3, 2, 'Color', NULL), + (38, 8, 2, 'Black and White', NULL), + (39, 9, 2, 'Color', NULL), + (40, 1, 19, 'Story about hope and redemption', NULL), + (41, 3, 19, 'Batman faces his greatest challenge', NULL), + (42, 19, 19, 'Origin story of the Batman villain', NULL), + (43, 1, 3, 'Drama', NULL), + (44, 3, 3, 'Action', NULL), + (45, 3, 3, 'Crime', NULL), + (46, 3, 3, 'Drama', NULL), + (47, 9, 3, 'Action', NULL), + (48, 9, 3, 'Adventure', NULL), + (49, 9, 3, 'Sci-Fi', NULL), + (50, 23, 3, 'Action', NULL), + (51, 23, 3, 'Adventure', NULL), + (52, 23, 3, 'Sci-Fi', NULL), + (53, 24, 3, 'Action', NULL), + (54, 24, 3, 'Adventure', NULL), + (55, 9, 7, 'Germany', NULL), + (56, 19, 7, 'German', NULL), + (57, 24, 7, 'Germany', NULL), + (58, 13, 7, 'USA', NULL), + (59, 3, 7, 'USA', NULL), + (60, 3, 22, '2343110', NULL), + (61, 3, 27, '2343110', NULL), + (62, 26, 10, 'USA:2011-05-26', NULL), + (63, 19, 20, 'Batman faces his greatest challenge', NULL), + (64, 3, 3, 'Drama', NULL), + (65, 13, 3, 'Action', NULL), + (66, 13, 19, 'Epic conclusion to the Infinity Saga', NULL), + (67, 2, 8, '1972-03-24', 'Released via internet in 2001'), + (68, 13, 4, 'English', NULL), + (69, 13, 3, 'Animation', NULL), + (70, 26, 3, 'Animation', NULL), + (71, 27, 3, '$15 million', NULL), + (72, 27, 3, 'Horror', NULL), + (73, 32, 3, 'Horror', NULL), + (74, 33, 10, 'USA: 2004', NULL), + (75, 33, 3, 'Animation', NULL), + (76, 35, 7, 'Germany', NULL), + (77, 35, 10, '2005-09-15', NULL), + (78, 44, 10, 'USA: 15 May 2005', 'This movie explores internet culture and digital connections that emerged in the early 2000s.'), + (79, 40, 10, '2016-08-12', 'internet release'), + (80, 1, 31, '$25,000,000', NULL), + (81, 45, 7, 'Germany', NULL), + (82, 45, 32, 'Germany', NULL), + (83, 13, 32, '8.5', NULL), + (84, 3, 32, '9.2', NULL), + (85, 3, 102, '9.2', NULL), + (86, 3, 25, 'sequel', NULL), + (87, 3, 102, '9.2', NULL), + (88, 3, 102, '9.2', NULL), + (89, 4, 102, '9.5', NULL), + (90, 33, 102, '8.7', NULL), + (91, 4, 32, '9.5', NULL), + (92, 11, 32, '8.7', NULL), + (93, 3, 32, '9.2', NULL), + (94, 3, 102, '9.2', NULL), + (95, 3, 32, '9.0', NULL), + (96, 26, 32, '8.2', NULL), + (97, 26, 32, '8.5', NULL), + (98, 27, 27, '8231', NULL), + (99, 27, 10, '2008-10-31', NULL), + (100, 13, 1, '182', NULL), + (101, 11, 2, 'Germany', NULL), + (102, 11, 1, '120', NULL), + (103, 3, 3, 'Drama', NULL), + (104, 11, 7, 'USA', NULL), + (105, 11, 7, 'Bulgaria', NULL), + (106, 50, 3, 'Horror', NULL), + (107, 36, 7, 'Sweden', NULL), + (108, 37, 7, 'Norway', NULL), + (109, 38, 7, 'Sweden', NULL), + (110, 54, 3, 'Horror', NULL), + (111, 55, 3, 'Sci-Fi', NULL), + (112, 56, 30, 'Japan:2015-06-15', NULL), + (113, 56, 30, 'USA:2015-07-20', NULL), + (114, 26, 10, 'Japan:2011-05-29', NULL), + (115, 26, 10, 'USA:2011-05-26', NULL), + (116, 61, 31, '$500,000', NULL), + (117, 41, 10, '2018-05-25', 'USA theatrical release'), + (118, 41, 7, 'Germany', 'Filmed on location'), + (119, 48, 7, 'Sweden', 'Filmed on location'), + (120, 48, 10, '2015-06-15', 'theatrical release'), + (121, 48, 3, 'Thriller', NULL), + (122, 47, 7, 'Sweden', 'Principal filming location'), + (123, 47, 10, '2012-09-21', 'theatrical release'), + (124, 47, 3, 'Crime', NULL), + (125, 47, 3, 'Thriller', NULL), + (126, 47, 7, 'Sweden', NULL), + (127, 3, 10, 'USA: 2008-07-14', 'internet release'), + (128, 46, 10, 'USA: 2008-05-15', 'internet release'), + (129, 40, 10, 'USA:\ 2006', 'internet release'), + (130, 51, 10, 'USA: 2008-06-06', 'theatrical release'), + (131, 51, 10, 'Japan: 2007-12-20', 'preview screening'); + +# kind_type table +statement ok +CREATE TABLE kind_type ( + id INT NOT NULL, + kind VARCHAR NOT NULL +); + +statement ok +INSERT INTO kind_type VALUES + (1, 'movie'), + (2, 'tv series'), + (3, 'video movie'), + (4, 'tv movie'), + (5, 'video game'), + (6, 'episode'), + (7, 'documentary'), + (8, 'short movie'), + (9, 'tv mini series'), + (10, 'reality-tv'); + +# cast_info table +statement ok +CREATE TABLE cast_info ( + id INT NOT NULL, + person_id INT NOT NULL, + movie_id INT NOT NULL, + person_role_id INT, + note VARCHAR, + nr_order INT, + role_id INT NOT NULL +); + +statement ok +INSERT INTO cast_info VALUES + (1, 29, 53, NULL, NULL, 1, 1), + (2, 3, 1, 54, NULL, 1, 1), + (3, 3, 1, NULL, '(producer)', 1, 3), + (4, 4, 2, 2, NULL, 1, 1), + (5, 5, 3, 3, NULL, 1, 1), + (6, 6, 4, 4, NULL, 1, 1), + (7, 2, 50, NULL, '(writer)', 1, 4), + (8, 18, 51, 15, '(voice)', 1, 2), + (9, 1, 19, NULL, NULL, 1, 1), + (10, 6, 100, 1985, '(as Special Actor)', 1, 1), + (11, 15, 19, NULL, NULL, 1, 1), + (12, 8, 5, 5, NULL, 1, 1), + (13, 9, 6, 6, NULL, 1, 1), + (14, 10, 7, 7, NULL, 1, 1), + (15, 11, 8, 8, NULL, 1, 1), + (16, 12, 9, 9, NULL, 1, 1), + (17, 13, 10, 10, NULL, 1, 1), + (18, 14, 9, 55, NULL, 1, 1), + (19, 14, 14, 29, NULL, 1, 1), + (20, 27, 58, 28, '(producer)', 1, 1), + (21, 16, 3, 23, '(producer)', 2, 1), + (22, 20, 49, NULL, NULL, 1, 1), + (23, 13, 23, 14, NULL, 1, 1), + (24, 28, 13, NULL, '(costume design)', 1, 7), + (25, 25, 58, 31, '(voice) (uncredited)', 1, 1), + (26, 18, 3, 24, '(voice)', 1, 2), + (27, 29, 26, 24, '(voice)', 1, 2), + (28, 13, 13, 47, '(writer)', 1, 1), + (29, 17, 3, 25, '(producer)', 3, 8), + (30, 18, 3, 11, '(voice)', 1, 2), + (31, 18, 26, 11, '(voice)', 1, 2), + (32, 18, 26, 12, '(voice: original film)', 1, 2), + (33, 22, 27, 12, '(writer)', 4, 8), + (34, 23, 32, 12, '(writer)', 4, 8), + (35, 21, 33, 13, '(voice)', 2, 2), + (36, 21, 33, 13, '(voice: English version)', 2, 2), + (37, 21, 33, 13, '(voice) (uncredited)', 2, 2), + (38, 22, 39, 25, 'Superman', 1, 1), + (39, 22, 39, 26, 'Ironman', 1, 1), + (40, 22, 39, 27, 'Spiderman', 1, 1), + (41, 19, 52, NULL, NULL, 2, 1), + (42, 14, 19, NULL, NULL, 3, 1), + (43, 6, 2, 2, NULL, 1, 1), + (44, 16, 54, NULL, '(writer)', 1, 4), + (45, 24, 55, NULL, '(director)', 1, 8), + (46, 25, 56, 29, '(voice: English version)', 1, 2), + (47, 18, 26, 30, '(voice: English version)', 1, 2), + (48, 26, 21, 24, '(voice: English version)', 1, 2), + (49, 26, 57, 25, '(voice: English version)', 1, 2), + (50, 27, 25, NULL, NULL, 1, 4), + (51, 18, 62, 32, '(voice)', 1, 2); + +# char_name table +statement ok +CREATE TABLE char_name ( + id INT NOT NULL, + name VARCHAR NOT NULL, + imdb_index VARCHAR, + imdb_id INT, + name_pcode_nf VARCHAR, + surname_pcode VARCHAR, + md5sum VARCHAR +); + +statement ok +INSERT INTO char_name VALUES + (1, 'Andy Dufresne', NULL, NULL, NULL, NULL, NULL), + (2, 'Don Vito Corleone', NULL, NULL, NULL, NULL, NULL), + (3, 'Joker', NULL, NULL, NULL, NULL, NULL), + (4, 'Michael Corleone', NULL, NULL, NULL, NULL, NULL), + (5, 'Vincent Vega', NULL, NULL, NULL, NULL, NULL), + (6, 'Oskar Schindler', NULL, NULL, NULL, NULL, NULL), + (7, 'Gandalf', NULL, NULL, NULL, NULL, NULL), + (8, 'Juror 8', NULL, NULL, NULL, NULL, NULL), + (9, 'Cobb', NULL, NULL, NULL, NULL, NULL), + (10, 'Tyler Durden', NULL, NULL, NULL, NULL, NULL), + (11, 'Batman''s Assistant', NULL, NULL, NULL, NULL, NULL), + (12, 'Tiger', NULL, NULL, NULL, NULL, NULL), + (13, 'Queen', NULL, NULL, NULL, NULL, NULL), + (14, 'Iron Man', NULL, NULL, NULL, NULL, NULL), + (15, 'Master Tigress', NULL, NULL, NULL, NULL, NULL), + (16, 'Dom Cobb', NULL, NULL, NULL, NULL, NULL), + (17, 'Rachel Dawes', NULL, NULL, NULL, NULL, NULL), + (18, 'Arthur Fleck', NULL, NULL, NULL, NULL, NULL), + (19, 'Pepper Potts', NULL, NULL, NULL, NULL, NULL), + (20, 'T''Challa', NULL, NULL, NULL, NULL, NULL), + (21, 'Steve Rogers', NULL, NULL, NULL, NULL, NULL), + (22, 'Ellis Boyd Redding', NULL, NULL, NULL, NULL, NULL), + (23, 'Bruce Wayne', NULL, NULL, NULL, NULL, NULL), + (24, 'Tigress', NULL, NULL, NULL, NULL, NULL), + (25, 'Superman', NULL, NULL, NULL, NULL, NULL), + (26, 'Ironman', NULL, NULL, NULL, NULL, NULL), + (27, 'Spiderman', NULL, NULL, NULL, NULL, NULL), + (28, 'Director', NULL, NULL, NULL, NULL, NULL), + (29, 'Tiger Warrior', NULL, NULL, NULL, NULL, NULL), + (30, 'Tigress', NULL, NULL, NULL, NULL, NULL), + (31, 'Nikolai', NULL, NULL, NULL, NULL, NULL), + (32, 'Princess Dragon', NULL, NULL, NULL, NULL, NULL); + +# keyword table +statement ok +CREATE TABLE keyword ( + id INT NOT NULL, + keyword VARCHAR NOT NULL, + phonetic_code VARCHAR +); + +statement ok +INSERT INTO keyword VALUES + (1, 'prison', NULL), + (2, 'mafia', NULL), + (3, 'superhero', NULL), + (4, 'sequel', NULL), + (5, 'crime', NULL), + (6, 'holocaust', NULL), + (7, 'fantasy', NULL), + (8, 'jury', NULL), + (9, 'dream', NULL), + (10, 'fight', NULL), + (11, 'marvel-cinematic-universe', NULL), + (12, 'character-name-in-title', NULL), + (13, 'female-name-in-title', NULL), + (14, 'murder', NULL), + (15, 'noir', NULL), + (16, 'space', NULL), + (17, 'time-travel', NULL), + (18, 'artificial-intelligence', NULL), + (19, 'robot', NULL), + (20, 'alien', NULL), + (21, '10,000-mile-club', NULL), + (22, 'martial-arts', NULL), + (23, 'computer-animation', NULL), + (24, 'violence', NULL), + (25, 'based-on-novel', NULL), + (26, 'nerd', NULL), + (27, 'marvel-comics', NULL), + (28, 'based-on-comic', NULL), + (29, 'superhero-movie', NULL); + +# movie_keyword table +statement ok +CREATE TABLE movie_keyword ( + id INT NOT NULL, + movie_id INT NOT NULL, + keyword_id INT NOT NULL +); + +statement ok +INSERT INTO movie_keyword VALUES + (1, 1, 1), + (2, 2, 2), + (3, 3, 3), + (4, 4, 4), + (5, 5, 5), + (6, 6, 6), + (7, 7, 7), + (8, 8, 8), + (9, 9, 9), + (10, 10, 10), + (11, 3, 5), + (12, 19, 3), + (13, 19, 12), + (14, 23, 11), + (15, 13, 11), + (16, 24, 11), + (17, 11, 1), + (18, 11, 20), + (19, 11, 20), + (20, 14, 16), + (21, 9, 3), + (22, 3, 14), + (23, 25, 13), + (24, 23, 12), + (25, 2, 4), + (26, 23, 19), + (27, 19, 5), + (28, 23, 3), + (29, 23, 28), + (30, 3, 4), + (31, 3, 4), + (32, 2, 4), + (33, 4, 4), + (34, 11, 4), + (35, 3, 3), + (36, 26, 16), + (37, 13, 11), + (38, 13, 3), + (39, 13, 4), + (40, 9, 17), + (41, 9, 18), + (42, 3, 12), + (43, 13, 13), + (44, 26, 21), + (45, 24, 3), + (46, 9, 14), + (47, 2, 4), + (48, 14, 21), + (49, 27, 14), + (50, 32, 14), + (51, 33, 23), + (52, 33, 23), + (55, 35, 24), + (56, 36, 14), + (57, 36, 25), + (58, 35, 4), + (59, 37, 14), + (60, 37, 25), + (61, 45, 24), + (62, 2, 4), + (63, 14, 21), + (64, 27, 14), + (65, 32, 14), + (66, 33, 23), + (67, 33, 23), + (68, 35, 24), + (69, 38, 4), + (70, 39, 3), + (71, 39, 27), + (72, 39, 28), + (73, 39, 29), + (74, 44, 26), + (75, 52, 12), + (76, 54, 14), + (77, 55, 20), + (78, 55, 16), + (79, 56, 22), + (80, 26, 22), + (81, 3, 4), + (82, 4, 4), + (83, 13, 4), + (84, 3, 4), + (85, 40, 29), + (86, 4, 4), + (87, 13, 4), + (88, 59, 4), + (89, 60, 25), + (90, 48, 14), + (91, 47, 14), + (92, 45, 24), + (93, 46, 3), + (94, 53, 12); + +# company_name table +statement ok +CREATE TABLE company_name ( + id INT NOT NULL, + name VARCHAR NOT NULL, + country_code VARCHAR, + imdb_id INT, + name_pcode_nf VARCHAR, + name_pcode_sf VARCHAR, + md5sum VARCHAR +); + +statement ok +INSERT INTO company_name VALUES + (1, 'Mosfilm', '[ru]', NULL, NULL, NULL, NULL), + (2, 'Toei Animation', '[jp]', NULL, NULL, NULL, NULL), + (3, 'Tokyo Animation Studio', '[jp]', NULL, NULL, NULL, NULL), + (4, 'Castle Rock Entertainment', '[us]', NULL, NULL, NULL, NULL), + (5, 'Paramount Pictures', '[us]', NULL, NULL, NULL, NULL), + (6, 'Warner Bros.', '[us]', NULL, NULL, NULL, NULL), + (7, 'Metro-Goldwyn-Mayer', '[us]', NULL, NULL, NULL, NULL), + (8, 'Miramax Films', '[us]', NULL, NULL, NULL, NULL), + (9, 'Universal Pictures', '[us]', NULL, NULL, NULL, NULL), + (10, 'New Line Cinema', '[us]', NULL, NULL, NULL, NULL), + (11, 'United Artists', '[us]', NULL, NULL, NULL, NULL), + (12, 'Columbia Pictures', '[us]', NULL, NULL, NULL, NULL), + (13, 'Twentieth Century Fox', '[us]', NULL, NULL, NULL, NULL), + (14, 'Marvel Studios', '[us]', NULL, NULL, NULL, NULL), + (15, 'DC Films', '[us]', NULL, NULL, NULL, NULL), + (16, 'YouTube', '[us]', NULL, NULL, NULL, NULL), + (17, 'DreamWorks Pictures', '[us]', NULL, NULL, NULL, NULL), + (18, 'Walt Disney Pictures', '[us]', NULL, NULL, NULL, NULL), + (19, 'Netflix', '[us]', NULL, NULL, NULL, NULL), + (20, 'Amazon Studios', '[us]', NULL, NULL, NULL, NULL), + (21, 'A24', '[us]', NULL, NULL, NULL, NULL), + (22, 'Lionsgate Films', '[us]', NULL, NULL, NULL, NULL), + (23, 'DreamWorks Animation', '[us]', NULL, NULL, NULL, NULL), + (24, 'Sony Pictures', '[us]', NULL, NULL, NULL, NULL), + (25, 'Bavaria Film', '[de]', NULL, NULL, NULL, NULL), + (26, 'Dutch FilmWorks', '[nl]', NULL, NULL, NULL, NULL), + (27, 'San Marino Films', '[sm]', NULL, NULL, NULL, NULL), + (28, 'Legendary Pictures', '[us]', NULL, NULL, NULL, NULL), + (29, 'Dutch Entertainment Group', '[nl]', NULL, NULL, NULL, NULL), + (30, 'Amsterdam Studios', '[nl]', NULL, NULL, NULL, NULL), + (31, 'DreamWorks Animation', '[us]', NULL, NULL, NULL, NULL), + (32, 'Berlin Film Studio', '[de]', NULL, NULL, NULL, NULL), + (33, 'Stockholm Productions', '[se]', NULL, NULL, NULL, NULL), + (34, 'Oslo Films', '[no]', NULL, NULL, NULL, NULL), + (35, 'Warner Bros. Pictures', '[us]', NULL, NULL, NULL, NULL), + (36, 'Silicon Entertainment', '[us]', NULL, NULL, NULL, NULL), + (37, 'DreamWorks Animation', '[us]', NULL, NULL, NULL, NULL); + +# name table for actors/directors information +statement ok +CREATE TABLE name ( + id INT NOT NULL, + name VARCHAR NOT NULL, + imdb_index VARCHAR, + imdb_id INT, + gender VARCHAR, + name_pcode_cf VARCHAR, + name_pcode_nf VARCHAR, + surname_pcode VARCHAR, + md5sum VARCHAR +); + +statement ok +INSERT INTO name VALUES + (1, 'Xavier Thompson', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (2, 'Susan Hill', NULL, NULL, 'f', NULL, NULL, NULL, NULL), + (3, 'Tim Robbins', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (4, 'Marlon Brando', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (5, 'Heath Ledger', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (6, 'Al Pacino', NULL, NULL, 'm', 'A', NULL, NULL, NULL), + (7, 'Downey Pacino', NULL, NULL, 'm', 'D', NULL, NULL, NULL), + (8, 'John Travolta', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (9, 'Liam Neeson', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (10, 'Ian McKellen', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (11, 'Henry Fonda', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (12, 'Leonardo DiCaprio', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (13, 'Downey Robert Jr.', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (14, 'Zach Wilson', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (15, 'Bert Wilson', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (29, 'Alex Morgan', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (16, 'Christian Bale', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (17, 'Christopher Nolan', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (18, 'Angelina Jolie', NULL, NULL, 'f', NULL, NULL, NULL, NULL), + (19, 'Brad Wilson', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (20, 'Derek Jacobi', NULL, NULL, 'm', 'D624', NULL, NULL, NULL), + (21, 'Anne Hathaway', NULL, NULL, 'f', NULL, NULL, NULL, NULL), + (22, 'John Carpenter', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (23, 'James Wan', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (24, 'Ridley Scott', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (25, 'Angelina Jolie', NULL, NULL, 'f', NULL, NULL, NULL, NULL), + (26, 'Yoko Tanaka', NULL, NULL, 'f', NULL, NULL, NULL, NULL), + (27, 'James Cameron', NULL, NULL, 'm', NULL, NULL, NULL, NULL), + (28, 'Edith Head', NULL, NULL, 'f', NULL, NULL, NULL, NULL), + (29, 'Anne Hathaway', NULL, NULL, 'f', NULL, NULL, NULL, NULL); + +# aka_name table +statement ok +CREATE TABLE aka_name ( + id INT NOT NULL, + person_id INT NOT NULL, + name VARCHAR NOT NULL, + imdb_index VARCHAR, + name_pcode_cf VARCHAR, + name_pcode_nf VARCHAR, + surname_pcode VARCHAR, + md5sum VARCHAR +); + +statement ok +INSERT INTO aka_name VALUES + (1, 2, 'Marlon Brando Jr.', NULL, NULL, NULL, NULL, NULL), + (2, 2, 'Marlon Brando', NULL, NULL, NULL, NULL, NULL), + (3, 3, 'Heath Andrew Ledger', NULL, NULL, NULL, NULL, NULL), + (4, 6, 'Alfredo James Pacino', NULL, NULL, NULL, NULL, NULL), + (5, 5, 'John Joseph Travolta', NULL, NULL, NULL, NULL, NULL), + (6, 6, 'Liam John Neeson', NULL, NULL, NULL, NULL, NULL), + (7, 7, 'Ian Murray McKellen', NULL, NULL, NULL, NULL, NULL), + (8, 8, 'Henry Jaynes Fonda', NULL, NULL, NULL, NULL, NULL), + (9, 9, 'Leonardo Wilhelm DiCaprio', NULL, NULL, NULL, NULL, NULL), + (10, 10, 'Robert John Downey Jr.', NULL, NULL, NULL, NULL, NULL), + (11, 16, 'Christian Charles Philip Bale', NULL, NULL, NULL, NULL, NULL), + (12, 29, 'Christopher Jonathan James Nolan', NULL, NULL, NULL, NULL, NULL), + (13, 47, 'Joaquin Rafael Bottom', NULL, NULL, NULL, NULL, NULL), + (14, 26, 'Yoko Shimizu', NULL, NULL, NULL, NULL, NULL), + (15, 48, 'Chadwick Aaron Boseman', NULL, NULL, NULL, NULL, NULL), + (16, 29, 'Scarlett Ingrid Johansson', NULL, NULL, NULL, NULL, NULL), + (17, 31, 'Christopher Robert Evans', NULL, NULL, NULL, NULL, NULL), + (18, 32, 'Christopher Hemsworth', NULL, NULL, NULL, NULL, NULL), + (19, 33, 'Mark Alan Ruffalo', NULL, NULL, NULL, NULL, NULL), + (20, 20, 'Sir Derek Jacobi', NULL, NULL, NULL, NULL, NULL), + (21, 34, 'Samuel Leroy Jackson', NULL, NULL, NULL, NULL, NULL), + (22, 35, 'Gwyneth Kate Paltrow', NULL, NULL, NULL, NULL, NULL), + (23, 36, 'Thomas William Hiddleston', NULL, NULL, NULL, NULL, NULL), + (24, 37, 'Morgan Porterfield Freeman', NULL, NULL, NULL, NULL, NULL), + (25, 38, 'William Bradley Pitt', NULL, NULL, NULL, NULL, NULL), + (26, 39, 'Edward John Norton Jr.', NULL, NULL, NULL, NULL, NULL), + (27, 40, 'Marion Cotillard', NULL, NULL, NULL, NULL, NULL), + (28, 41, 'Joseph Leonard Gordon-Levitt', NULL, NULL, NULL, NULL, NULL), + (29, 42, 'Matthew David McConaughey', NULL, NULL, NULL, NULL, NULL), + (30, 43, 'Anne Jacqueline Hathaway', NULL, NULL, NULL, NULL, NULL), + (31, 44, 'Kevin Feige', NULL, NULL, NULL, NULL, NULL), + (32, 45, 'Margaret Ruth Gyllenhaal', NULL, NULL, NULL, NULL, NULL), + (33, 46, 'Kate Elizabeth Winslet', NULL, NULL, NULL, NULL, NULL), + (34, 28, 'E. Head', NULL, NULL, NULL, NULL, NULL), + (35, 29, 'Anne Jacqueline Hathaway', NULL, NULL, NULL, NULL, NULL), + (36, 29, 'Alexander Morgan', NULL, NULL, NULL, NULL, NULL), + (37, 2, 'Brando, M.', NULL, NULL, NULL, NULL, NULL), + (38, 21, 'Annie Hathaway', NULL, NULL, NULL, NULL, NULL), + (39, 21, 'Annie H', NULL, NULL, NULL, NULL, NULL), + (40, 25, 'Angie Jolie', NULL, NULL, NULL, NULL, NULL), + (41, 27, 'Jim Cameron', NULL, NULL, NULL, NULL, NULL), + (42, 18, 'Angelina Jolie', NULL, NULL, NULL, NULL, NULL); + +# role_type table +statement ok +CREATE TABLE role_type ( + id INT NOT NULL, + role VARCHAR NOT NULL +); + +statement ok +INSERT INTO role_type VALUES + (1, 'actor'), + (2, 'actress'), + (3, 'producer'), + (4, 'writer'), + (5, 'cinematographer'), + (6, 'composer'), + (7, 'costume designer'), + (8, 'director'), + (9, 'editor'), + (10, 'miscellaneous crew'); + +# link_type table +statement ok +CREATE TABLE link_type ( + id INT NOT NULL, + link VARCHAR NOT NULL +); + +statement ok +INSERT INTO link_type VALUES + (1, 'sequel'), + (2, 'follows'), + (3, 'remake of'), + (4, 'version of'), + (5, 'spin off from'), + (6, 'reference to'), + (7, 'featured in'), + (8, 'spoofed in'), + (9, 'edited into'), + (10, 'alternate language version of'), + (11, 'features'); + +# movie_link table +statement ok +CREATE TABLE movie_link ( + id INT NOT NULL, + movie_id INT NOT NULL, + linked_movie_id INT NOT NULL, + link_type_id INT NOT NULL +); + +statement ok +INSERT INTO movie_link VALUES + (1, 2, 4, 1), + (2, 3, 5, 6), + (3, 6, 7, 4), + (4, 8, 9, 8), + (5, 10, 1, 3), + (6, 28, 29, 1), + (7, 30, 31, 2), + (8, 1, 3, 6), + (9, 23, 13, 1), + (10, 13, 24, 2), + (11, 20, 3, 1), + (12, 3, 22, 1), + (13, 2, 4, 2), + (14, 19, 19, 6), + (15, 14, 16, 6), + (16, 13, 23, 2), + (17, 25, 9, 4), + (18, 17, 1, 8), + (19, 24, 23, 2), + (20, 21, 22, 1), + (21, 15, 9, 6), + (22, 11, 13, 1), + (23, 13, 11, 2), + (24, 100, 100, 7), + (25, 1, 2, 7), + (26, 23, 2, 7), + (27, 14, 25, 9), + (28, 4, 6, 4), + (29, 5, 8, 6), + (30, 7, 10, 6), + (31, 9, 2, 8), + (32, 38, 39, 2), + (33, 59, 5, 2), + (34, 60, 9, 2), + (35, 49, 49, 11), + (36, 35, 36, 2); + +# complete_cast table +statement ok +CREATE TABLE complete_cast ( + id INT NOT NULL, + movie_id INT NOT NULL, + subject_id INT NOT NULL, + status_id INT NOT NULL +); + +statement ok +INSERT INTO complete_cast VALUES + (1, 1, 1, 1), + (2, 2, 1, 1), + (3, 3, 1, 1), + (4, 4, 1, 1), + (5, 5, 1, 1), + (6, 6, 1, 1), + (7, 7, 1, 1), + (8, 8, 1, 1), + (9, 9, 1, 1), + (10, 10, 1, 1), + (11, 11, 1, 1), + (12, 12, 1, 1), + (13, 13, 1, 1), + (14, 14, 1, 1), + (15, 15, 1, 1), + (16, 16, 1, 1), + (17, 17, 1, 1), + (18, 18, 1, 1), + (19, 19, 1, 2), + (20, 20, 2, 1), + (21, 21, 1, 1), + (22, 22, 1, 1), + (23, 23, 1, 3), + (24, 24, 1, 1), + (25, 25, 1, 1), + (26, 26, 1, 1), + (27, 13, 2, 4), + (28, 44, 1, 4), + (29, 33, 1, 4), + (30, 31, 1, 1), + (31, 32, 1, 4), + (32, 33, 1, 4), + (33, 35, 2, 3), + (34, 36, 2, 3), + (35, 37, 1, 4), + (36, 37, 1, 3), + (37, 38, 1, 3), + (38, 39, 1, 3), + (39, 39, 1, 11), + (40, 40, 1, 4); + +# comp_cast_type table +statement ok +CREATE TABLE comp_cast_type ( + id INT NOT NULL, + kind VARCHAR NOT NULL +); + +statement ok +INSERT INTO comp_cast_type VALUES + (1, 'cast'), + (2, 'crew'), + (3, 'complete'), + (4, 'complete+verified'), + (5, 'pending'), + (6, 'unverified'), + (7, 'uncredited cast'), + (8, 'uncredited crew'), + (9, 'unverified cast'), + (10, 'unverified crew'), + (11, 'complete cast'); + +# person_info table +statement ok +CREATE TABLE person_info ( + id INT NOT NULL, + person_id INT NOT NULL, + info_type_id INT NOT NULL, + info VARCHAR NOT NULL, + note VARCHAR +); + +statement ok +INSERT INTO person_info VALUES + (1, 1, 3, 'actor,producer', NULL), + (2, 2, 3, 'actor,director', NULL), + (3, 3, 3, 'actor', NULL), + (4, 6, 3, 'actor,producer', NULL), + (5, 5, 3, 'actor', NULL), + (6, 6, 3, 'actor', NULL), + (7, 7, 3, 'actor', NULL), + (8, 8, 3, 'actor', NULL), + (9, 20, 30, 'Renowned Shakespearean actor and stage performer', 'Volker Boehm'), + (10, 10, 3, 'actor,producer', 'marvel-cinematic-universe'), + (11, 3, 1, 'Won Academy Award for portrayal of Joker', NULL), + (12, 10, 1, 'Played Iron Man in the Marvel Cinematic Universe', NULL), + (13, 16, 3, 'actor', NULL), + (14, 16, 1, 'Played Batman in The Dark Knight trilogy', NULL), + (15, 29, 3, 'director,producer,writer', NULL), + (16, 29, 1, 'Directed The Dark Knight trilogy', NULL), + (17, 47, 3, 'actor', NULL), + (18, 47, 1, 'Won Academy Award for portrayal of Joker', NULL), + (19, 48, 3, 'actor', NULL), + (20, 48, 1, 'Played Black Panther in the Marvel Cinematic Universe', NULL), + (21, 29, 3, 'actress', NULL), + (22, 29, 1, 'Played Black Widow in the Marvel Cinematic Universe', NULL), + (23, 31, 3, 'actor', NULL), + (24, 31, 1, 'Played Captain America in the Marvel Cinematic Universe', NULL), + (25, 32, 3, 'actor', NULL), + (26, 32, 1, 'Played Thor in the Marvel Cinematic Universe', NULL), + (27, 9, 1, 'Won Academy Award for The Revenant', NULL), + (28, 9, 7, '1974-11-11', NULL), + (29, 10, 7, '1965-04-04', NULL), + (30, 16, 7, '1974-01-30', NULL), + (31, 47, 7, '1974-10-28', NULL), + (32, 48, 7, '1976-11-29', NULL), + (33, 29, 7, '1984-11-22', NULL), + (34, 31, 7, '1981-06-13', NULL), + (35, 32, 7, '1983-08-11', NULL), + (36, 21, 14, 'Won an Oscar for Les Miserables.', 'IMDB staff'), + (37, 21, 14, 'Voiced Queen in Shrek 2.', 'IMDB staff'), + (38, 21, 28, '5 ft 8 in (1.73 m)', 'IMDB staff'), + (39, 6, 30, 'Famous for his role in The Godfather', 'Volker Boehm'); + +# aka_title table +statement ok +CREATE TABLE aka_title ( + id INT NOT NULL, + movie_id INT NOT NULL, + title VARCHAR, + imdb_index VARCHAR, + kind_id INT NOT NULL, + production_year INT, + phonetic_code VARCHAR, + episode_of_id INT, + season_nr INT, + episode_nr INT, + note VARCHAR, + md5sum VARCHAR +); + +statement ok +INSERT INTO aka_title VALUES + (1, 1, 'Shawshank', NULL, 1, 1994, NULL, NULL, NULL, NULL, NULL, NULL), + (2, 2, 'Der Pate', NULL, 1, 1972, NULL, NULL, NULL, NULL, 'German title', NULL), + (3, 3, 'The Dark Knight', NULL, 1, 2008, NULL, NULL, NULL, NULL, NULL, NULL), + (4, 4, 'Der Pate II', NULL, 1, 1974, NULL, NULL, NULL, NULL, 'German title', NULL), + (5, 5, 'Pulp Fiction', NULL, 1, 1994, NULL, NULL, NULL, NULL, NULL, NULL), + (6, 6, 'La lista di Schindler', NULL, 1, 1993, NULL, NULL, NULL, NULL, 'Italian title', NULL), + (7, 7, 'LOTR: ROTK', NULL, 1, 2003, NULL, NULL, NULL, NULL, 'Abbreviated', NULL), + (8, 8, '12 Angry Men', NULL, 1, 1957, NULL, NULL, NULL, NULL, NULL, NULL), + (9, 9, 'Dream Heist', NULL, 1, 2010, NULL, NULL, NULL, NULL, 'Working title', NULL), + (10, 10, 'Fight Club', NULL, 1, 1999, NULL, NULL, NULL, NULL, NULL, NULL), + (11, 3, 'Batman: The Dark Knight', NULL, 1, 2008, NULL, NULL, NULL, NULL, 'Full title', NULL), + (12, 13, 'Avengers 4', NULL, 1, 2019, NULL, NULL, NULL, NULL, 'Abbreviated', NULL), + (13, 19, 'The Joker', NULL, 1, 2019, NULL, NULL, NULL, NULL, 'Working title', NULL), + (14, 23, 'Iron Man: Birth of a Hero', NULL, 1, 2008, NULL, NULL, NULL, NULL, 'Extended title', NULL), + (15, 24, 'Black Panther: Wakanda Forever', NULL, 1, 2018, NULL, NULL, NULL, NULL, 'Alternate title', NULL), + (16, 11, 'Avengers 3', NULL, 1, 2018, NULL, NULL, NULL, NULL, 'Abbreviated', NULL), + (17, 3, 'Batman 2', NULL, 1, 2008, NULL, NULL, NULL, NULL, 'Sequel numbering', NULL), + (18, 20, 'Batman: Year One', NULL, 1, 2005, NULL, NULL, NULL, NULL, 'Working title', NULL), + (19, 14, 'Journey to the Stars', NULL, 1, 2014, NULL, NULL, NULL, NULL, 'Working title', NULL), + (20, 25, 'Rose and Jack', NULL, 1, 1997, NULL, NULL, NULL, NULL, 'Character-based title', NULL), + (21, 19, 'Joker: A Descent Into Madness', NULL, 1, 2019, NULL, NULL, NULL, NULL, 'Extended title', NULL), + (22, 22, 'Batman 3', NULL, 1, 2012, NULL, NULL, NULL, NULL, 'Sequel numbering', NULL), + (23, 1, 'The Shawshank Redemption', NULL, 1, 1994, NULL, NULL, NULL, NULL, 'Full title', NULL), + (24, 19, 'El Joker', NULL, 1, 2019, NULL, NULL, NULL, NULL, 'Spanish title', NULL), + (25, 13, 'Los Vengadores: Endgame', NULL, 1, 2019, NULL, NULL, NULL, NULL, 'Spanish title', NULL), + (26, 19, 'The Batman', NULL, 1, 2022, NULL, NULL, NULL, NULL, 'Working title', NULL), + (27, 41, 'Champion Boxer: The Rise of a Legend', NULL, 1, 2018, NULL, NULL, NULL, NULL, 'Extended title', NULL), + (28, 47, 'The Swedish Murder Case', NULL, 1, 2012, NULL, NULL, NULL, NULL, 'Full title', NULL), + (29, 46, 'Viral Documentary', NULL, 1, 2008, NULL, NULL, NULL, NULL, 'Alternate title', NULL), + (30, 45, 'Berlin Noir', NULL, 1, 2010, 989898, NULL, NULL, NULL, NULL, NULL), + (31, 44, 'Digital Connection', NULL, 1, 2005, NULL, NULL, NULL, NULL, NULL, NULL), + (32, 62, 'Animated Feature', NULL, 1, 2010, 123456, NULL, NULL, NULL, NULL, NULL); + +# 1a - Query with production companies and top 250 rank +query TTI +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t +WHERE ct.kind = 'production companies' + AND it.info = 'top 250 rank' + AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' and (mc.note like '%(co-production)%' or mc.note like '%(presents)%') + AND ct.id = mc.company_type_id + AND t.id = mc.movie_id + AND t.id = mi_idx.movie_id + AND mc.movie_id = mi_idx.movie_id + AND it.id = mi_idx.info_type_id +---- +(co-production) Avengers: Endgame 1985 + +# 1b - Query with production companies and bottom 10 rank +query TTI +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t +WHERE ct.kind = 'production companies' + AND it.info = 'bottom 10 rank' + AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' + AND t.production_year between 2005 and 2010 + AND ct.id = mc.company_type_id + AND t.id = mc.movie_id + AND t.id = mi_idx.movie_id + AND mc.movie_id = mi_idx.movie_id + AND it.id = mi_idx.info_type_id +---- +(as Warner Bros. Pictures) Bad Movie Sequel 2008 + +# 1c - Query with distributors and top 250 rank +query TTI +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t +WHERE ct.kind = 'production companies' + AND it.info = 'top 250 rank' + AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' and (mc.note like '%(co-production)%') + AND t.production_year >2010 + AND ct.id = mc.company_type_id + AND t.id = mc.movie_id + AND t.id = mi_idx.movie_id + AND mc.movie_id = mi_idx.movie_id + AND it.id = mi_idx.info_type_id +---- +(co-production) Avengers: Endgame 2014 + +# 1d - Query with production companies and top 250 rank (different production year) +query TTI +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t +WHERE ct.kind = 'production companies' + AND it.info = 'bottom 10 rank' + AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' + AND t.production_year >2000 + AND ct.id = mc.company_type_id + AND t.id = mc.movie_id + AND t.id = mi_idx.movie_id + AND mc.movie_id = mi_idx.movie_id + AND it.id = mi_idx.info_type_id +---- +(as Warner Bros. Pictures) Bad Movie Sequel 2008 + +# 2a - Query with German companies and character-name-in-title +query T +SELECT MIN(t.title) AS movie_title +FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t +WHERE cn.country_code ='[de]' + AND k.keyword ='character-name-in-title' + AND cn.id = mc.company_id + AND mc.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND mc.movie_id = mk.movie_id +---- +Joker + +# 2b - Query with Dutch companies and character-name-in-title +query T +SELECT MIN(t.title) AS movie_title +FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t +WHERE cn.country_code ='[nl]' + AND k.keyword ='character-name-in-title' + AND cn.id = mc.company_id + AND mc.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND mc.movie_id = mk.movie_id +---- +Iron Man + +# 2c - Query with Slovenian companies and female name in title +query T +SELECT MIN(t.title) AS movie_title +FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t +WHERE cn.country_code ='[sm]' + AND k.keyword ='character-name-in-title' + AND cn.id = mc.company_id + AND mc.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND mc.movie_id = mk.movie_id +---- +Joker + +# 2d - Query with US companies and murder movies +query T +SELECT MIN(t.title) AS movie_title +FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND cn.id = mc.company_id + AND mc.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND mc.movie_id = mk.movie_id +---- +Bruno + +# 3a - Query with runtimes > 100 +query T +SELECT MIN(t.title) AS movie_title +FROM keyword AS k, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE k.keyword like '%sequel%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German') + AND t.production_year > 2005 + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND mk.movie_id = mi.movie_id + AND k.id = mk.keyword_id +---- +The Godfather Part II + +# 3b - Query with Bulgarian movies +query T +SELECT MIN(t.title) AS movie_title +FROM keyword AS k, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE k.keyword like '%sequel%' + AND mi.info IN ('Bulgaria') + AND t.production_year > 2000 + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND mk.movie_id = mi.movie_id + AND k.id = mk.keyword_id +---- +The Dark Knight + +# 3c - Query with biographies +query T +SELECT MIN(t.title) AS movie_title +FROM keyword AS k, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE k.keyword like '%sequel%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') + AND t.production_year > 1990 + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND mk.movie_id = mi.movie_id + AND k.id = mk.keyword_id +---- +Avengers: Endgame + +# 4a - Query with certain actor names +query TT +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS movie_title +FROM info_type AS it, keyword AS k, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE it.info ='rating' + AND k.keyword like '%sequel%' + AND mi_idx.info > '5.0' + AND t.production_year > 2005 + AND t.id = mi_idx.movie_id + AND t.id = mk.movie_id + AND mk.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it.id = mi_idx.info_type_id +---- +8.9 Avengers: Endgame + +# 4b - Query with certain actor names (revised) +query TT +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS movie_title +FROM info_type AS it, keyword AS k, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE it.info ='rating' + AND k.keyword like '%sequel%' + AND mi_idx.info > '9.0' + AND t.production_year > 2000 + AND t.id = mi_idx.movie_id + AND t.id = mk.movie_id + AND mk.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it.id = mi_idx.info_type_id +---- +9.1 The Dark Knight + +# 4c - Query with actors from certain period +query TT +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS movie_title +FROM info_type AS it, keyword AS k, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE it.info ='rating' + AND k.keyword like '%sequel%' + AND mi_idx.info > '2.0' + AND t.production_year > 1990 + AND t.id = mi_idx.movie_id + AND t.id = mk.movie_id + AND mk.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it.id = mi_idx.info_type_id +---- +7.2 Avengers: Endgame + +# 5a - Query with keyword and movie links +query T +SELECT MIN(t.title) AS typical_european_movie +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info AS mi, title AS t +WHERE ct.kind = 'production companies' + AND mc.note like '%(theatrical)%' and mc.note like '%(France)%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German') + AND t.production_year > 2005 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND mc.movie_id = mi.movie_id + AND ct.id = mc.company_type_id + AND it.id = mi.info_type_id +---- +The Matrix + +# 5b - Query with keyword and directors +query T +SELECT MIN(t.title) AS american_vhs_movie +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info AS mi, title AS t +WHERE ct.kind = 'production companies' + AND mc.note like '%(VHS)%' and mc.note like '%(USA)%' and mc.note like '%(1994)%' + AND mi.info IN ('USA', 'America') + AND t.production_year > 2000 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND mc.movie_id = mi.movie_id + AND ct.id = mc.company_type_id + AND it.id = mi.info_type_id +---- +The Matrix + +# 5c - Query with female leading roles +query T +SELECT MIN(t.title) AS american_movie +FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info AS mi, title AS t +WHERE ct.kind = 'production companies' + AND mc.note not like '%(TV)%' and mc.note like '%(USA)%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') + AND t.production_year > 1990 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND mc.movie_id = mi.movie_id + AND ct.id = mc.company_type_id + AND it.id = mi.info_type_id +---- +Champion Boxer + +# 6a - Query for Marvel movies with Robert Downey +query TTT +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS marvel_movie +FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword = 'marvel-cinematic-universe' + AND n.name LIKE '%Downey%Robert%' + AND t.production_year > 2010 + AND k.id = mk.keyword_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mk.movie_id + AND n.id = ci.person_id +---- +marvel-cinematic-universe Downey Robert Jr. Avengers: Endgame + +# 6b - Query for male actors in movies after 2009 +query TTT +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS hero_movie +FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') + AND n.name LIKE '%Downey%Robert%' + AND t.production_year > 2014 + AND k.id = mk.keyword_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mk.movie_id + AND n.id = ci.person_id +---- +sequel Downey Robert Jr. Avengers: Endgame + +# 6c - Query for superhero movies from specific year +query TTT +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS marvel_movie +FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword = 'marvel-cinematic-universe' + AND n.name LIKE '%Downey%Robert%' + AND t.production_year > 2014 + AND k.id = mk.keyword_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mk.movie_id + AND n.id = ci.person_id +---- +marvel-cinematic-universe Downey Robert Jr. Avengers: Endgame + +# 6d - Query for specific director +query TTT +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS hero_movie +FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') + AND n.name LIKE '%Downey%Robert%' + AND t.production_year > 2000 + AND k.id = mk.keyword_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mk.movie_id + AND n.id = ci.person_id +---- +based-on-comic Downey Robert Jr. Avengers: Endgame + +# 6e - Query for advanced superhero movies +query TTT +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS marvel_movie +FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword = 'marvel-cinematic-universe' + AND n.name LIKE '%Downey%Robert%' + AND t.production_year > 2000 + AND k.id = mk.keyword_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mk.movie_id + AND n.id = ci.person_id +---- +marvel-cinematic-universe Downey Robert Jr. Avengers: Endgame + +# 6f - Query for complex superhero movies +query TTT +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS hero_movie +FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') + AND t.production_year > 2000 + AND k.id = mk.keyword_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mk.movie_id + AND n.id = ci.person_id +---- +based-on-comic Al Pacino Avengers: Endgame + +# 7a - Query about character names +query TT +SELECT MIN(n.name) AS of_person, MIN(t.title) AS biography_movie +FROM aka_name AS an, cast_info AS ci, info_type AS it, link_type AS lt, movie_link AS ml, name AS n, person_info AS pi, title AS t +WHERE an.name LIKE '%a%' + AND it.info ='mini biography' + AND lt.link ='features' + AND n.name_pcode_cf BETWEEN 'A' + AND 'F' + AND (n.gender='m' OR (n.gender = 'f' + AND n.name LIKE 'B%')) + AND pi.note ='Volker Boehm' + AND t.production_year BETWEEN 1980 + AND 1995 + AND n.id = an.person_id + AND n.id = pi.person_id + AND ci.person_id = n.id + AND t.id = ci.movie_id + AND ml.linked_movie_id = t.id + AND lt.id = ml.link_type_id + AND it.id = pi.info_type_id + AND pi.person_id = an.person_id + AND pi.person_id = ci.person_id + AND an.person_id = ci.person_id + AND ci.movie_id = ml.linked_movie_id -- #Al Pacino The Godfather +---- +Derek Jacobi Derek Jacobi Story + +# 7b - Query for person with biography +query TT +SELECT MIN(n.name) AS of_person, MIN(t.title) AS biography_movie +FROM aka_name AS an, cast_info AS ci, info_type AS it, link_type AS lt, movie_link AS ml, name AS n, person_info AS pi, title AS t +WHERE an.name LIKE '%a%' + AND it.info ='mini biography' + AND lt.link ='features' + AND n.name_pcode_cf LIKE 'D%' + AND n.gender='m' + AND pi.note ='Volker Boehm' + AND t.production_year BETWEEN 1980 + AND 1984 + AND n.id = an.person_id + AND n.id = pi.person_id + AND ci.person_id = n.id + AND t.id = ci.movie_id + AND ml.linked_movie_id = t.id + AND lt.id = ml.link_type_id + AND it.id = pi.info_type_id + AND pi.person_id = an.person_id + AND pi.person_id = ci.person_id + AND an.person_id = ci.person_id + AND ci.movie_id = ml.linked_movie_id +---- +Derek Jacobi Derek Jacobi Story + +# 7c - Query for extended character names and biographies +query TT +SELECT MIN(n.name) AS cast_member_name, MIN(pi.info) AS cast_member_info +FROM aka_name AS an, cast_info AS ci, info_type AS it, link_type AS lt, movie_link AS ml, name AS n, person_info AS pi, title AS t +WHERE an.name is not NULL and (an.name LIKE '%a%' or an.name LIKE 'A%') + AND it.info ='mini biography' + AND lt.link in ('references', 'referenced in', 'features', 'featured in') + AND n.name_pcode_cf BETWEEN 'A' + AND 'F' + AND (n.gender='m' OR (n.gender = 'f' + AND n.name LIKE 'A%')) + AND pi.note is not NULL + AND t.production_year BETWEEN 1980 + AND 2010 + AND n.id = an.person_id + AND n.id = pi.person_id + AND ci.person_id = n.id + AND t.id = ci.movie_id + AND ml.linked_movie_id = t.id + AND lt.id = ml.link_type_id + AND it.id = pi.info_type_id + AND pi.person_id = an.person_id + AND pi.person_id = ci.person_id + AND an.person_id = ci.person_id + AND ci.movie_id = ml.linked_movie_id +---- +Al Pacino Famous for his role in The Godfather + +# 8a - Find movies by keyword +query TT +SELECT MIN(an1.name) AS actress_pseudonym, MIN(t.title) AS japanese_movie_dubbed +FROM aka_name AS an1, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n1, role_type AS rt, title AS t +WHERE ci.note ='(voice: English version)' + AND cn.country_code ='[jp]' + AND mc.note like '%(Japan)%' and mc.note not like '%(USA)%' + AND n1.name like '%Yo%' and n1.name not like '%Yu%' + AND rt.role ='actress' + AND an1.person_id = n1.id + AND n1.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND an1.person_id = ci.person_id + AND ci.movie_id = mc.movie_id +---- +Yoko Shimizu One Piece: Grand Adventure + +# 8b - Query for anime voice actors +query TT +SELECT MIN(an.name) AS acress_pseudonym, MIN(t.title) AS japanese_anime_movie +FROM aka_name AS an, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t +WHERE ci.note ='(voice: English version)' + AND cn.country_code ='[jp]' + AND mc.note like '%(Japan)%' and mc.note not like '%(USA)%' and (mc.note like '%(2006)%' or mc.note like '%(2007)%') + AND n.name like '%Yo%' and n.name not like '%Yu%' + AND rt.role ='actress' + AND t.production_year between 2006 and 2007 and (t.title like 'One Piece%' or t.title like 'Dragon Ball Z%') + AND an.person_id = n.id + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND an.person_id = ci.person_id + AND ci.movie_id = mc.movie_id +---- +Yoko Shimizu One Piece: Grand Adventure + +# 8c - Query for extended movies by keyword and voice actors +query TT +SELECT MIN(a1.name) AS writer_pseudo_name, MIN(t.title) AS movie_title +FROM aka_name AS a1, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n1, role_type AS rt, title AS t +WHERE cn.country_code ='[us]' + AND rt.role ='writer' + AND a1.person_id = n1.id + AND n1.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND a1.person_id = ci.person_id + AND ci.movie_id = mc.movie_id +---- +Jim Cameron Titanic + +# 8d - Query for specialized movies by keyword and voice actors +query TT +SELECT MIN(an1.name) AS costume_designer_pseudo, MIN(t.title) AS movie_with_costumes +FROM aka_name AS an1, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n1, role_type AS rt, title AS t +WHERE cn.country_code ='[us]' + AND rt.role ='costume designer' + AND an1.person_id = n1.id + AND n1.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND an1.person_id = ci.person_id + AND ci.movie_id = mc.movie_id +---- +E. Head Avengers: Endgame + +# 9a - Query for movie sequels +query TTT +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS character_name, MIN(t.title) AS movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND mc.note is not NULL and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') + AND n.gender ='f' and n.name like '%Ang%' + AND rt.role ='actress' + AND t.production_year between 2005 and 2015 + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND ci.movie_id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND n.id = ci.person_id + AND chn.id = ci.person_role_id + AND an.person_id = n.id + AND an.person_id = ci.person_id +---- +Angelina Jolie Batman's Assistant Kung Fu Panda + +# 9b - Query for voice actors in American movies +query TTTT +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS voiced_character, MIN(n.name) AS voicing_actress, MIN(t.title) AS american_movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t +WHERE ci.note = '(voice)' + AND cn.country_code ='[us]' + AND mc.note like '%(200%)%' and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') + AND n.gender ='f' and n.name like '%Angel%' + AND rt.role ='actress' + AND t.production_year between 2007 and 2010 + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND ci.movie_id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND n.id = ci.person_id + AND chn.id = ci.person_role_id + AND an.person_id = n.id + AND an.person_id = ci.person_id +---- +Angelina Jolie Batman's Assistant Angelina Jolie Kung Fu Panda + +# 9c - Query for extended movie sequels and voice actors +query TTTT +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS voiced_character_name, MIN(n.name) AS voicing_actress, MIN(t.title) AS american_movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND ci.movie_id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND n.id = ci.person_id + AND chn.id = ci.person_role_id + AND an.person_id = n.id + AND an.person_id = ci.person_id +---- +Alexander Morgan Batman's Assistant Angelina Jolie Dragon Warriors + +# 9d - Query for specialized movie sequels and voice actors +query TTTT +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS voiced_char_name, MIN(n.name) AS voicing_actress, MIN(t.title) AS american_movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND n.gender ='f' + AND rt.role ='actress' + AND ci.movie_id = t.id + AND t.id = mc.movie_id + AND ci.movie_id = mc.movie_id + AND mc.company_id = cn.id + AND ci.role_id = rt.id + AND n.id = ci.person_id + AND chn.id = ci.person_role_id + AND an.person_id = n.id + AND an.person_id = ci.person_id +---- +Alexander Morgan Batman's Assistant Angelina Jolie Dragon Warriors + +# 10a - Query for cast combinations +query TT +SELECT MIN(chn.name) AS uncredited_voiced_character, MIN(t.title) AS russian_movie +FROM char_name AS chn, cast_info AS ci, company_name AS cn, company_type AS ct, movie_companies AS mc, role_type AS rt, title AS t +WHERE ci.note like '%(voice)%' and ci.note like '%(uncredited)%' + AND cn.country_code = '[ru]' + AND rt.role = 'actor' + AND t.production_year > 2005 + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mc.movie_id + AND chn.id = ci.person_role_id + AND rt.id = ci.role_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +Nikolai Moscow Nights + +# 10b - Query for Russian movie producers who are also actors +query TT +SELECT MIN(chn.name) AS character, MIN(t.title) AS russian_mov_with_actor_producer +FROM char_name AS chn, cast_info AS ci, company_name AS cn, company_type AS ct, movie_companies AS mc, role_type AS rt, title AS t +WHERE ci.note like '%(producer)%' + AND cn.country_code = '[ru]' + AND rt.role = 'actor' + AND t.production_year > 2000 + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mc.movie_id + AND chn.id = ci.person_role_id + AND rt.id = ci.role_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +Director Moscow Nights + +# 10c - Query for American producers in movies +query TT +SELECT MIN(chn.name) AS character, MIN(t.title) AS movie_with_american_producer +FROM char_name AS chn, cast_info AS ci, company_name AS cn, company_type AS ct, movie_companies AS mc, role_type AS rt, title AS t +WHERE ci.note like '%(producer)%' + AND cn.country_code = '[us]' + AND t.production_year > 1990 + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mc.movie_id + AND chn.id = ci.person_role_id + AND rt.id = ci.role_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +Bruce Wayne The Dark Knight + +# 11a - Query for non-Polish companies with sequels +query TTT +SELECT MIN(cn.name) AS from_company, MIN(lt.link) AS movie_link_type, MIN(t.title) AS non_polish_sequel_movie +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND t.production_year BETWEEN 1950 + AND 2000 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id +---- +Warner Bros. follows Money Talks + +# 11b - Query for non-Polish companies with Money sequels from 1998 +query TTT +SELECT MIN(cn.name) AS from_company, MIN(lt.link) AS movie_link_type, MIN(t.title) AS sequel_movie +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follows%' + AND mc.note IS NULL + AND t.production_year = 1998 and t.title like '%Money%' + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id +---- +Warner Bros. Pictures follows Money Talks + +# 11c - Query for Fox movies based on novels +query TTT +SELECT MIN(cn.name) AS from_company, MIN(mc.note) AS production_note, MIN(t.title) AS movie_based_on_book +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' and (cn.name like '20th Century Fox%' or cn.name like 'Twentieth Century Fox%') + AND ct.kind != 'production companies' and ct.kind is not NULL + AND k.keyword in ('sequel', 'revenge', 'based-on-novel') + AND mc.note is not NULL + AND t.production_year > 1950 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id +---- +Twentieth Century Fox (distribution) (DVD) (US) Fox Novel Movie + +# 11d - Query for movies based on novels from non-Polish companies +query TTT +SELECT MIN(cn.name) AS from_company, MIN(mc.note) AS production_note, MIN(t.title) AS movie_based_on_book +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' + AND ct.kind != 'production companies' and ct.kind is not NULL + AND k.keyword in ('sequel', 'revenge', 'based-on-novel') + AND mc.note is not NULL + AND t.production_year > 1950 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id +---- +Marvel Studios (as Marvel Studios) Avengers: Endgame + +# 12a - Query for cast in movies with specific genres +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS drama_horror_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, title AS t +WHERE cn.country_code = '[us]' + AND ct.kind = 'production companies' + AND it1.info = 'genres' + AND it2.info = 'rating' + AND mi.info in ('Drama', 'Horror') + AND mi_idx.info > '8.0' + AND t.production_year between 2005 and 2008 + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND mi.info_type_id = it1.id + AND mi_idx.info_type_id = it2.id + AND t.id = mc.movie_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id +---- +Warner Bros. 9.5 The Dark Knight + +# 12b - Query for unsuccessful movies with specific budget criteria +query TT +SELECT MIN(mi.info) AS budget, MIN(t.title) AS unsuccsessful_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, title AS t +WHERE cn.country_code ='[us]' + AND ct.kind is not NULL and (ct.kind ='production companies' or ct.kind = 'distributors') + AND it1.info ='budget' + AND it2.info ='bottom 10 rank' + AND t.production_year >2000 + AND (t.title LIKE 'Birdemic%' OR t.title LIKE '%Movie%') + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND mi.info_type_id = it1.id + AND mi_idx.info_type_id = it2.id + AND t.id = mc.movie_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id +---- +$500,000 Bad Movie Sequel + +# 12c - Query for highly rated mainstream movies +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS mainstream_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, title AS t +WHERE cn.country_code = '[us]' + AND ct.kind = 'production companies' + AND it1.info = 'genres' + AND it2.info = 'rating' + AND mi.info in ('Drama', 'Horror', 'Western', 'Family') + AND mi_idx.info > '7.0' + AND t.production_year between 2000 and 2010 + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND mi.info_type_id = it1.id + AND mi_idx.info_type_id = it2.id + AND t.id = mc.movie_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id +---- +Warner Bros. 9.5 The Dark Knight + +# 13a - Query for movies with specific genre combinations +query TTT +SELECT MIN(mi.info) AS release_date, MIN(miidx.info) AS rating, MIN(t.title) AS german_movie +FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t +WHERE cn.country_code ='[de]' + AND ct.kind ='production companies' + AND it.info ='rating' + AND it2.info ='release dates' + AND kt.kind ='movie' + AND mi.movie_id = t.id + AND it2.id = mi.info_type_id + AND kt.id = t.kind_id + AND mc.movie_id = t.id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND miidx.movie_id = t.id + AND it.id = miidx.info_type_id + AND mi.movie_id = miidx.movie_id + AND mi.movie_id = mc.movie_id + AND miidx.movie_id = mc.movie_id +---- +2005-09-15 7.2 Dark Blood + +# 13b - Query for movies about winning with specific criteria +query TTT +SELECT MIN(cn.name) AS producing_company, MIN(miidx.info) AS rating, MIN(t.title) AS movie_about_winning +FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t +WHERE cn.country_code ='[us]' + AND ct.kind ='production companies' + AND it.info ='rating' + AND it2.info ='release dates' + AND kt.kind ='movie' + AND t.title != '' + AND (t.title LIKE '%Champion%' OR t.title LIKE '%Loser%') + AND mi.movie_id = t.id + AND it2.id = mi.info_type_id + AND kt.id = t.kind_id + AND mc.movie_id = t.id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND miidx.movie_id = t.id + AND it.id = miidx.info_type_id + AND mi.movie_id = miidx.movie_id + AND mi.movie_id = mc.movie_id + AND miidx.movie_id = mc.movie_id +---- +Universal Pictures 7.5 Champion Boxer + +# 13c - Query for movies with Champion in the title +query TTT +SELECT MIN(cn.name) AS producing_company, MIN(miidx.info) AS rating, MIN(t.title) AS movie_about_winning +FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t +WHERE cn.country_code ='[us]' + AND ct.kind ='production companies' + AND it.info ='rating' + AND it2.info ='release dates' + AND kt.kind ='movie' + AND t.title != '' + AND (t.title LIKE 'Champion%' OR t.title LIKE 'Loser%') + AND mi.movie_id = t.id + AND it2.id = mi.info_type_id + AND kt.id = t.kind_id + AND mc.movie_id = t.id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND miidx.movie_id = t.id + AND it.id = miidx.info_type_id + AND mi.movie_id = miidx.movie_id + AND mi.movie_id = mc.movie_id + AND miidx.movie_id = mc.movie_id +---- +Universal Pictures 7.5 Champion Boxer + +# 13d - Query for all US movies +query TTT +SELECT MIN(cn.name) AS producing_company, MIN(miidx.info) AS rating, MIN(t.title) AS movie +FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t +WHERE cn.country_code ='[us]' + AND ct.kind ='production companies' + AND it.info ='rating' + AND it2.info ='release dates' + AND kt.kind ='movie' + AND mi.movie_id = t.id + AND it2.id = mi.info_type_id + AND kt.id = t.kind_id + AND mc.movie_id = t.id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND miidx.movie_id = t.id + AND it.id = miidx.info_type_id + AND mi.movie_id = miidx.movie_id + AND mi.movie_id = mc.movie_id + AND miidx.movie_id = mc.movie_id +---- +Marvel Studios 7.5 Avengers: Endgame + +# 14a - Query for actors in specific movie types +query TT +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS northern_dark_movie +FROM info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind = 'movie' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info < '8.5' + AND t.production_year > 2010 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id +---- +7.5 Nordic Noir + +# 14b - Query for dark western productions with specific criteria +query TT +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS western_dark_production +FROM info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title') + AND kt.kind = 'movie' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info > '6.0' + AND t.production_year > 2010 and (t.title like '%murder%' or t.title like '%Murder%' or t.title like '%Mord%') + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id +---- +7.8 The Swedish Murder Case + +# 14c - Query for extended movie types and dark themes +query TT +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS north_european_dark_production +FROM info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword is not null and k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info < '8.5' + AND t.production_year > 2005 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id +---- +6.8 Berlin Noir + +# 15a - Query for US movies with internet releases +query TT +SELECT MIN(mi.info) AS release_date, MIN(t.title) AS internet_movie +FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cn.country_code = '[us]' + AND it1.info = 'release dates' + AND mc.note like '%(200%)%' and mc.note like '%(worldwide)%' + AND mi.note like '%internet%' + AND mi.info like 'USA:% 200%' + AND t.production_year > 2000 + AND t.id = at.movie_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = at.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = at.movie_id + AND mc.movie_id = at.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +USA: 2008-05-15 The Dark Knight + +# 15b - Query for YouTube movies with specific release criteria +query TT +SELECT MIN(mi.info) AS release_date, MIN(t.title) AS youtube_movie +FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cn.country_code = '[us]' and cn.name = 'YouTube' + AND it1.info = 'release dates' + AND mc.note like '%(200%)%' and mc.note like '%(worldwide)%' + AND mi.note like '%internet%' + AND mi.info like 'USA:% 200%' + AND t.production_year between 2005 and 2010 + AND t.id = at.movie_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = at.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = at.movie_id + AND mc.movie_id = at.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +USA: 2008-05-15 YouTube Documentary + +# 15c - Query for extended internet releases +query TT +SELECT MIN(mi.info) AS release_date, MIN(t.title) AS modern_american_internet_movie +FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cn.country_code = '[us]' + AND it1.info = 'release dates' + AND mi.note like '%internet%' + AND mi.info is not NULL and (mi.info like 'USA:% 199%' or mi.info like 'USA:% 200%') + AND t.production_year > 1990 + AND t.id = at.movie_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = at.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = at.movie_id + AND mc.movie_id = at.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +USA: 15 May 2005 Digital Connection + +# 15d - Query for specialized internet releases +query TT +SELECT MIN(at.title) AS aka_title, MIN(t.title) AS internet_movie_title +FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cn.country_code = '[us]' + AND it1.info = 'release dates' + AND mi.note like '%internet%' + AND t.production_year > 1990 + AND t.id = at.movie_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = at.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = at.movie_id + AND mc.movie_id = at.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id +---- +Avengers 4 Avengers: Endgame + +# 16a - Query for movies in specific languages +query TT +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char +FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND t.episode_nr >= 50 + AND t.episode_nr < 100 + AND an.person_id = n.id + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND an.person_id = ci.person_id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Alexander Morgan Character Series + +# 16b - Query for series named after characters +query TT +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char +FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND an.person_id = n.id + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND an.person_id = ci.person_id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Alexander Morgan Character Series + +# 16c - Query for extended languages and character-named series +query TT +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char +FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND t.episode_nr < 100 + AND an.person_id = n.id + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND an.person_id = ci.person_id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Alexander Morgan Character Series + +# 16d - Query for specialized languages and character-named series +query TT +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char +FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND t.episode_nr >= 5 + AND t.episode_nr < 100 + AND an.person_id = n.id + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND an.person_id = ci.person_id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Alexander Morgan Character Series + +# 17a - Query for actor/actress combinations +query TT +SELECT MIN(n.name) AS member_in_charnamed_american_movie, MIN(n.name) AS a1 +FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND n.name LIKE 'B%' + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Bert Wilson Bert Wilson + +# 17b - Query for actors with names starting with Z in character-named movies +query TT +SELECT MIN(n.name) AS member_in_charnamed_movie, MIN(n.name) AS a1 +FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword ='character-name-in-title' + AND n.name LIKE 'Z%' + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Zach Wilson Zach Wilson + +# 17c - Query for extended actor/actress combinations +query TT +SELECT MIN(n.name) AS member_in_charnamed_movie, MIN(n.name) AS a1 +FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword ='character-name-in-title' + AND n.name LIKE 'X%' + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Xavier Thompson Xavier Thompson + +# 17d - Query for specialized actor/actress combinations +query T +SELECT MIN(n.name) AS member_in_charnamed_movie +FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword ='character-name-in-title' + AND n.name LIKE '%Bert%' + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Bert Wilson + +# 17e - Query for advanced actor/actress combinations +query T +SELECT MIN(n.name) AS member_in_charnamed_movie +FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE cn.country_code ='[us]' + AND k.keyword ='character-name-in-title' + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Alex Morgan + +# 17f - Query for complex actor/actress combinations +query T +SELECT MIN(n.name) AS member_in_charnamed_movie +FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t +WHERE k.keyword ='character-name-in-title' + AND n.name LIKE '%B%' + AND n.id = ci.person_id + AND ci.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_id = cn.id + AND ci.movie_id = mc.movie_id + AND ci.movie_id = mk.movie_id + AND mc.movie_id = mk.movie_id +---- +Bert Wilson + +# 18a - Query with complex genre filtering +query TTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(t.title) AS movie_title +FROM cast_info AS ci, info_type AS it1, info_type AS it2, movie_info AS mi, movie_info_idx AS mi_idx, name AS n, title AS t +WHERE ci.note in ('(producer)', '(executive producer)') + AND it1.info = 'budget' + AND it2.info = 'votes' + AND n.gender = 'm' and n.name like '%Tim%' + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id +---- +$25,000,000 2,345,678 The Shawshank Redemption + +# 18b - Query for horror movies by female writers with high ratings +query TTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(t.title) AS movie_title +FROM cast_info AS ci, info_type AS it1, info_type AS it2, movie_info AS mi, movie_info_idx AS mi_idx, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'rating' + AND mi.info in ('Horror', 'Thriller') and mi.note is NULL + AND mi_idx.info > '8.0' + AND n.gender is not null and n.gender = 'f' + AND t.production_year between 2008 and 2014 + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id +---- +Horror 8.5 Woman in Black + +# 18c - Query for extended genre filtering +query TTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(t.title) AS movie_title +FROM cast_info AS ci, info_type AS it1, info_type AS it2, movie_info AS mi, movie_info_idx AS mi_idx, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') + AND n.gender = 'm' + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND mi.movie_id = mi_idx.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id +---- +Horror 1000 Halloween + +# 19a - Query for character name patterns +query TT +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND mc.note is not NULL and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') + AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') + AND n.gender ='f' and n.name like '%Ang%' + AND rt.role ='actress' + AND t.production_year between 2005 and 2009 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mi.movie_id = ci.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id +---- +Angelina Jolie Kung Fu Panda + +# 19b - Query for Angelina Jolie as voice actress in Kung Fu Panda series +query TT +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS kung_fu_panda +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t +WHERE ci.note = '(voice)' + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND mc.note like '%(200%)%' and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') + AND mi.info is not null and (mi.info like 'Japan:%2007%' or mi.info like 'USA:%2008%') + AND n.gender ='f' and n.name like '%Angel%' + AND rt.role ='actress' + AND t.production_year between 2007 and 2008 and t.title like '%Kung%Fu%Panda%' + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mi.movie_id = ci.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id +---- +Angelina Jolie Kung Fu Panda + +# 19c - Query for extended character patterns +query TT +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS jap_engl_voiced_movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND t.production_year > 2000 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mi.movie_id = ci.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id +---- +Angelina Jolie Kung Fu Panda + +# 19d - Query for specialized character patterns +query TT +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS jap_engl_voiced_movie +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND n.gender ='f' + AND rt.role ='actress' + AND t.production_year > 2000 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mi.movie_id = ci.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id +---- +Angelina Jolie Kung Fu Panda + +# 20a - Query for movies with specific actor roles +query T +SELECT MIN(t.title) AS complete_downey_ironman_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, keyword AS k, kind_type AS kt, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like '%complete%' + AND chn.name not like '%Sherlock%' and (chn.name like '%Tony%Stark%' or chn.name like '%Iron%Man%') + AND k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') + AND kt.kind = 'movie' + AND t.production_year > 1950 + AND kt.id = t.kind_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = ci.movie_id + AND mk.movie_id = cc.movie_id + AND ci.movie_id = cc.movie_id + AND chn.id = ci.person_role_id + AND n.id = ci.person_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Iron Man + +# 20b - Query for complete Downey Iron Man movies +query T +SELECT MIN(t.title) AS complete_downey_ironman_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, keyword AS k, kind_type AS kt, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like '%complete%' + AND chn.name not like '%Sherlock%' and (chn.name like '%Tony%Stark%' or chn.name like '%Iron%Man%') + AND k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') + AND kt.kind = 'movie' + AND n.name LIKE '%Downey%Robert%' + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = ci.movie_id + AND mk.movie_id = cc.movie_id + AND ci.movie_id = cc.movie_id + AND chn.id = ci.person_role_id + AND n.id = ci.person_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Iron Man + +# 20c - Query for extended specific actor roles +query TT +SELECT MIN(n.name) AS cast_member, MIN(t.title) AS complete_dynamic_hero_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, keyword AS k, kind_type AS kt, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like '%complete%' + AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') + AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence', 'magnet', 'web', 'claw', 'laser') + AND kt.kind = 'movie' + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = ci.movie_id + AND mk.movie_id = cc.movie_id + AND ci.movie_id = cc.movie_id + AND chn.id = ci.person_role_id + AND n.id = ci.person_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Downey Robert Jr. Iron Man + +# 21a - Query for movies with specific production years +query TTT +SELECT MIN(cn.name) AS company_name, MIN(lt.link) AS link_type, MIN(t.title) AS western_follow_up +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German') + AND t.production_year BETWEEN 1950 + AND 2000 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND mi.movie_id = t.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND ml.movie_id = mi.movie_id + AND mk.movie_id = mi.movie_id + AND mc.movie_id = mi.movie_id +---- +Warner Bros. Pictures follows The Western Sequel + +# 21b - Query for German follow-up movies +query TTT +SELECT MIN(cn.name) AS company_name, MIN(lt.link) AS link_type, MIN(t.title) AS german_follow_up +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND mi.info IN ('Germany', 'German') + AND t.production_year BETWEEN 2000 + AND 2010 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND mi.movie_id = t.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND ml.movie_id = mi.movie_id + AND mk.movie_id = mi.movie_id + AND mc.movie_id = mi.movie_id +---- +Berlin Film Studio follows Dark Blood + +# 21c - Query for extended specific production years +query TTT +SELECT MIN(cn.name) AS company_name, MIN(lt.link) AS link_type, MIN(t.title) AS western_follow_up +FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'English') + AND t.production_year BETWEEN 1950 + AND 2010 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND mi.movie_id = t.id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND ml.movie_id = mi.movie_id + AND mk.movie_id = mi.movie_id + AND mc.movie_id = mi.movie_id +---- +Berlin Film Studio follows Dark Blood + +# 22a - Query for movies with specific actor roles +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' + AND mi.info IN ('Germany', 'German', 'USA', 'American') + AND mi_idx.info < '7.0' + AND t.production_year > 2008 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id +---- +Berlin Film Studio 6.8 Berlin Noir + +# 22b - Query for western violent movies by non-US companies +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' + AND mi.info IN ('Germany', 'German', 'USA', 'American') + AND mi_idx.info < '7.0' + AND t.production_year > 2009 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id +---- +Berlin Film Studio 6.8 Berlin Noir + +# 22c - Query for extended actor roles +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info < '8.5' + AND t.production_year > 2005 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id +---- +Berlin Film Studio 6.8 Berlin Noir + +# 22d - Query for specialized actor roles +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie +FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info < '8.5' + AND t.production_year > 2005 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id +---- +Berlin Film Studio 6.8 Berlin Noir + +# 23a - Query for sequels with specific character names +query TT +SELECT MIN(kt.kind) AS movie_kind, MIN(t.title) AS complete_us_internet_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cct1.kind = 'complete+verified' + AND cn.country_code = '[us]' + AND it1.info = 'release dates' + AND kt.kind in ('movie') + AND mi.note like '%internet%' + AND mi.info is not NULL and (mi.info like 'USA:% 199%' or mi.info like 'USA:% 200%') + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = cc.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = cc.movie_id + AND mc.movie_id = cc.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND cct1.id = cc.status_id +---- +movie Digital Connection + +# 23b - Query for complete nerdy internet movies +query TT +SELECT MIN(kt.kind) AS movie_kind, MIN(t.title) AS complete_nerdy_internet_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cct1.kind = 'complete+verified' + AND cn.country_code = '[us]' + AND it1.info = 'release dates' + AND k.keyword in ('nerd', 'loner', 'alienation', 'dignity') + AND kt.kind in ('movie') + AND mi.note like '%internet%' + AND mi.info like 'USA:% 200%' + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = cc.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = cc.movie_id + AND mc.movie_id = cc.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND cct1.id = cc.status_id +---- +movie Digital Connection + +# 23c - Query for extended sequels with specific attributes +query TT +SELECT MIN(kt.kind) AS movie_kind, MIN(t.title) AS complete_us_internet_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t +WHERE cct1.kind = 'complete+verified' + AND cn.country_code = '[us]' + AND it1.info = 'release dates' + AND kt.kind in ('movie', 'tv movie', 'video movie', 'video game') + AND mi.note like '%internet%' + AND mi.info is not NULL and (mi.info like 'USA:% 199%' or mi.info like 'USA:% 200%') + AND t.production_year > 1990 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = cc.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = cc.movie_id + AND mc.movie_id = cc.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND cn.id = mc.company_id + AND ct.id = mc.company_type_id + AND cct1.id = cc.status_id +---- +movie Digital Connection + +# 24a - Query for movies with specific budgets +query TTT +SELECT MIN(chn.name) AS voiced_char_name, MIN(n.name) AS voicing_actress_name, MIN(t.title) AS voiced_action_movie_jap_eng +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND k.keyword in ('hero', 'martial-arts', 'hand-to-hand-combat') + AND mi.info is not null and (mi.info like 'Japan:%201%' or mi.info like 'USA:%201%') + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND t.production_year > 2010 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mk.movie_id + AND mi.movie_id = ci.movie_id + AND mi.movie_id = mk.movie_id + AND ci.movie_id = mk.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id + AND k.id = mk.keyword_id +---- +Batman's Assistant Angelina Jolie Kung Fu Panda 2 + +# 24b - Query for voiced characters in Kung Fu Panda +query TTT +SELECT MIN(chn.name) AS voiced_char_name, MIN(n.name) AS voicing_actress_name, MIN(t.title) AS kung_fu_panda +FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, role_type AS rt, title AS t +WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND cn.name = 'DreamWorks Animation' + AND it.info = 'release dates' + AND k.keyword in ('hero', 'martial-arts', 'hand-to-hand-combat', 'computer-animated-movie') + AND mi.info is not null and (mi.info like 'Japan:%201%' or mi.info like 'USA:%201%') + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND t.production_year > 2010 + AND t.title like 'Kung Fu Panda%' + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mk.movie_id + AND mi.movie_id = ci.movie_id + AND mi.movie_id = mk.movie_id + AND ci.movie_id = mk.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id + AND k.id = mk.keyword_id +---- +Batman's Assistant Angelina Jolie Kung Fu Panda 2 + +# 25a - Query for cast combinations in specific movies +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS male_writer, MIN(t.title) AS violent_movie_title +FROM cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'blood', 'gore', 'death', 'female-nudity') + AND mi.info = 'Horror' + AND n.gender = 'm' + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi_idx.movie_id = mk.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id +---- +Horror 1000 Christian Bale Halloween + +# 25b - Query for violent horror films with male writers +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS male_writer, MIN(t.title) AS violent_movie_title +FROM cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'blood', 'gore', 'death', 'female-nudity') + AND mi.info = 'Horror' + AND n.gender = 'm' + AND t.production_year > 2010 + AND t.title like 'Vampire%' + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi_idx.movie_id = mk.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id +---- +Horror 1000 Christian Bale Vampire Chronicles + +# 25c - Query for extended cast combinations +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS male_writer, MIN(t.title) AS violent_movie_title +FROM cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') + AND n.gender = 'm' + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi_idx.movie_id = mk.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id +---- +Horror 1000 Christian Bale Halloween + +# 26a - Query for specific movie genres with ratings +query TTTT +SELECT MIN(chn.name) AS character_name, MIN(mi_idx.info) AS rating, MIN(n.name) AS playing_actor, MIN(t.title) AS complete_hero_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, info_type AS it2, keyword AS k, kind_type AS kt, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like '%complete%' + AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') + AND it2.info = 'rating' + AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence', 'magnet', 'web', 'claw', 'laser') + AND kt.kind = 'movie' + AND mi_idx.info > '7.0' + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND t.id = cc.movie_id + AND t.id = mi_idx.movie_id + AND mk.movie_id = ci.movie_id + AND mk.movie_id = cc.movie_id + AND mk.movie_id = mi_idx.movie_id + AND ci.movie_id = cc.movie_id + AND ci.movie_id = mi_idx.movie_id + AND cc.movie_id = mi_idx.movie_id + AND chn.id = ci.person_role_id + AND n.id = ci.person_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id + AND it2.id = mi_idx.info_type_id +---- +Ironman 8.5 John Carpenter Marvel Superhero Epic + +# 26b - Query for complete hero movies with Man in character name +query TTT +SELECT MIN(chn.name) AS character_name, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_hero_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, info_type AS it2, keyword AS k, kind_type AS kt, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like '%complete%' + AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') + AND it2.info = 'rating' + AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'fight') + AND kt.kind = 'movie' + AND mi_idx.info > '8.0' + AND t.production_year > 2005 + AND kt.id = t.kind_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND t.id = cc.movie_id + AND t.id = mi_idx.movie_id + AND mk.movie_id = ci.movie_id + AND mk.movie_id = cc.movie_id + AND mk.movie_id = mi_idx.movie_id + AND ci.movie_id = cc.movie_id + AND ci.movie_id = mi_idx.movie_id + AND cc.movie_id = mi_idx.movie_id + AND chn.id = ci.person_role_id + AND n.id = ci.person_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id + AND it2.id = mi_idx.info_type_id +---- +Ironman 8.5 Marvel Superhero Epic + +# 26c - Query for extended movie genres and ratings +query TTT +SELECT MIN(chn.name) AS character_name, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_hero_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, info_type AS it2, keyword AS k, kind_type AS kt, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like '%complete%' + AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') + AND it2.info = 'rating' + AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence', 'magnet', 'web', 'claw', 'laser') + AND kt.kind = 'movie' + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mk.movie_id + AND t.id = ci.movie_id + AND t.id = cc.movie_id + AND t.id = mi_idx.movie_id + AND mk.movie_id = ci.movie_id + AND mk.movie_id = cc.movie_id + AND mk.movie_id = mi_idx.movie_id + AND ci.movie_id = cc.movie_id + AND ci.movie_id = mi_idx.movie_id + AND cc.movie_id = mi_idx.movie_id + AND chn.id = ci.person_role_id + AND n.id = ci.person_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id + AND it2.id = mi_idx.info_type_id +---- +Ironman 8.5 Marvel Superhero Epic + +# 27a - Query for movies with specific person roles +query TTT +SELECT MIN(cn.name) AS producing_company, MIN(lt.link) AS link_type, MIN(t.title) AS complete_western_sequel +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cct1.kind in ('cast', 'crew') + AND cct2.kind = 'complete' + AND cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND mi.info IN ('Sweden', 'Germany','Swedish', 'German') + AND t.production_year BETWEEN 1950 + AND 2000 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND mi.movie_id = t.id + AND t.id = cc.movie_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND ml.movie_id = mi.movie_id + AND mk.movie_id = mi.movie_id + AND mc.movie_id = mi.movie_id + AND ml.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND mc.movie_id = cc.movie_id + AND mi.movie_id = cc.movie_id +---- +Warner Bros. Pictures follows The Western Sequel + +# 27b - Query for complete western sequel films by non-Polish companies +query TTT +SELECT MIN(cn.name) AS producing_company, MIN(lt.link) AS link_type, MIN(t.title) AS complete_western_sequel +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cct1.kind in ('cast', 'crew') + AND cct2.kind = 'complete' + AND cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND mi.info IN ('Sweden', 'Germany','Swedish', 'German') + AND t.production_year = 1998 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND mi.movie_id = t.id + AND t.id = cc.movie_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND ml.movie_id = mi.movie_id + AND mk.movie_id = mi.movie_id + AND mc.movie_id = mi.movie_id + AND ml.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND mc.movie_id = cc.movie_id + AND mi.movie_id = cc.movie_id +---- +Warner Bros. Pictures follows The Western Sequel + +# 27c - Query for extended person roles +query TTT +SELECT MIN(cn.name) AS producing_company, MIN(lt.link) AS link_type, MIN(t.title) AS complete_western_sequel +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind like 'complete%' + AND cn.country_code !='[pl]' + AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') + AND ct.kind ='production companies' + AND k.keyword ='sequel' + AND lt.link LIKE '%follow%' + AND mc.note IS NULL + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'English') + AND t.production_year BETWEEN 1950 + AND 2010 + AND lt.id = ml.link_type_id + AND ml.movie_id = t.id + AND t.id = mk.movie_id + AND mk.keyword_id = k.id + AND t.id = mc.movie_id + AND mc.company_type_id = ct.id + AND mc.company_id = cn.id + AND mi.movie_id = t.id + AND t.id = cc.movie_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id + AND ml.movie_id = mk.movie_id + AND ml.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND ml.movie_id = mi.movie_id + AND mk.movie_id = mi.movie_id + AND mc.movie_id = mi.movie_id + AND ml.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND mc.movie_id = cc.movie_id + AND mi.movie_id = cc.movie_id +---- +Warner Bros. Pictures follows The Western Sequel + +# 28a - Query for movies with specific production years +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_euro_dark_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cct1.kind = 'crew' + AND cct2.kind != 'complete+verified' + AND cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info < '8.5' + AND t.production_year > 2000 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = cc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = cc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND mc.movie_id = cc.movie_id + AND mi_idx.movie_id = cc.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Stockholm Productions 7.8 The Nordic Murders + +# 28b - Query for Euro dark movies with complete crew +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_euro_dark_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cct1.kind = 'crew' + AND cct2.kind != 'complete+verified' + AND cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' + AND mi.info IN ('Sweden', 'Germany', 'Swedish', 'German') + AND mi_idx.info > '6.5' + AND t.production_year > 2005 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = cc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = cc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND mc.movie_id = cc.movie_id + AND mi_idx.movie_id = cc.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Stockholm Productions 7.8 The Nordic Murders + +# 28c - Query for extended movies with specific criteria +query TTT +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_euro_dark_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind = 'complete' + AND cn.country_code != '[us]' + AND it1.info = 'countries' + AND it2.info = 'rating' + AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') + AND kt.kind in ('movie', 'episode') + AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' + AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') + AND mi_idx.info < '8.5' + AND t.production_year > 2005 + AND kt.id = t.kind_id + AND t.id = mi.movie_id + AND t.id = mk.movie_id + AND t.id = mi_idx.movie_id + AND t.id = mc.movie_id + AND t.id = cc.movie_id + AND mk.movie_id = mi.movie_id + AND mk.movie_id = mi_idx.movie_id + AND mk.movie_id = mc.movie_id + AND mk.movie_id = cc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mc.movie_id + AND mi.movie_id = cc.movie_id + AND mc.movie_id = mi_idx.movie_id + AND mc.movie_id = cc.movie_id + AND mi_idx.movie_id = cc.movie_id + AND k.id = mk.keyword_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND ct.id = mc.company_type_id + AND cn.id = mc.company_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Oslo Films 7.5 Scandinavian Crime + +# 29a - Query for movies with specific combinations +query TTT +SELECT MIN(chn.name) AS voiced_char, MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_animation +FROM aka_name AS an, complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, info_type AS it3, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, person_info AS pi, role_type AS rt, title AS t +WHERE cct1.kind ='cast' + AND cct2.kind ='complete+verified' + AND chn.name = 'Queen' + AND ci.note in ('(voice)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND it3.info = 'trivia' + AND k.keyword = 'computer-animation' + AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND t.title = 'Shrek 2' + AND t.production_year between 2000 and 2010 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = cc.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mk.movie_id + AND mc.movie_id = cc.movie_id + AND mi.movie_id = ci.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = cc.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id + AND n.id = pi.person_id + AND ci.person_id = pi.person_id + AND it3.id = pi.info_type_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Queen Anne Hathaway Shrek 2 + +# 29b - Query for specific Queen character voice actress +query TTT +SELECT MIN(chn.name) AS voiced_char, MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_animation +FROM aka_name AS an, complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, info_type AS it3, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, person_info AS pi, role_type AS rt, title AS t +WHERE cct1.kind ='cast' + AND cct2.kind ='complete+verified' + AND chn.name = 'Queen' + AND ci.note in ('(voice)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND it3.info = 'height' + AND k.keyword = 'computer-animation' + AND mi.info like 'USA:%200%' + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND t.title = 'Shrek 2' + AND t.production_year between 2000 and 2005 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = cc.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mk.movie_id + AND mc.movie_id = cc.movie_id + AND mi.movie_id = ci.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = cc.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id + AND n.id = pi.person_id + AND ci.person_id = pi.person_id + AND it3.id = pi.info_type_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Queen Anne Hathaway Shrek 2 + +# 29c - Query for extended specific combinations +query TTT +SELECT MIN(chn.name) AS voiced_char, MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_animation +FROM aka_name AS an, complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, info_type AS it3, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, person_info AS pi, role_type AS rt, title AS t +WHERE cct1.kind ='cast' + AND cct2.kind ='complete+verified' + AND ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') + AND cn.country_code ='[us]' + AND it.info = 'release dates' + AND it3.info = 'trivia' + AND k.keyword = 'computer-animation' + AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') + AND n.gender ='f' and n.name like '%An%' + AND rt.role ='actress' + AND t.production_year between 2000 and 2010 + AND t.id = mi.movie_id + AND t.id = mc.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = cc.movie_id + AND mc.movie_id = ci.movie_id + AND mc.movie_id = mi.movie_id + AND mc.movie_id = mk.movie_id + AND mc.movie_id = cc.movie_id + AND mi.movie_id = ci.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = cc.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND cn.id = mc.company_id + AND it.id = mi.info_type_id + AND n.id = ci.person_id + AND rt.id = ci.role_id + AND n.id = an.person_id + AND ci.person_id = an.person_id + AND chn.id = ci.person_role_id + AND n.id = pi.person_id + AND ci.person_id = pi.person_id + AND it3.id = pi.info_type_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Queen Anne Hathaway Shrek 2 + +# 30a - Query for top-rated action movies +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS complete_violent_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind in ('cast', 'crew') + AND cct2.kind ='complete+verified' + AND ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mi.info in ('Horror', 'Thriller') + AND n.gender = 'm' + AND t.production_year > 2000 + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = cc.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = cc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = cc.movie_id + AND mi_idx.movie_id = mk.movie_id + AND mi_idx.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Horror 52000 James Wan Saw IV + +# 30b - Query for ratings of female-cast-only movies +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS complete_gore_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind in ('cast', 'crew') + AND cct2.kind ='complete+verified' + AND ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mi.info in ('Horror', 'Thriller') + AND n.gender = 'm' + AND t.production_year > 2000 and (t.title like '%Freddy%' or t.title like '%Jason%' or t.title like 'Saw%') + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = cc.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = cc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = cc.movie_id + AND mi_idx.movie_id = mk.movie_id + AND mi_idx.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Horror 52000 James Wan Saw IV + +# 30c - Query for extended action movies +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS complete_violent_movie +FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE cct1.kind = 'cast' + AND cct2.kind ='complete+verified' + AND ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = cc.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = cc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = cc.movie_id + AND mi_idx.movie_id = mk.movie_id + AND mi_idx.movie_id = cc.movie_id + AND mk.movie_id = cc.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id + AND cct1.id = cc.subject_id + AND cct2.id = cc.status_id +---- +Horror 52000 James Wan Saw IV + +# 31a - Query for movies with specific language and production values +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS violent_liongate_movie +FROM cast_info AS ci, company_name AS cn, info_type AS it1, info_type AS it2, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND cn.name like 'Lionsgate%' + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mi.info in ('Horror', 'Thriller') + AND n.gender = 'm' + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = mc.movie_id + AND mi_idx.movie_id = mk.movie_id + AND mi_idx.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id + AND cn.id = mc.company_id +---- +Horror 45000 James Wan Halloween + +# 31b - Query for sci-fi female-focused movies +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS violent_liongate_movie +FROM cast_info AS ci, company_name AS cn, info_type AS it1, info_type AS it2, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND cn.name like 'Lionsgate%' + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mc.note like '%(Blu-ray)%' + AND mi.info in ('Horror', 'Thriller') + AND n.gender = 'm' + AND t.production_year > 2000 and (t.title like '%Freddy%' or t.title like '%Jason%' or t.title like 'Saw%') + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = mc.movie_id + AND mi_idx.movie_id = mk.movie_id + AND mi_idx.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id + AND cn.id = mc.company_id +---- +Horror 52000 James Wan Saw IV + +# 31c - Query for extended language and production values +query TTTT +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS violent_liongate_movie +FROM cast_info AS ci, company_name AS cn, info_type AS it1, info_type AS it2, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t +WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') + AND cn.name like 'Lionsgate%' + AND it1.info = 'genres' + AND it2.info = 'votes' + AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') + AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') + AND t.id = mi.movie_id + AND t.id = mi_idx.movie_id + AND t.id = ci.movie_id + AND t.id = mk.movie_id + AND t.id = mc.movie_id + AND ci.movie_id = mi.movie_id + AND ci.movie_id = mi_idx.movie_id + AND ci.movie_id = mk.movie_id + AND ci.movie_id = mc.movie_id + AND mi.movie_id = mi_idx.movie_id + AND mi.movie_id = mk.movie_id + AND mi.movie_id = mc.movie_id + AND mi_idx.movie_id = mk.movie_id + AND mi_idx.movie_id = mc.movie_id + AND mk.movie_id = mc.movie_id + AND n.id = ci.person_id + AND it1.id = mi.info_type_id + AND it2.id = mi_idx.info_type_id + AND k.id = mk.keyword_id + AND cn.id = mc.company_id +---- +Horror 45000 James Wan Halloween + +# 32a - Query for action movies with specific actor roles +query TTT +SELECT MIN(lt.link) AS link_type, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie +FROM keyword AS k, link_type AS lt, movie_keyword AS mk, movie_link AS ml, title AS t1, title AS t2 +WHERE k.keyword ='10,000-mile-club' + AND mk.keyword_id = k.id + AND t1.id = mk.movie_id + AND ml.movie_id = t1.id + AND ml.linked_movie_id = t2.id + AND lt.id = ml.link_type_id + AND mk.movie_id = t1.id +---- +edited into Interstellar Saving Private Ryan + +# 32b - Query for character-name-in-title movies and their connections +query TTT +SELECT MIN(lt.link) AS link_type, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie +FROM keyword AS k, link_type AS lt, movie_keyword AS mk, movie_link AS ml, title AS t1, title AS t2 +WHERE k.keyword ='character-name-in-title' + AND mk.keyword_id = k.id + AND t1.id = mk.movie_id + AND ml.movie_id = t1.id + AND ml.linked_movie_id = t2.id + AND lt.id = ml.link_type_id + AND mk.movie_id = t1.id +---- +featured in Iron Man Avengers: Endgame + +# 33a - Query for directors of sequels with specific ratings +query TTTTTT +SELECT MIN(cn1.name) AS first_company, MIN(cn2.name) AS second_company, MIN(mi_idx1.info) AS first_rating, MIN(mi_idx2.info) AS second_rating, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie +FROM company_name AS cn1, company_name AS cn2, info_type AS it1, info_type AS it2, kind_type AS kt1, kind_type AS kt2, link_type AS lt, movie_companies AS mc1, movie_companies AS mc2, movie_info_idx AS mi_idx1, movie_info_idx AS mi_idx2, movie_link AS ml, title AS t1, title AS t2 +WHERE cn1.country_code = '[us]' + AND it1.info = 'rating' + AND it2.info = 'rating' + AND kt1.kind in ('tv series') + AND kt2.kind in ('tv series') + AND lt.link in ('sequel', 'follows', 'followed by') + AND mi_idx2.info < '3.0' + AND t2.production_year between 2005 and 2008 + AND lt.id = ml.link_type_id + AND t1.id = ml.movie_id + AND t2.id = ml.linked_movie_id + AND it1.id = mi_idx1.info_type_id + AND t1.id = mi_idx1.movie_id + AND kt1.id = t1.kind_id + AND cn1.id = mc1.company_id + AND t1.id = mc1.movie_id + AND ml.movie_id = mi_idx1.movie_id + AND ml.movie_id = mc1.movie_id + AND mi_idx1.movie_id = mc1.movie_id + AND it2.id = mi_idx2.info_type_id + AND t2.id = mi_idx2.movie_id + AND kt2.id = t2.kind_id + AND cn2.id = mc2.company_id + AND t2.id = mc2.movie_id + AND ml.linked_movie_id = mi_idx2.movie_id + AND ml.linked_movie_id = mc2.movie_id + AND mi_idx2.movie_id = mc2.movie_id +---- +Paramount Pictures Paramount Pictures 8.2 2.8 Breaking Bad Breaking Bad: The Final Season + +# 33b - Query for linked TV series by country code +query TTTTTT +SELECT MIN(cn1.name) AS first_company, MIN(cn2.name) AS second_company, MIN(mi_idx1.info) AS first_rating, MIN(mi_idx2.info) AS second_rating, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie +FROM company_name AS cn1, company_name AS cn2, info_type AS it1, info_type AS it2, kind_type AS kt1, kind_type AS kt2, link_type AS lt, movie_companies AS mc1, movie_companies AS mc2, movie_info_idx AS mi_idx1, movie_info_idx AS mi_idx2, movie_link AS ml, title AS t1, title AS t2 +WHERE cn1.country_code = '[nl]' + AND it1.info = 'rating' + AND it2.info = 'rating' + AND kt1.kind in ('tv series') + AND kt2.kind in ('tv series') + AND lt.link LIKE '%follow%' + AND mi_idx2.info < '3.0' + AND t2.production_year = 2007 + AND lt.id = ml.link_type_id + AND t1.id = ml.movie_id + AND t2.id = ml.linked_movie_id + AND it1.id = mi_idx1.info_type_id + AND t1.id = mi_idx1.movie_id + AND kt1.id = t1.kind_id + AND cn1.id = mc1.company_id + AND t1.id = mc1.movie_id + AND ml.movie_id = mi_idx1.movie_id + AND ml.movie_id = mc1.movie_id + AND mi_idx1.movie_id = mc1.movie_id + AND it2.id = mi_idx2.info_type_id + AND t2.id = mi_idx2.movie_id + AND kt2.id = t2.kind_id + AND cn2.id = mc2.company_id + AND t2.id = mc2.movie_id + AND ml.linked_movie_id = mi_idx2.movie_id + AND ml.linked_movie_id = mc2.movie_id + AND mi_idx2.movie_id = mc2.movie_id +---- +Dutch Entertainment Group Amsterdam Studios 8.5 2.5 Amsterdam Detective Amsterdam Detective: Cold Case + +# 33c - Query for linked TV series and episodes with specific ratings +query TTTTTT +SELECT MIN(cn1.name) AS first_company, MIN(cn2.name) AS second_company, MIN(mi_idx1.info) AS first_rating, MIN(mi_idx2.info) AS second_rating, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie +FROM company_name AS cn1, company_name AS cn2, info_type AS it1, info_type AS it2, kind_type AS kt1, kind_type AS kt2, link_type AS lt, movie_companies AS mc1, movie_companies AS mc2, movie_info_idx AS mi_idx1, movie_info_idx AS mi_idx2, movie_link AS ml, title AS t1, title AS t2 +WHERE cn1.country_code != '[us]' + AND it1.info = 'rating' + AND it2.info = 'rating' + AND kt1.kind in ('tv series', 'episode') + AND kt2.kind in ('tv series', 'episode') + AND lt.link in ('sequel', 'follows', 'followed by') + AND mi_idx2.info < '3.5' + AND t2.production_year between 2000 and 2010 + AND lt.id = ml.link_type_id + AND t1.id = ml.movie_id + AND t2.id = ml.linked_movie_id + AND it1.id = mi_idx1.info_type_id + AND t1.id = mi_idx1.movie_id + AND kt1.id = t1.kind_id + AND cn1.id = mc1.company_id + AND t1.id = mc1.movie_id + AND ml.movie_id = mi_idx1.movie_id + AND ml.movie_id = mc1.movie_id + AND mi_idx1.movie_id = mc1.movie_id + AND it2.id = mi_idx2.info_type_id + AND t2.id = mi_idx2.movie_id + AND kt2.id = t2.kind_id + AND cn2.id = mc2.company_id + AND t2.id = mc2.movie_id + AND ml.linked_movie_id = mi_idx2.movie_id + AND ml.linked_movie_id = mc2.movie_id + AND mi_idx2.movie_id = mc2.movie_id +---- +Dutch Entertainment Group Amsterdam Studios 8.5 2.5 Amsterdam Detective Amsterdam Detective: Cold Case + +# Clean up all tables +statement ok +DROP TABLE company_type; + +statement ok +DROP TABLE info_type; + +statement ok +DROP TABLE title; + +statement ok +DROP TABLE movie_companies; + +statement ok +DROP TABLE movie_info_idx; + +statement ok +DROP TABLE movie_info; + +statement ok +DROP TABLE kind_type; + +statement ok +DROP TABLE cast_info; + +statement ok +DROP TABLE char_name; + +statement ok +DROP TABLE keyword; + +statement ok +DROP TABLE movie_keyword; + +statement ok +DROP TABLE company_name; + +statement ok +DROP TABLE name; + +statement ok +DROP TABLE role_type; + +statement ok +DROP TABLE link_type; + +statement ok +DROP TABLE movie_link; + +statement ok +DROP TABLE complete_cast; + +statement ok +DROP TABLE comp_cast_type; + +statement ok +DROP TABLE person_info; + +statement ok +DROP TABLE aka_title; + +statement ok +DROP TABLE aka_name; diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 4964bcbc735c..f76e436e0ad3 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -149,6 +149,39 @@ drop table t statement ok drop table t2 + +############ +## 0 to represent the default value (target_partitions and planning_concurrency) +########### + +statement ok +SET datafusion.execution.target_partitions = 3; + +statement ok +SET datafusion.execution.planning_concurrency = 3; + +# when setting target_partitions and planning_concurrency to 3, their values will be 3 +query TB rowsort +SELECT name, value = 3 FROM information_schema.df_settings WHERE name IN ('datafusion.execution.target_partitions', 'datafusion.execution.planning_concurrency'); +---- +datafusion.execution.planning_concurrency true +datafusion.execution.target_partitions true + +statement ok +SET datafusion.execution.target_partitions = 0; + +statement ok +SET datafusion.execution.planning_concurrency = 0; + +# when setting target_partitions and planning_concurrency to 0, their values will be equal to the +# default values, which are different from 0 (which is invalid) +query TB rowsort +SELECT name, value = 0 FROM information_schema.df_settings WHERE name IN ('datafusion.execution.target_partitions', 'datafusion.execution.planning_concurrency'); +---- +datafusion.execution.planning_concurrency false +datafusion.execution.target_partitions false + + ############ ## SHOW VARIABLES should work ########### @@ -183,7 +216,7 @@ datafusion.catalog.location NULL datafusion.catalog.newlines_in_values false datafusion.execution.batch_size 8192 datafusion.execution.coalesce_batches true -datafusion.execution.collect_statistics false +datafusion.execution.collect_statistics true datafusion.execution.enable_recursive_ctes true datafusion.execution.enforce_batch_size_in_joins false datafusion.execution.keep_partition_by_columns false @@ -191,6 +224,7 @@ datafusion.execution.listing_table_ignore_subdirectory true datafusion.execution.max_buffered_batches_per_output_file 2 datafusion.execution.meta_fetch_concurrency 32 datafusion.execution.minimum_parallel_output_files 4 +datafusion.execution.objectstore_writer_buffer_size 10485760 datafusion.execution.parquet.allow_single_file_parallelism true datafusion.execution.parquet.binary_as_string false datafusion.execution.parquet.bloom_filter_fpp NULL @@ -229,6 +263,7 @@ datafusion.execution.skip_physical_aggregate_schema_check false datafusion.execution.soft_max_rows_per_output_file 50000000 datafusion.execution.sort_in_place_threshold_bytes 1048576 datafusion.execution.sort_spill_reservation_bytes 10485760 +datafusion.execution.spill_compression uncompressed datafusion.execution.split_file_groups_by_statistics false datafusion.execution.target_partitions 7 datafusion.execution.time_zone +00:00 @@ -239,9 +274,19 @@ datafusion.explain.physical_plan_only false datafusion.explain.show_schema false datafusion.explain.show_sizes true datafusion.explain.show_statistics false +datafusion.format.date_format %Y-%m-%d +datafusion.format.datetime_format %Y-%m-%dT%H:%M:%S%.f +datafusion.format.duration_format pretty +datafusion.format.null (empty) +datafusion.format.safe true +datafusion.format.time_format %H:%M:%S%.f +datafusion.format.timestamp_format %Y-%m-%dT%H:%M:%S%.f +datafusion.format.timestamp_tz_format NULL +datafusion.format.types_info false datafusion.optimizer.allow_symmetric_joins_without_pruning true datafusion.optimizer.default_filter_selectivity 20 datafusion.optimizer.enable_distinct_aggregation_soft_limit true +datafusion.optimizer.enable_dynamic_filter_pushdown true datafusion.optimizer.enable_round_robin_repartition true datafusion.optimizer.enable_topk_aggregation true datafusion.optimizer.expand_views_at_output false @@ -264,7 +309,7 @@ datafusion.sql_parser.collect_spans false datafusion.sql_parser.dialect generic datafusion.sql_parser.enable_ident_normalization true datafusion.sql_parser.enable_options_value_normalization false -datafusion.sql_parser.map_varchar_to_utf8view false +datafusion.sql_parser.map_string_types_to_utf8view true datafusion.sql_parser.parse_float_as_decimal false datafusion.sql_parser.recursion_limit 50 datafusion.sql_parser.support_varchar_with_length true @@ -283,7 +328,7 @@ datafusion.catalog.location NULL Location scanned to load tables for `default` s datafusion.catalog.newlines_in_values false Specifies whether newlines in (quoted) CSV values are supported. This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. Parsing newlines in quoted values may be affected by execution behaviour such as parallel file scanning. Setting this to `true` ensures that newlines in values are parsed successfully, which may reduce performance. datafusion.execution.batch_size 8192 Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting -datafusion.execution.collect_statistics false Should DataFusion collect statistics after listing files +datafusion.execution.collect_statistics true Should DataFusion collect statistics when first creating a table. Has no effect after the table is created. Applies to the default `ListingTableProvider` in DataFusion. Defaults to true. datafusion.execution.enable_recursive_ctes true Should DataFusion support recursive CTEs datafusion.execution.enforce_batch_size_in_joins false Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. datafusion.execution.keep_partition_by_columns false Should DataFusion keep the columns used for partition_by in the output RecordBatches @@ -291,11 +336,12 @@ datafusion.execution.listing_table_ignore_subdirectory true Should sub directori datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics datafusion.execution.minimum_parallel_output_files 4 Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. +datafusion.execution.objectstore_writer_buffer_size 10485760 Size (bytes) of data buffer DataFusion uses when writing output files. This affects the size of the data chunks that are uploaded to remote object stores (e.g. AWS S3). If very large (>= 100 GiB) output files are being written, it may be necessary to increase this size to avoid errors from the remote end point. datafusion.execution.parquet.allow_single_file_parallelism true (writing) Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. datafusion.execution.parquet.binary_as_string false (reading) If true, parquet reader will read columns of `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. Parquet files generated by some legacy writers do not correctly set the UTF8 flag for strings, causing string columns to be loaded as BLOB instead. datafusion.execution.parquet.bloom_filter_fpp NULL (writing) Sets bloom filter false positive probability. If NULL, uses default parquet writer setting datafusion.execution.parquet.bloom_filter_ndv NULL (writing) Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting -datafusion.execution.parquet.bloom_filter_on_read true (writing) Use any available bloom filters when reading parquet files +datafusion.execution.parquet.bloom_filter_on_read true (reading) Use any available bloom filters when reading parquet files datafusion.execution.parquet.bloom_filter_on_write false (writing) Write bloom filters for all columns when creating parquet files datafusion.execution.parquet.coerce_int96 NULL (reading) If true, parquet reader will read columns of physical type int96 as originating from a different resolution than nanosecond. This is useful for reading data from systems like Spark which stores microsecond resolution timestamps in an int96 allowing it to write values with a larger date range than 64-bit timestamps with nanosecond resolution. datafusion.execution.parquet.column_index_truncate_length 64 (writing) Sets column index truncate length @@ -329,6 +375,7 @@ datafusion.execution.skip_physical_aggregate_schema_check false When set to true datafusion.execution.soft_max_rows_per_output_file 50000000 Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max datafusion.execution.sort_in_place_threshold_bytes 1048576 When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. datafusion.execution.sort_spill_reservation_bytes 10485760 Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). +datafusion.execution.spill_compression uncompressed Sets the compression codec used when spilling data to disk. Since datafusion writes spill files using the Arrow IPC Stream format, only codecs supported by the Arrow IPC Stream Writer are allowed. Valid values are: uncompressed, lz4_frame, zstd. Note: lz4_frame offers faster (de)compression, but typically results in larger spill files. In contrast, zstd achieves higher compression ratios at the cost of slower (de)compression speed. datafusion.execution.split_file_groups_by_statistics false Attempt to eliminate sorts by packing & sorting files with non-overlapping statistics into the same file groups. Currently experimental datafusion.execution.target_partitions 7 Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system datafusion.execution.time_zone +00:00 The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour @@ -339,9 +386,19 @@ datafusion.explain.physical_plan_only false When set to true, the explain statem datafusion.explain.show_schema false When set to true, the explain statement will print schema information datafusion.explain.show_sizes true When set to true, the explain statement will print the partition sizes datafusion.explain.show_statistics false When set to true, the explain statement will print operator statistics for physical plans +datafusion.format.date_format %Y-%m-%d Date format for date arrays +datafusion.format.datetime_format %Y-%m-%dT%H:%M:%S%.f Format for DateTime arrays +datafusion.format.duration_format pretty Duration format. Can be either `"pretty"` or `"ISO8601"` +datafusion.format.null (empty) Format string for nulls +datafusion.format.safe true If set to `true` any formatting errors will be written to the output instead of being converted into a [`std::fmt::Error`] +datafusion.format.time_format %H:%M:%S%.f Time format for time arrays +datafusion.format.timestamp_format %Y-%m-%dT%H:%M:%S%.f Timestamp format for timestamp arrays +datafusion.format.timestamp_tz_format NULL Timestamp format for timestamp with timezone arrays. When `None`, ISO 8601 format is used. +datafusion.format.types_info false Show types in visual representation batches datafusion.optimizer.allow_symmetric_joins_without_pruning true Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. datafusion.optimizer.default_filter_selectivity 20 The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). datafusion.optimizer.enable_distinct_aggregation_soft_limit true When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. +datafusion.optimizer.enable_dynamic_filter_pushdown true When set to true attempts to push down dynamic filters generated by operators into the file scan phase. For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. This means that if we already have 10 timestamps in the year 2025 any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. datafusion.optimizer.enable_round_robin_repartition true When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores datafusion.optimizer.enable_topk_aggregation true When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible datafusion.optimizer.expand_views_at_output false When set to true, if the returned type is a view type then the output will be coerced to a non-view. Coerces `Utf8View` to `LargeUtf8`, and `BinaryView` to `LargeBinary`. @@ -354,7 +411,7 @@ datafusion.optimizer.prefer_existing_union false When set to true, the optimizer datafusion.optimizer.prefer_hash_join true When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory datafusion.optimizer.repartition_aggregations true Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level datafusion.optimizer.repartition_file_min_size 10485760 Minimum total files size in bytes to perform file scan repartitioning. -datafusion.optimizer.repartition_file_scans true When set to `true`, file groups will be repartitioned to achieve maximum parallelism. Currently Parquet and CSV formats are supported. If set to `true`, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false`, different files will be read in parallel, but repartitioning won't happen within a single file. +datafusion.optimizer.repartition_file_scans true When set to `true`, datasource partitions will be repartitioned to achieve maximum parallelism. This applies to both in-memory partitions and FileSource's file groups (1 group is 1 partition). For FileSources, only Parquet and CSV formats are currently supported. If set to `true` for a FileSource, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false` for a FileSource, different files will be read in parallel, but repartitioning won't happen within a single file. If set to `true` for an in-memory source, all memtable's partitions will have their batches repartitioned evenly to the desired number of `target_partitions`. Repartitioning can change the total number of partitions and batches per partition, but does not slice the initial record tables provided to the MemTable on creation. datafusion.optimizer.repartition_joins true Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level datafusion.optimizer.repartition_sorts true Should DataFusion execute sorts in a per-partition fashion and merge afterwards instead of coalescing first and sorting globally. With this flag is enabled, plans in the form below ```text "SortExec: [a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ``` would turn into the plan below which performs better in multithreaded environments ```text "SortPreservingMergeExec: [a@0 ASC]", " SortExec: [a@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ``` datafusion.optimizer.repartition_windows true Should DataFusion repartition data using the partitions keys to execute window functions in parallel using the provided `target_partitions` level @@ -364,7 +421,7 @@ datafusion.sql_parser.collect_spans false When set to true, the source locations datafusion.sql_parser.dialect generic Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB and Databricks. datafusion.sql_parser.enable_ident_normalization true When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) datafusion.sql_parser.enable_options_value_normalization false When set to true, SQL parser will normalize options value (convert value to lowercase). Note that this option is ignored and will be removed in the future. All case-insensitive values are normalized automatically. -datafusion.sql_parser.map_varchar_to_utf8view false If true, `VARCHAR` is mapped to `Utf8View` during SQL planning. If false, `VARCHAR` is mapped to `Utf8` during SQL planning. Default is false. +datafusion.sql_parser.map_string_types_to_utf8view true If true, string types (VARCHAR, CHAR, Text, and String) are mapped to `Utf8View` during SQL planning. If false, they are mapped to `Utf8`. Default is true. datafusion.sql_parser.parse_float_as_decimal false When set to true, SQL parser will parse float as decimal type datafusion.sql_parser.recursion_limit 50 Specifies the recursion depth limit when parsing complex SQL Queries datafusion.sql_parser.support_varchar_with_length true If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but ignore the length. If false, error if a `VARCHAR` with a length is specified. The Arrow type system does not have a notion of maximum string length and thus DataFusion can not enforce such limits. @@ -636,7 +693,7 @@ datafusion public abc CREATE EXTERNAL TABLE abc STORED AS CSV LOCATION ../../tes query TTT select routine_name, data_type, function_type from information_schema.routines where routine_name = 'string_agg'; ---- -string_agg LargeUtf8 AGGREGATE +string_agg String AGGREGATE # test every function type are included in the result query TTTTTTTBTTTT rowsort @@ -651,7 +708,7 @@ datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestam datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(Second, None) SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) datafusion public date_trunc datafusion public date_trunc FUNCTION true Timestamp(Second, Some("+TZ")) SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) datafusion public rank datafusion public rank FUNCTION true NULL WINDOW Returns the rank of the current row within its partition, allowing gaps between ranks. This function provides a ranking similar to `row_number`, but skips ranks for identical values. rank() -datafusion public string_agg datafusion public string_agg FUNCTION true LargeUtf8 AGGREGATE Concatenates the values of string expressions and places separator values between them. If ordering is required, strings are concatenated in the specified order. This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression. string_agg([DISTINCT] expression, delimiter [ORDER BY expression]) +datafusion public string_agg datafusion public string_agg FUNCTION true String AGGREGATE Concatenates the values of string expressions and places separator values between them. If ordering is required, strings are concatenated in the specified order. This aggregation function can only mix DISTINCT and ORDER BY if the ordering expression is exactly the same as the first argument expression. string_agg([DISTINCT] expression, delimiter [ORDER BY expression]) query B select is_deterministic from information_schema.routines where routine_name = 'now'; @@ -660,119 +717,65 @@ false # test every function type are included in the result query TTTITTTTBI -select * from information_schema.parameters where specific_name = 'date_trunc' OR specific_name = 'string_agg' OR specific_name = 'rank' ORDER BY specific_name, rid; ----- -datafusion public date_trunc 1 IN precision Utf8 NULL false 0 -datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, None) NULL false 0 -datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, None) NULL false 0 -datafusion public date_trunc 1 IN precision Utf8View NULL false 1 -datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, None) NULL false 1 -datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, None) NULL false 1 -datafusion public date_trunc 1 IN precision Utf8 NULL false 2 -datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, Some("+TZ")) NULL false 2 -datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, Some("+TZ")) NULL false 2 -datafusion public date_trunc 1 IN precision Utf8View NULL false 3 -datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, Some("+TZ")) NULL false 3 -datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, Some("+TZ")) NULL false 3 -datafusion public date_trunc 1 IN precision Utf8 NULL false 4 -datafusion public date_trunc 2 IN expression Timestamp(Microsecond, None) NULL false 4 -datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, None) NULL false 4 -datafusion public date_trunc 1 IN precision Utf8View NULL false 5 -datafusion public date_trunc 2 IN expression Timestamp(Microsecond, None) NULL false 5 -datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, None) NULL false 5 -datafusion public date_trunc 1 IN precision Utf8 NULL false 6 -datafusion public date_trunc 2 IN expression Timestamp(Microsecond, Some("+TZ")) NULL false 6 -datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, Some("+TZ")) NULL false 6 -datafusion public date_trunc 1 IN precision Utf8View NULL false 7 -datafusion public date_trunc 2 IN expression Timestamp(Microsecond, Some("+TZ")) NULL false 7 -datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, Some("+TZ")) NULL false 7 -datafusion public date_trunc 1 IN precision Utf8 NULL false 8 -datafusion public date_trunc 2 IN expression Timestamp(Millisecond, None) NULL false 8 -datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, None) NULL false 8 -datafusion public date_trunc 1 IN precision Utf8View NULL false 9 -datafusion public date_trunc 2 IN expression Timestamp(Millisecond, None) NULL false 9 -datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, None) NULL false 9 -datafusion public date_trunc 1 IN precision Utf8 NULL false 10 -datafusion public date_trunc 2 IN expression Timestamp(Millisecond, Some("+TZ")) NULL false 10 -datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, Some("+TZ")) NULL false 10 -datafusion public date_trunc 1 IN precision Utf8View NULL false 11 -datafusion public date_trunc 2 IN expression Timestamp(Millisecond, Some("+TZ")) NULL false 11 -datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, Some("+TZ")) NULL false 11 -datafusion public date_trunc 1 IN precision Utf8 NULL false 12 -datafusion public date_trunc 2 IN expression Timestamp(Second, None) NULL false 12 -datafusion public date_trunc 1 OUT NULL Timestamp(Second, None) NULL false 12 -datafusion public date_trunc 1 IN precision Utf8View NULL false 13 -datafusion public date_trunc 2 IN expression Timestamp(Second, None) NULL false 13 -datafusion public date_trunc 1 OUT NULL Timestamp(Second, None) NULL false 13 -datafusion public date_trunc 1 IN precision Utf8 NULL false 14 -datafusion public date_trunc 2 IN expression Timestamp(Second, Some("+TZ")) NULL false 14 -datafusion public date_trunc 1 OUT NULL Timestamp(Second, Some("+TZ")) NULL false 14 -datafusion public date_trunc 1 IN precision Utf8View NULL false 15 -datafusion public date_trunc 2 IN expression Timestamp(Second, Some("+TZ")) NULL false 15 -datafusion public date_trunc 1 OUT NULL Timestamp(Second, Some("+TZ")) NULL false 15 -datafusion public string_agg 1 IN expression LargeUtf8 NULL false 0 -datafusion public string_agg 2 IN delimiter Utf8 NULL false 0 -datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 0 -datafusion public string_agg 1 IN expression LargeUtf8 NULL false 1 -datafusion public string_agg 2 IN delimiter LargeUtf8 NULL false 1 -datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 1 -datafusion public string_agg 1 IN expression LargeUtf8 NULL false 2 -datafusion public string_agg 2 IN delimiter Null NULL false 2 -datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 2 -datafusion public string_agg 1 IN expression Utf8 NULL false 3 -datafusion public string_agg 2 IN delimiter Utf8 NULL false 3 -datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 3 -datafusion public string_agg 1 IN expression Utf8 NULL false 4 -datafusion public string_agg 2 IN delimiter LargeUtf8 NULL false 4 -datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 4 -datafusion public string_agg 1 IN expression Utf8 NULL false 5 -datafusion public string_agg 2 IN delimiter Null NULL false 5 -datafusion public string_agg 1 OUT NULL LargeUtf8 NULL false 5 +select * from information_schema.parameters where specific_name = 'date_trunc' OR specific_name = 'string_agg' OR specific_name = 'rank' ORDER BY specific_name, rid, data_type; +---- +datafusion public date_trunc 1 IN precision String NULL false 0 +datafusion public date_trunc 2 IN expression Timestamp(Microsecond, None) NULL false 0 +datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, None) NULL false 0 +datafusion public date_trunc 1 IN precision String NULL false 1 +datafusion public date_trunc 2 IN expression Timestamp(Microsecond, Some("+TZ")) NULL false 1 +datafusion public date_trunc 1 OUT NULL Timestamp(Microsecond, Some("+TZ")) NULL false 1 +datafusion public date_trunc 1 IN precision String NULL false 2 +datafusion public date_trunc 2 IN expression Timestamp(Millisecond, None) NULL false 2 +datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, None) NULL false 2 +datafusion public date_trunc 1 IN precision String NULL false 3 +datafusion public date_trunc 2 IN expression Timestamp(Millisecond, Some("+TZ")) NULL false 3 +datafusion public date_trunc 1 OUT NULL Timestamp(Millisecond, Some("+TZ")) NULL false 3 +datafusion public date_trunc 1 IN precision String NULL false 4 +datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, None) NULL false 4 +datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, None) NULL false 4 +datafusion public date_trunc 1 IN precision String NULL false 5 +datafusion public date_trunc 2 IN expression Timestamp(Nanosecond, Some("+TZ")) NULL false 5 +datafusion public date_trunc 1 OUT NULL Timestamp(Nanosecond, Some("+TZ")) NULL false 5 +datafusion public date_trunc 1 IN precision String NULL false 6 +datafusion public date_trunc 2 IN expression Timestamp(Second, None) NULL false 6 +datafusion public date_trunc 1 OUT NULL Timestamp(Second, None) NULL false 6 +datafusion public date_trunc 1 IN precision String NULL false 7 +datafusion public date_trunc 2 IN expression Timestamp(Second, Some("+TZ")) NULL false 7 +datafusion public date_trunc 1 OUT NULL Timestamp(Second, Some("+TZ")) NULL false 7 +datafusion public string_agg 2 IN delimiter Null NULL false 0 +datafusion public string_agg 1 IN expression String NULL false 0 +datafusion public string_agg 1 OUT NULL String NULL false 0 +datafusion public string_agg 1 IN expression String NULL false 1 +datafusion public string_agg 2 IN delimiter String NULL false 1 +datafusion public string_agg 1 OUT NULL String NULL false 1 # test variable length arguments query TTTBI rowsort select specific_name, data_type, parameter_mode, is_variadic, rid from information_schema.parameters where specific_name = 'concat'; ---- -concat LargeUtf8 IN true 2 -concat LargeUtf8 OUT false 2 -concat Utf8 IN true 1 -concat Utf8 OUT false 1 -concat Utf8View IN true 0 -concat Utf8View OUT false 0 +concat String IN true 0 +concat String OUT false 0 # test ceorcion signature query TTITI rowsort select specific_name, data_type, ordinal_position, parameter_mode, rid from information_schema.parameters where specific_name = 'repeat'; ---- repeat Int64 2 IN 0 -repeat Int64 2 IN 1 -repeat Int64 2 IN 2 -repeat LargeUtf8 1 IN 1 -repeat LargeUtf8 1 OUT 1 -repeat Utf8 1 IN 0 -repeat Utf8 1 OUT 0 -repeat Utf8 1 OUT 2 -repeat Utf8View 1 IN 2 +repeat String 1 IN 0 +repeat String 1 OUT 0 query TT??TTT rowsort show functions like 'date_trunc'; ---- -date_trunc Timestamp(Microsecond, None) [precision, expression] [Utf8, Timestamp(Microsecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Microsecond, None) [precision, expression] [Utf8View, Timestamp(Microsecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Microsecond, Some("+TZ")) [precision, expression] [Utf8, Timestamp(Microsecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Microsecond, Some("+TZ")) [precision, expression] [Utf8View, Timestamp(Microsecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Millisecond, None) [precision, expression] [Utf8, Timestamp(Millisecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Millisecond, None) [precision, expression] [Utf8View, Timestamp(Millisecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Millisecond, Some("+TZ")) [precision, expression] [Utf8, Timestamp(Millisecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Millisecond, Some("+TZ")) [precision, expression] [Utf8View, Timestamp(Millisecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Nanosecond, None) [precision, expression] [Utf8, Timestamp(Nanosecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Nanosecond, None) [precision, expression] [Utf8View, Timestamp(Nanosecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Nanosecond, Some("+TZ")) [precision, expression] [Utf8, Timestamp(Nanosecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Nanosecond, Some("+TZ")) [precision, expression] [Utf8View, Timestamp(Nanosecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Second, None) [precision, expression] [Utf8, Timestamp(Second, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Second, None) [precision, expression] [Utf8View, Timestamp(Second, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Second, Some("+TZ")) [precision, expression] [Utf8, Timestamp(Second, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) -date_trunc Timestamp(Second, Some("+TZ")) [precision, expression] [Utf8View, Timestamp(Second, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Microsecond, None) [precision, expression] [String, Timestamp(Microsecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Microsecond, Some("+TZ")) [precision, expression] [String, Timestamp(Microsecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Millisecond, None) [precision, expression] [String, Timestamp(Millisecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Millisecond, Some("+TZ")) [precision, expression] [String, Timestamp(Millisecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Nanosecond, None) [precision, expression] [String, Timestamp(Nanosecond, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Nanosecond, Some("+TZ")) [precision, expression] [String, Timestamp(Nanosecond, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Second, None) [precision, expression] [String, Timestamp(Second, None)] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) +date_trunc Timestamp(Second, Some("+TZ")) [precision, expression] [String, Timestamp(Second, Some("+TZ"))] SCALAR Truncates a timestamp value to a specified precision. date_trunc(precision, expression) statement ok show functions diff --git a/datafusion/sqllogictest/test_files/join.slt.part b/datafusion/sqllogictest/test_files/join.slt.part index 972dd2265343..19763ab0083f 100644 --- a/datafusion/sqllogictest/test_files/join.slt.part +++ b/datafusion/sqllogictest/test_files/join.slt.part @@ -842,7 +842,7 @@ LEFT JOIN department AS d ON (e.name = 'Alice' OR e.name = 'Bob'); ---- logical_plan -01)Left Join: Filter: e.name = Utf8("Alice") OR e.name = Utf8("Bob") +01)Left Join: Filter: e.name = Utf8View("Alice") OR e.name = Utf8View("Bob") 02)--SubqueryAlias: e 03)----TableScan: employees projection=[emp_id, name] 04)--SubqueryAlias: d @@ -929,7 +929,7 @@ ON (e.name = 'Alice' OR e.name = 'Bob'); logical_plan 01)Cross Join: 02)--SubqueryAlias: e -03)----Filter: employees.name = Utf8("Alice") OR employees.name = Utf8("Bob") +03)----Filter: employees.name = Utf8View("Alice") OR employees.name = Utf8View("Bob") 04)------TableScan: employees projection=[emp_id, name] 05)--SubqueryAlias: d 06)----TableScan: department projection=[dept_name] @@ -974,11 +974,11 @@ ON e.emp_id = d.emp_id WHERE ((dept_name != 'Engineering' AND e.name = 'Alice') OR (name != 'Alice' AND e.name = 'Carol')); ---- logical_plan -01)Filter: d.dept_name != Utf8("Engineering") AND e.name = Utf8("Alice") OR e.name != Utf8("Alice") AND e.name = Utf8("Carol") +01)Filter: d.dept_name != Utf8View("Engineering") AND e.name = Utf8View("Alice") OR e.name != Utf8View("Alice") AND e.name = Utf8View("Carol") 02)--Projection: e.emp_id, e.name, d.dept_name 03)----Left Join: e.emp_id = d.emp_id 04)------SubqueryAlias: e -05)--------Filter: employees.name = Utf8("Alice") OR employees.name != Utf8("Alice") AND employees.name = Utf8("Carol") +05)--------Filter: employees.name = Utf8View("Alice") OR employees.name != Utf8View("Alice") AND employees.name = Utf8View("Carol") 06)----------TableScan: employees projection=[emp_id, name] 07)------SubqueryAlias: d 08)--------TableScan: department projection=[emp_id, dept_name] @@ -1404,3 +1404,102 @@ set datafusion.execution.target_partitions = 4; statement ok set datafusion.optimizer.repartition_joins = false; + +statement ok +CREATE TABLE t1(v0 BIGINT, v1 BIGINT); + +statement ok +CREATE TABLE t0(v0 BIGINT, v1 BIGINT); + +statement ok +INSERT INTO t0(v0, v1) VALUES (1, 1), (1, 2), (3, 3), (4, 4); + +statement ok +INSERT INTO t1(v0, v1) VALUES (1, 1), (3, 2), (3, 5); + +query TT +explain SELECT * +FROM t0, +LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0); +---- +logical_plan +01)Projection: t0.v0, t0.v1, sum(t1.v1) +02)--Left Join: t0.v0 = t1.v0 +03)----TableScan: t0 projection=[v0, v1] +04)----Projection: sum(t1.v1), t1.v0 +05)------Aggregate: groupBy=[[t1.v0]], aggr=[[sum(t1.v1)]] +06)--------TableScan: t1 projection=[v0, v1] +physical_plan +01)ProjectionExec: expr=[v0@1 as v0, v1@2 as v1, sum(t1.v1)@0 as sum(t1.v1)] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----HashJoinExec: mode=CollectLeft, join_type=Right, on=[(v0@1, v0@0)], projection=[sum(t1.v1)@0, v0@2, v1@3] +04)------CoalescePartitionsExec +05)--------ProjectionExec: expr=[sum(t1.v1)@1 as sum(t1.v1), v0@0 as v0] +06)----------AggregateExec: mode=FinalPartitioned, gby=[v0@0 as v0], aggr=[sum(t1.v1)] +07)------------CoalesceBatchesExec: target_batch_size=8192 +08)--------------RepartitionExec: partitioning=Hash([v0@0], 4), input_partitions=4 +09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +10)------------------AggregateExec: mode=Partial, gby=[v0@0 as v0], aggr=[sum(t1.v1)] +11)--------------------DataSourceExec: partitions=1, partition_sizes=[1] +12)------DataSourceExec: partitions=1, partition_sizes=[1] + +query III +SELECT * +FROM t0, +LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0); +---- +1 1 1 +1 2 1 +3 3 7 +4 4 NULL + +query TT +explain SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t0.v0 = t1.v0); +---- +logical_plan +01)Inner Join: t0.v0 = t1.v0 +02)--TableScan: t0 projection=[v0, v1] +03)--TableScan: t1 projection=[v0, v1] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(v0@0, v0@0)] +03)----DataSourceExec: partitions=1, partition_sizes=[1] +04)----DataSourceExec: partitions=1, partition_sizes=[1] + +query IIII +SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t0.v0 = t1.v0); +---- +1 1 1 1 +1 2 1 1 +3 3 3 2 +3 3 3 5 + +query III +SELECT * FROM t0, LATERAL (SELECT 1); +---- +1 1 1 +1 2 1 +3 3 1 +4 4 1 + +query IIII +SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t1.v0 = 1); +---- +1 1 1 1 +1 2 1 1 +3 3 1 1 +4 4 1 1 + +query IIII +SELECT * FROM t0 JOIN LATERAL (SELECT * FROM t1 WHERE t1.v0 = 1) on true; +---- +1 1 1 1 +1 2 1 1 +3 3 1 1 +4 4 1 1 + +statement ok +drop table t1; + +statement ok +drop table t0; diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index ddf701ba04ef..3be5c1b1c370 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -1067,9 +1067,9 @@ LEFT JOIN join_t2 on join_t1.t1_id = join_t2.t2_id WHERE join_t2.t2_int < 10 or (join_t1.t1_int > 2 and join_t2.t2_name != 'w') ---- logical_plan -01)Inner Join: join_t1.t1_id = join_t2.t2_id Filter: join_t2.t2_int < UInt32(10) OR join_t1.t1_int > UInt32(2) AND join_t2.t2_name != Utf8("w") +01)Inner Join: join_t1.t1_id = join_t2.t2_id Filter: join_t2.t2_int < UInt32(10) OR join_t1.t1_int > UInt32(2) AND join_t2.t2_name != Utf8View("w") 02)--TableScan: join_t1 projection=[t1_id, t1_name, t1_int] -03)--Filter: join_t2.t2_int < UInt32(10) OR join_t2.t2_name != Utf8("w") +03)--Filter: join_t2.t2_int < UInt32(10) OR join_t2.t2_name != Utf8View("w") 04)----TableScan: join_t2 projection=[t2_id, t2_name, t2_int] # Reduce left join 3 (to inner join) @@ -1153,7 +1153,7 @@ WHERE join_t1.t1_name != 'b' ---- logical_plan 01)Left Join: join_t1.t1_id = join_t2.t2_id -02)--Filter: join_t1.t1_name != Utf8("b") +02)--Filter: join_t1.t1_name != Utf8View("b") 03)----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] 04)--TableScan: join_t2 projection=[t2_id, t2_name, t2_int] @@ -1168,9 +1168,9 @@ WHERE join_t1.t1_name != 'b' and join_t2.t2_name = 'x' ---- logical_plan 01)Inner Join: join_t1.t1_id = join_t2.t2_id -02)--Filter: join_t1.t1_name != Utf8("b") +02)--Filter: join_t1.t1_name != Utf8View("b") 03)----TableScan: join_t1 projection=[t1_id, t1_name, t1_int] -04)--Filter: join_t2.t2_name = Utf8("x") +04)--Filter: join_t2.t2_name = Utf8View("x") 05)----TableScan: join_t2 projection=[t2_id, t2_name, t2_int] ### @@ -1373,7 +1373,7 @@ inner join join_t4 on join_t3.s3 = join_t4.s4 {id: 2} {id: 2} # join with struct key and nulls -# Note that intersect or except applies `null_equals_null` as true for Join. +# Note that intersect or except applies `null_equality` as `NullEquality::NullEqualsNull` for Join. query ? SELECT * FROM join_t3 EXCEPT @@ -4087,7 +4087,7 @@ logical_plan 07)------------TableScan: sales_global projection=[ts, sn, amount, currency] 08)----------SubqueryAlias: e 09)------------Projection: exchange_rates.ts, exchange_rates.currency_from, exchange_rates.rate -10)--------------Filter: exchange_rates.currency_to = Utf8("USD") +10)--------------Filter: exchange_rates.currency_to = Utf8View("USD") 11)----------------TableScan: exchange_rates projection=[ts, currency_from, currency_to, rate] physical_plan 01)SortExec: expr=[sn@1 ASC NULLS LAST], preserve_partitioning=[false] @@ -4385,7 +4385,7 @@ JOIN my_catalog.my_schema.table_with_many_types AS r ON l.binary_col = r.binary_ logical_plan 01)Projection: count(Int64(1)) AS count(*) 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -03)----Projection: +03)----Projection: 04)------Inner Join: l.binary_col = r.binary_col 05)--------SubqueryAlias: l 06)----------TableScan: my_catalog.my_schema.table_with_many_types projection=[binary_col] @@ -4644,7 +4644,7 @@ logical_plan 08)----Subquery: 09)------Filter: j3.j3_string = outer_ref(j2.j2_string) 10)--------TableScan: j3 projection=[j3_string, j3_id] -physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Utf8, Column { relation: Some(Bare { table: "j2" }), name: "j2_string" }) +physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Utf8View, Column { relation: Some(Bare { table: "j2" }), name: "j2_string" }) query TT explain SELECT * FROM j1, LATERAL (SELECT * FROM j1, LATERAL (SELECT * FROM j2 WHERE j1_id = j2_id) as j2) as j2; diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 93ffa313b8f7..b46d15cb962a 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -365,7 +365,7 @@ EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6); logical_plan 01)Projection: count(Int64(1)) AS count(*) 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] -03)----Projection: +03)----Projection: 04)------Limit: skip=6, fetch=3 05)--------Filter: t1.a > Int32(3) 06)----------TableScan: t1 projection=[a] @@ -835,6 +835,7 @@ explain with selection as ( select * from test_limit_with_partitions + order by part_key limit 1 ) select 1 as foo @@ -847,19 +848,19 @@ logical_plan 02)--Sort: selection.part_key ASC NULLS LAST, fetch=1000 03)----Projection: Int64(1) AS foo, selection.part_key 04)------SubqueryAlias: selection -05)--------Limit: skip=0, fetch=1 -06)----------TableScan: test_limit_with_partitions projection=[part_key], fetch=1 +05)--------Sort: test_limit_with_partitions.part_key ASC NULLS LAST, fetch=1 +06)----------TableScan: test_limit_with_partitions projection=[part_key] physical_plan -01)ProjectionExec: expr=[foo@0 as foo] -02)--SortExec: TopK(fetch=1000), expr=[part_key@1 ASC NULLS LAST], preserve_partitioning=[false] -03)----ProjectionExec: expr=[1 as foo, part_key@0 as part_key] -04)------CoalescePartitionsExec: fetch=1 -05)--------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_limit_with_partitions/part-0.parquet:0..794], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_limit_with_partitions/part-1.parquet:0..794], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_limit_with_partitions/part-2.parquet:0..794]]}, projection=[part_key], limit=1, file_type=parquet +01)ProjectionExec: expr=[1 as foo] +02)--SortPreservingMergeExec: [part_key@0 ASC NULLS LAST], fetch=1 +03)----SortExec: TopK(fetch=1), expr=[part_key@0 ASC NULLS LAST], preserve_partitioning=[true] +04)------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_limit_with_partitions/part-0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_limit_with_partitions/part-1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_limit_with_partitions/part-2.parquet]]}, projection=[part_key], file_type=parquet, predicate=DynamicFilterPhysicalExpr [ true ] query I with selection as ( select * from test_limit_with_partitions + order by part_key limit 1 ) select 1 as foo diff --git a/datafusion/sqllogictest/test_files/listing_table_statistics.slt b/datafusion/sqllogictest/test_files/listing_table_statistics.slt new file mode 100644 index 000000000000..890d1f2e9250 --- /dev/null +++ b/datafusion/sqllogictest/test_files/listing_table_statistics.slt @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Test file with different schema order but generating correct statistics for table +statement ok +COPY (SELECT * FROM values (1, 'a'), (2, 'b') t(int_col, str_col)) to 'test_files/scratch/table/1.parquet'; + +statement ok +COPY (SELECT * FROM values ('c', 3), ('d', -1) t(str_col, int_col)) to 'test_files/scratch/table/2.parquet'; + +statement ok +set datafusion.execution.collect_statistics = true; + +statement ok +set datafusion.explain.show_statistics = true; + +statement ok +create external table t stored as parquet location 'test_files/scratch/table'; + +query TT +explain format indent select * from t; +---- +logical_plan TableScan: t projection=[int_col, str_col] +physical_plan DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/table/2.parquet]]}, projection=[int_col, str_col], file_type=parquet, statistics=[Rows=Exact(4), Bytes=Exact(288), [(Col[0]: Min=Exact(Int64(-1)) Max=Exact(Int64(3)) Null=Exact(0)),(Col[1]: Min=Exact(Utf8View("a")) Max=Exact(Utf8View("d")) Null=Exact(0))]] + +statement ok +drop table t; + +statement ok +set datafusion.execution.collect_statistics = false; + +statement ok +set datafusion.explain.show_statistics = false; diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 42a4ba621801..56481936e726 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -651,6 +651,57 @@ select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7) [NULL] [[4, NULL, 6]] [NULL] [NULL] [NULL] [[1, NULL, 3]] +# Tests for map_entries + +query ? +SELECT map_entries(MAP { 'a': 1, 'b': 3 }); +---- +[{key: a, value: 1}, {key: b, value: 3}] + +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type +SELECT map_entries(MAP { 'a': 1, 2: 3 }); + +query ? +SELECT map_entries(MAP {'a':1, 'b':2, 'c':3 }) FROM t; +---- +[{key: a, value: 1}, {key: b, value: 2}, {key: c, value: 3}] +[{key: a, value: 1}, {key: b, value: 2}, {key: c, value: 3}] +[{key: a, value: 1}, {key: b, value: 2}, {key: c, value: 3}] + +query ? +SELECT map_entries(Map{column1: column2, column3: column4}) FROM t; +---- +[{key: a, value: 1}, {key: k1, value: 10}] +[{key: b, value: 2}, {key: k3, value: 30}] +[{key: d, value: 4}, {key: k5, value: 50}] + +query ? +SELECT map_entries(map(column5, column6)) FROM t; +---- +[{key: k1, value: 1}, {key: k2, value: 2}] +[{key: k3, value: 3}] +[{key: k5, value: 5}] + +query ? +SELECT map_entries(map(column8, column9)) FROM t; +---- +[{key: [1, 2, 3], value: a}] +[{key: [4], value: b}] +[{key: [1, 2], value: c}] + +query ? +SELECT map_entries(Map{}); +---- +[] + +query ? +SELECT map_entries(column1) from map_array_table_1; +---- +[{key: 1, value: [1, NULL, 3]}, {key: 2, value: [4, NULL, 6]}, {key: 3, value: [7, 8, 9]}] +[{key: 4, value: [1, NULL, 3]}, {key: 5, value: [4, NULL, 6]}, {key: 6, value: [7, 8, 9]}] +[{key: 7, value: [1, NULL, 3]}, {key: 8, value: [9, NULL, 6]}, {key: 9, value: [7, 8, 9]}] +NULL + # Tests for map_keys query ? diff --git a/datafusion/sqllogictest/test_files/min_max/fixed_size_list.slt b/datafusion/sqllogictest/test_files/min_max/fixed_size_list.slt new file mode 100644 index 000000000000..aa623b63cdc7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/min_max/fixed_size_list.slt @@ -0,0 +1,133 @@ +# Min/Max with FixedSizeList over integers +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)')), +(arrow_cast(make_array(1, 2), 'FixedSizeList(2, Int64)')); +---- +[1, 2] [1, 2, 3, 4] + +# Min/Max with FixedSizeList over strings +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array('a', 'b', 'c'), 'FixedSizeList(3, Utf8)')), +(arrow_cast(make_array('a', 'b'), 'LargeList(Utf8)')); +---- +[a, b] [a, b, c] + +# Min/Max with FixedSizeList over booleans +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array(true, false, true), 'FixedSizeList(3, Boolean)')), +(arrow_cast(make_array(true, false), 'FixedSizeList(2, Boolean)')); +---- +[true, false] [true, false, true] + +# Min/Max with FixedSizeList over nullable integers +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array(NULL, 1, 2), 'FixedSizeList(3, Int64)')), +(arrow_cast(make_array(1, 2), 'FixedSizeList(2, Int64)')); +---- +[1, 2] [NULL, 1, 2] + +# Min/Max FixedSizeList with different lengths and nulls +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)')), +(arrow_cast(make_array(1, 2), 'FixedSizeList(2, Int64)')), +(arrow_cast(make_array(1, NULL, 3), 'FixedSizeList(3, Int64)')); +---- +[1, 2] [1, NULL, 3] + +# Min/Max FixedSizeList with only NULLs +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array(NULL, NULL), 'FixedSizeList(2, Int64)')), +(arrow_cast(make_array(NULL), 'FixedSizeList(1, Int64)')); +---- +[NULL] [NULL, NULL] + + +# Min/Max FixedSizeList of varying types (integers and NULLs) +query ?? +SELECT MIN(column1), MAX(column1) FROM VALUES +(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), +(arrow_cast(make_array(NULL, 2, 3), 'FixedSizeList(3, Int64)')), +(arrow_cast(make_array(1, 2, NULL), 'FixedSizeList(3, Int64)')); +---- +[1, 2, 3] [NULL, 2, 3] + +# Min/Max FixedSizeList grouped by key with NULLs and differing lengths +query I?? rowsort +SELECT column1, MIN(column2), MAX(column2) FROM VALUES +(0, arrow_cast(make_array(1, NULL, 3), 'FixedSizeList(3, Int64)')), +(0, arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)')), +(1, arrow_cast(make_array(1, 2), 'FixedSizeList(2, Int64)')), +(1, arrow_cast(make_array(NULL, 5), 'FixedSizeList(2, Int64)')) +GROUP BY column1; +---- +0 [1, 2, 3, 4] [1, NULL, 3] +1 [1, 2] [NULL, 5] + +# Min/Max FixedSizeList grouped by key with NULLs and differing lengths +query I?? rowsort +SELECT column1, MIN(column2), MAX(column2) FROM VALUES +(0, arrow_cast(make_array(NULL), 'FixedSizeList(1, Int64)')), +(0, arrow_cast(make_array(NULL, NULL), 'FixedSizeList(2, Int64)')), +(1, arrow_cast(make_array(NULL), 'FixedSizeList(1, Int64)')) +GROUP BY column1; +---- +0 [NULL] [NULL, NULL] +1 [NULL] [NULL] + +# Min/Max grouped FixedSizeList with empty and non-empty +query I?? rowsort +SELECT column1, MIN(column2), MAX(column2) FROM VALUES +(0, arrow_cast(make_array(1), 'FixedSizeList(1, Int64)')), +(1, arrow_cast(make_array(5, 6), 'FixedSizeList(2, Int64)')) +GROUP BY column1; +---- +0 [1] [1] +1 [5, 6] [5, 6] + +# Min/Max over FixedSizeList with a window function +query ? +SELECT min(column1) OVER (ORDER BY column1) FROM VALUES +(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), +(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)')), +(arrow_cast(make_array(2, 3), 'FixedSizeList(2, Int64)')) +---- +[1, 2, 3] +[1, 2, 3] +[1, 2, 3] + +# Min/Max over FixedSizeList with a window function and nulls +query ? +SELECT min(column1) OVER (ORDER BY column1) FROM VALUES +(arrow_cast(make_array(NULL), 'FixedSizeList(1, Int64)')), +(arrow_cast(make_array(4, 5), 'FixedSizeList(2, Int64)')), +(arrow_cast(make_array(2, 3), 'FixedSizeList(2, Int64)')) +---- +[2, 3] +[2, 3] +[2, 3] + +# Min/Max over FixedSizeList with a window function, nulls and ROWS BETWEEN statement +query ? +SELECT min(column1) OVER (ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM VALUES +(arrow_cast(make_array(NULL), 'FixedSizeList(1, Int64)')), +(arrow_cast(make_array(4, 5), 'FixedSizeList(2, Int64)')), +(arrow_cast(make_array(2, 3), 'FixedSizeList(2, Int64)')) +---- +[2, 3] +[2, 3] +[4, 5] + +# Min/Max over FixedSizeList with a window function using a different column +query ? +SELECT max(column2) OVER (ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM VALUES +(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), arrow_cast(make_array(4, 5), 'FixedSizeList(2, Int64)')), +(arrow_cast(make_array(2, 3), 'FixedSizeList(2, Int64)'), arrow_cast(make_array(2, 3), 'FixedSizeList(2, Int64)')) +---- +[4, 5] +[4, 5] diff --git a/datafusion/sqllogictest/test_files/min_max/init_data.slt.part b/datafusion/sqllogictest/test_files/min_max/init_data.slt.part new file mode 100644 index 000000000000..57e14f6993d4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/min_max/init_data.slt.part @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# -------------------------------------- +# 1. Min/Max over integers +# -------------------------------------- +statement ok +create table min_max_base_int as values + (make_array(1, 2, 3, 4)), + (make_array(1, 2)) +; + +# -------------------------------------- +# 2. Min/Max over strings +# -------------------------------------- +statement ok +create table min_max_base_string as values + (make_array('a', 'b', 'c')), + (make_array('a', 'b')) +; + +# -------------------------------------- +# 3. Min/Max over booleans +# -------------------------------------- +statement ok +create table min_max_base_bool as values + (make_array(true, false, true)), + (make_array(true, false)) +; + +# -------------------------------------- +# 4. Min/Max over nullable integers +# -------------------------------------- +statement ok +create table min_max_base_nullable_int as values + (make_array(NULL, 1, 2)), + (make_array(1, 2)) +; + +# -------------------------------------- +# 5. Min/Max with mixed lengths and nulls +# -------------------------------------- +statement ok +create table min_max_base_mixed_lengths_nulls as values + (make_array(1, 2, 3, 4)), + (make_array(1, 2)), + (make_array(1, NULL, 3)) +; + +# -------------------------------------- +# 6. Min/Max with only NULLs +# -------------------------------------- +statement ok +create table min_max_base_all_nulls as values + (make_array(NULL, NULL)), + (make_array(NULL)) +; + +# -------------------------------------- +# 7. Min/Max with partial NULLs +# -------------------------------------- +statement ok +create table min_max_base_null_variants as values + (make_array(1, 2, 3)), + (make_array(NULL, 2, 3)), + (make_array(1, 2, NULL)) +; + +# -------------------------------------- +# 8. Min/Max grouped by key with NULLs and differing lengths +# -------------------------------------- +statement ok +create table min_max_base_grouped_nulls as values + (0, make_array(1, NULL, 3)), + (0, make_array(1, 2, 3, 4)), + (1, make_array(1, 2)), + (1, make_array(NULL, 5)), + (1, make_array()) +; + +# -------------------------------------- +# 9. Min/Max grouped by key with only NULLs +# -------------------------------------- +statement ok +create table min_max_base_grouped_all_null as values + (0, make_array(NULL)), + (0, make_array(NULL, NULL)), + (1, make_array(NULL)) +; + +# -------------------------------------- +# 10. Min/Max grouped with empty and non-empty lists +# -------------------------------------- +statement ok +create table min_max_base_grouped_simple as values + (0, make_array()), + (0, make_array(1)), + (0, make_array()), + (1, make_array()), + (1, make_array(5, 6)) +; + +# -------------------------------------- +# 11. Min over with window function +# -------------------------------------- +statement ok +create table min_base_window_simple as values + (make_array(1, 2, 3)), + (make_array(1, 2, 3)), + (make_array(2, 3)) +; + +# -------------------------------------- +# 12. Min over with window + NULLs +# -------------------------------------- +statement ok +create table min_base_window_with_null as values + (make_array(NULL)), + (make_array(4, 5)), + (make_array(2, 3)) +; + +# -------------------------------------- +# 13. Min over with ROWS BETWEEN clause +# -------------------------------------- +statement ok +create table min_base_window_rows_between as values + (make_array(NULL)), + (make_array(4, 5)), + (make_array(2, 3)) +; + +# -------------------------------------- +# 14. Max over using different order column +# -------------------------------------- +statement ok +create table max_base_window_different_column as values + (make_array(1, 2, 3), make_array(4, 5)), + (make_array(2, 3), make_array(2, 3)), + (make_array(2, 3), NULL) +; diff --git a/datafusion/sqllogictest/test_files/min_max/large_list.slt b/datafusion/sqllogictest/test_files/min_max/large_list.slt new file mode 100644 index 000000000000..44789e9dd786 --- /dev/null +++ b/datafusion/sqllogictest/test_files/min_max/large_list.slt @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include ./init_data.slt.part + +## -------------------------------------- +## 1. Min/Max over integers +## -------------------------------------- +statement ok +create table min_max_int as ( + select + arrow_cast(column1, 'LargeList(Int64)') as column1 + from min_max_base_int + ); + +## -------------------------------------- +## 2. Min/Max over strings +## -------------------------------------- +statement ok +create table min_max_string as ( + select + arrow_cast(column1, 'LargeList(Utf8)') as column1 +from min_max_base_string); + +## -------------------------------------- +## 3. Min/Max over booleans +## -------------------------------------- +statement ok +create table min_max_bool as +( + select + arrow_cast(column1, 'LargeList(Boolean)') as column1 +from min_max_base_bool); + +## -------------------------------------- +## 4. Min/Max over nullable integers +## -------------------------------------- +statement ok +create table min_max_nullable_int as ( + select + arrow_cast(column1, 'LargeList(Int64)') as column1 + from min_max_base_nullable_int +); + +## -------------------------------------- +## 5. Min/Max with mixed lengths and nulls +## -------------------------------------- +statement ok +create table min_max_mixed_lengths_nulls as (select + arrow_cast(column1, 'LargeList(Int64)') as column1 +from min_max_base_mixed_lengths_nulls); + +## -------------------------------------- +## 6. Min/Max with only NULLs +## -------------------------------------- +statement ok +create table min_max_all_nulls as (select + arrow_cast(column1, 'LargeList(Int64)') as column1 +from min_max_base_all_nulls); + +## -------------------------------------- +## 7. Min/Max with partial NULLs +## -------------------------------------- +statement ok +create table min_max_null_variants as (select + arrow_cast(column1, 'LargeList(Int64)') as column1 +from min_max_base_null_variants); + +## -------------------------------------- +## 8. Min/Max grouped by key with NULLs and differing lengths +## -------------------------------------- +statement ok +create table min_max_grouped_nulls as (select + column1, + arrow_cast(column2, 'LargeList(Int64)') as column2 +from min_max_base_grouped_nulls); + +## -------------------------------------- +## 9. Min/Max grouped by key with only NULLs +## -------------------------------------- +statement ok +create table min_max_grouped_all_null as (select + column1, + arrow_cast(column2, 'LargeList(Int64)') as column2 +from min_max_base_grouped_all_null); + +## -------------------------------------- +## 10. Min/Max grouped with simple sizes +## -------------------------------------- +statement ok +create table min_max_grouped_simple as (select + column1, + arrow_cast(column2, 'LargeList(Int64)') as column2 +from min_max_base_grouped_simple); + +## -------------------------------------- +## 11. Min over with window function +## -------------------------------------- +statement ok +create table min_window_simple as (select + arrow_cast(column1, 'LargeList(Int64)') as column1 +from min_base_window_simple); + +## -------------------------------------- +## 12. Min over with window + NULLs +## -------------------------------------- +statement ok +create table min_window_with_null as (select + arrow_cast(column1, 'LargeList(Int64)') as column1 +from min_base_window_with_null); + +## -------------------------------------- +## 13. Min over with ROWS BETWEEN clause +## -------------------------------------- +statement ok +create table min_window_rows_between as (select + arrow_cast(column1, 'LargeList(Int64)') as column1 +from min_base_window_rows_between); + +## -------------------------------------- +## 14. Max over using different order column +## -------------------------------------- +statement ok +create table max_window_different_column as (select + arrow_cast(column1, 'LargeList(Int64)') as column1, + arrow_cast(column2, 'LargeList(Int64)') as column2 +from max_base_window_different_column); + +include ./queries.slt.part diff --git a/datafusion/sqllogictest/test_files/min_max/list.slt b/datafusion/sqllogictest/test_files/min_max/list.slt new file mode 100644 index 000000000000..e63e8303c7d5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/min_max/list.slt @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include ./init_data.slt.part + +# -------------------------------------- +# 1. Min/Max over integers +# -------------------------------------- +statement ok +create table min_max_int as ( + select * from min_max_base_int ) +; + +# -------------------------------------- +# 2. Min/Max over strings +# -------------------------------------- +statement ok +create table min_max_string as ( + select * from min_max_base_string ) +; + +# -------------------------------------- +# 3. Min/Max over booleans +# -------------------------------------- +statement ok +create table min_max_bool as ( + select * from min_max_base_bool ) +; + +# -------------------------------------- +# 4. Min/Max over nullable integers +# -------------------------------------- +statement ok +create table min_max_nullable_int as ( + select * from min_max_base_nullable_int ) +; + +# -------------------------------------- +# 5. Min/Max with mixed lengths and nulls +# -------------------------------------- +statement ok +create table min_max_mixed_lengths_nulls as ( + select * from min_max_base_mixed_lengths_nulls ) +; + +# -------------------------------------- +# 6. Min/Max with only NULLs +# -------------------------------------- +statement ok +create table min_max_all_nulls as ( + select * from min_max_base_all_nulls ) +; + +# -------------------------------------- +# 7. Min/Max with partial NULLs +# -------------------------------------- +statement ok +create table min_max_null_variants as ( + select * from min_max_base_null_variants ) +; + +# -------------------------------------- +# 8. Min/Max grouped by key with NULLs and differing lengths +# -------------------------------------- +statement ok +create table min_max_grouped_nulls as ( + select * from min_max_base_grouped_nulls ) +; + +# -------------------------------------- +# 9. Min/Max grouped by key with only NULLs +# -------------------------------------- +statement ok +create table min_max_grouped_all_null as ( + select * from min_max_base_grouped_all_null ) +; + +# -------------------------------------- +# 10. Min/Max grouped with simple sizes +# -------------------------------------- +statement ok +create table min_max_grouped_simple as ( + select * from min_max_base_grouped_simple ) +; + +# -------------------------------------- +# 11. Min over with window function +# -------------------------------------- +statement ok +create table min_window_simple as ( + select * from min_base_window_simple ) +; + +# -------------------------------------- +# 12. Min over with window + NULLs +# -------------------------------------- +statement ok +create table min_window_with_null as ( + select * from min_base_window_with_null ) +; + +# -------------------------------------- +# 13. Min over with ROWS BETWEEN clause +# -------------------------------------- +statement ok +create table min_window_rows_between as ( + select * from min_base_window_rows_between ) +; + +# -------------------------------------- +# 14. Max over using different order column +# -------------------------------------- +statement ok +create table max_window_different_column as ( + select * from max_base_window_different_column ) +; + +include ./queries.slt.part diff --git a/datafusion/sqllogictest/test_files/min_max/queries.slt.part b/datafusion/sqllogictest/test_files/min_max/queries.slt.part new file mode 100644 index 000000000000..bc7fb840bf97 --- /dev/null +++ b/datafusion/sqllogictest/test_files/min_max/queries.slt.part @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +## 1. Min/Max List over integers +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_int; +---- +[1, 2] [1, 2, 3, 4] + +## 2. Min/Max List over strings +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_string; +---- +[a, b] [a, b, c] + +## 3. Min/Max List over booleans +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_bool; +---- +[true, false] [true, false, true] + +## 4. Min/Max List over nullable integers +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_nullable_int; +---- +[1, 2] [NULL, 1, 2] + +## 5. Min/Max List with mixed lengths and nulls +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_mixed_lengths_nulls; +---- +[1, 2] [1, NULL, 3] + +## 6. Min/Max List with only NULLs +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_all_nulls; +---- +[NULL] [NULL, NULL] + +## 7. Min/Max List with partial NULLs +query ?? +SELECT MIN(column1), MAX(column1) FROM min_max_null_variants; +---- +[1, 2, 3] [NULL, 2, 3] + +## 8. Min/Max List grouped by key with NULLs and differing lengths +query I?? +SELECT column1, MIN(column2), MAX(column2) FROM min_max_grouped_nulls GROUP BY column1 ORDER BY column1; +---- +0 [1, 2, 3, 4] [1, NULL, 3] +1 [] [NULL, 5] + +## 9. Min/Max List grouped by key with only NULLs +query I?? +SELECT column1, MIN(column2), MAX(column2) FROM min_max_grouped_all_null GROUP BY column1 ORDER BY column1; +---- +0 [NULL] [NULL, NULL] +1 [NULL] [NULL] + +## 10. Min/Max grouped List with simple sizes +query I?? +SELECT column1, MIN(column2), MAX(column2) FROM min_max_grouped_simple GROUP BY column1 ORDER BY column1; +---- +0 [] [1] +1 [] [5, 6] + +## 11. Min over List with window function +query ? +SELECT MIN(column1) OVER (ORDER BY column1) FROM min_window_simple; +---- +[1, 2, 3] +[1, 2, 3] +[1, 2, 3] + +## 12. Min over List with window + NULLs +query ? +SELECT MIN(column1) OVER (ORDER BY column1) FROM min_window_with_null; +---- +[2, 3] +[2, 3] +[2, 3] + +## 13. Min over List with ROWS BETWEEN clause +query ? +SELECT MIN(column1) OVER (ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM min_window_rows_between; +---- +[2, 3] +[2, 3] +[4, 5] + +## 14. Max over List using different order column +query ? +SELECT MAX(column2) OVER (ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM max_window_different_column; +---- +[4, 5] +[4, 5] +[2, 3] diff --git a/datafusion/sqllogictest/test_files/monotonic_projection_test.slt b/datafusion/sqllogictest/test_files/monotonic_projection_test.slt index e8700b1fea27..9c806cfa0d8a 100644 --- a/datafusion/sqllogictest/test_files/monotonic_projection_test.slt +++ b/datafusion/sqllogictest/test_files/monotonic_projection_test.slt @@ -129,12 +129,12 @@ ORDER BY a_str ASC, b ASC; ---- logical_plan 01)Sort: a_str ASC NULLS LAST, multiple_ordered_table.b ASC NULLS LAST -02)--Projection: CAST(multiple_ordered_table.a AS Utf8) AS a_str, multiple_ordered_table.b +02)--Projection: CAST(multiple_ordered_table.a AS Utf8View) AS a_str, multiple_ordered_table.b 03)----TableScan: multiple_ordered_table projection=[a, b] physical_plan 01)SortPreservingMergeExec: [a_str@0 ASC NULLS LAST, b@1 ASC NULLS LAST] 02)--SortExec: expr=[a_str@0 ASC NULLS LAST, b@1 ASC NULLS LAST], preserve_partitioning=[true] -03)----ProjectionExec: expr=[CAST(a@0 AS Utf8) as a_str, b@1 as b] +03)----ProjectionExec: expr=[CAST(a@0 AS Utf8View) as a_str, b@1 as b] 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], file_type=csv, has_header=true diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index 4e8be56f3377..3fc90a6459f2 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -1040,12 +1040,12 @@ limit 5; ---- logical_plan 01)Sort: c_str ASC NULLS LAST, fetch=5 -02)--Projection: CAST(ordered_table.c AS Utf8) AS c_str +02)--Projection: CAST(ordered_table.c AS Utf8View) AS c_str 03)----TableScan: ordered_table projection=[c] physical_plan 01)SortPreservingMergeExec: [c_str@0 ASC NULLS LAST], fetch=5 02)--SortExec: TopK(fetch=5), expr=[c_str@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----ProjectionExec: expr=[CAST(c@0 AS Utf8) as c_str] +03)----ProjectionExec: expr=[CAST(c@0 AS Utf8View) as c_str] 04)------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], file_type=csv, has_header=true @@ -1380,3 +1380,42 @@ physical_plan statement ok drop table table_with_ordered_not_null; + +# ORDER BY ALL +statement ok +set datafusion.sql_parser.dialect = 'DuckDB'; + +statement ok +CREATE OR REPLACE TABLE addresses AS + SELECT '123 Quack Blvd' AS address, 'DuckTown' AS city, '11111' AS zip + UNION ALL + SELECT '111 Duck Duck Goose Ln', 'DuckTown', '11111' + UNION ALL + SELECT '111 Duck Duck Goose Ln', 'Duck Town', '11111' + UNION ALL + SELECT '111 Duck Duck Goose Ln', 'Duck Town', '11111-0001'; + + +query TTT +SELECT * FROM addresses ORDER BY ALL; +---- +111 Duck Duck Goose Ln Duck Town 11111 +111 Duck Duck Goose Ln Duck Town 11111-0001 +111 Duck Duck Goose Ln DuckTown 11111 +123 Quack Blvd DuckTown 11111 + +query TTT +SELECT * FROM addresses ORDER BY ALL DESC; +---- +123 Quack Blvd DuckTown 11111 +111 Duck Duck Goose Ln DuckTown 11111 +111 Duck Duck Goose Ln Duck Town 11111-0001 +111 Duck Duck Goose Ln Duck Town 11111 + +query TT +SELECT address, zip FROM addresses ORDER BY ALL; +---- +111 Duck Duck Goose Ln 11111 +111 Duck Duck Goose Ln 11111 +111 Duck Duck Goose Ln 11111-0001 +123 Quack Blvd 11111 diff --git a/datafusion/sqllogictest/test_files/parquet.slt b/datafusion/sqllogictest/test_files/parquet.slt index 2970b2effb3e..33bb052baa51 100644 --- a/datafusion/sqllogictest/test_files/parquet.slt +++ b/datafusion/sqllogictest/test_files/parquet.slt @@ -304,6 +304,54 @@ select count(*) from listing_table; ---- 12 +# Test table pointing to the folder with parquet files(ends with /) +statement ok +CREATE EXTERNAL TABLE listing_table_folder_0 +STORED AS PARQUET +LOCATION 'test_files/scratch/parquet/test_table/'; + +statement ok +set datafusion.execution.listing_table_ignore_subdirectory = true; + +# scan file: 0.parquet 1.parquet 2.parquet +query I +select count(*) from listing_table_folder_0; +---- +9 + +statement ok +set datafusion.execution.listing_table_ignore_subdirectory = false; + +# scan file: 0.parquet 1.parquet 2.parquet 3.parquet +query I +select count(*) from listing_table_folder_0; +---- +12 + +# Test table pointing to the folder with parquet files(doesn't end with /) +statement ok +CREATE EXTERNAL TABLE listing_table_folder_1 +STORED AS PARQUET +LOCATION 'test_files/scratch/parquet/test_table'; + +statement ok +set datafusion.execution.listing_table_ignore_subdirectory = true; + +# scan file: 0.parquet 1.parquet 2.parquet +query I +select count(*) from listing_table_folder_1; +---- +9 + +statement ok +set datafusion.execution.listing_table_ignore_subdirectory = false; + +# scan file: 0.parquet 1.parquet 2.parquet 3.parquet +query I +select count(*) from listing_table_folder_1; +---- +12 + # Clean up statement ok DROP TABLE timestamp_with_tz; @@ -629,3 +677,78 @@ physical_plan statement ok drop table foo + + +# Tests for int96 timestamps written by spark +# See https://github.com/apache/datafusion/issues/9981 + +statement ok +CREATE EXTERNAL TABLE int96_from_spark +STORED AS PARQUET +LOCATION '../../parquet-testing/data/int96_from_spark.parquet'; + +# by default the value is read as nanosecond precision +query TTT +describe int96_from_spark +---- +a Timestamp(Nanosecond, None) YES + +# Note that the values are read as nanosecond precision +query P +select * from int96_from_spark +---- +2024-01-01T20:34:56.123456 +2024-01-01T01:00:00 +1816-03-29T08:56:08.066277376 +2024-12-30T23:00:00 +NULL +1815-11-08T16:01:01.191053312 + +statement ok +drop table int96_from_spark; + +# Enable coercion of int96 to microseconds +statement ok +set datafusion.execution.parquet.coerce_int96 = ms; + +statement ok +CREATE EXTERNAL TABLE int96_from_spark +STORED AS PARQUET +LOCATION '../../parquet-testing/data/int96_from_spark.parquet'; + +# Print schema +query TTT +describe int96_from_spark; +---- +a Timestamp(Millisecond, None) YES + +# Per https://github.com/apache/parquet-testing/blob/6e851ddd768d6af741c7b15dc594874399fc3cff/data/int96_from_spark.md?plain=1#L37 +# these values should be +# +# Some("2024-01-01T12:34:56.123456"), +# Some("2024-01-01T01:00:00Z"), +# Some("9999-12-31T01:00:00-02:00"), +# Some("2024-12-31T01:00:00+02:00"), +# None, +# Some("290000-12-31T01:00:00+02:00")) +# +# However, printing the large dates (9999-12-31 and 290000-12-31) is not supported by +# arrow yet +# +# See https://github.com/apache/arrow-rs/issues/7287 +query P +select * from int96_from_spark +---- +2024-01-01T20:34:56.123 +2024-01-01T01:00:00 +9999-12-31T03:00:00 +2024-12-30T23:00:00 +NULL +ERROR: Cast error: Failed to convert -9357363680509551 to datetime for Timestamp(Millisecond, None) + +# Cleanup / reset default setting +statement ok +drop table int96_from_spark; + +statement ok +set datafusion.execution.parquet.coerce_int96 = ns; diff --git a/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt b/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt index 758113b70835..ed3bed1c2004 100644 --- a/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt +++ b/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt @@ -54,7 +54,6 @@ LOCATION 'test_files/scratch/parquet_filter_pushdown/parquet_table/'; statement ok set datafusion.execution.parquet.pushdown_filters = true; -## Create table without pushdown statement ok CREATE EXTERNAL TABLE t_pushdown(a varchar, b int, c float) STORED AS PARQUET LOCATION 'test_files/scratch/parquet_filter_pushdown/parquet_table/'; @@ -81,7 +80,9 @@ EXPLAIN select a from t_pushdown where b > 2 ORDER BY a; ---- logical_plan 01)Sort: t_pushdown.a ASC NULLS LAST -02)--TableScan: t_pushdown projection=[a], full_filters=[t_pushdown.b > Int32(2)] +02)--Projection: t_pushdown.a +03)----Filter: t_pushdown.b > Int32(2) +04)------TableScan: t_pushdown projection=[a, b], partial_filters=[t_pushdown.b > Int32(2)] physical_plan 01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] 02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] @@ -127,7 +128,9 @@ EXPLAIN select a from t_pushdown where b > 2 AND a IS NOT NULL order by a; ---- logical_plan 01)Sort: t_pushdown.a ASC NULLS LAST -02)--TableScan: t_pushdown projection=[a], full_filters=[t_pushdown.b > Int32(2), t_pushdown.a IS NOT NULL] +02)--Projection: t_pushdown.a +03)----Filter: t_pushdown.b > Int32(2) AND t_pushdown.a IS NOT NULL +04)------TableScan: t_pushdown projection=[a, b], partial_filters=[t_pushdown.b > Int32(2), t_pushdown.a IS NOT NULL] physical_plan 01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] 02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] @@ -144,7 +147,9 @@ EXPLAIN select b from t_pushdown where a = 'bar' order by b; ---- logical_plan 01)Sort: t_pushdown.b ASC NULLS LAST -02)--TableScan: t_pushdown projection=[b], full_filters=[t_pushdown.a = Utf8("bar")] +02)--Projection: t_pushdown.b +03)----Filter: t_pushdown.a = Utf8View("bar") +04)------TableScan: t_pushdown projection=[a, b], partial_filters=[t_pushdown.a = Utf8View("bar")] physical_plan 01)SortPreservingMergeExec: [b@0 ASC NULLS LAST] 02)--SortExec: expr=[b@0 ASC NULLS LAST], preserve_partitioning=[true] @@ -156,3 +161,120 @@ DROP TABLE t; statement ok DROP TABLE t_pushdown; + +## Test filter pushdown with a predicate that references both a partition column and a file column +statement ok +set datafusion.execution.parquet.pushdown_filters = true; + +## Create table +statement ok +CREATE EXTERNAL TABLE t_pushdown(part text, val text) +STORED AS PARQUET +PARTITIONED BY (part) +LOCATION 'test_files/scratch/parquet_filter_pushdown/parquet_part_test/'; + +statement ok +COPY ( + SELECT arrow_cast('a', 'Utf8') AS val +) TO 'test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=a/file.parquet' +STORED AS PARQUET; + +statement ok +COPY ( + SELECT arrow_cast('b', 'Utf8') AS val +) TO 'test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=b/file.parquet' +STORED AS PARQUET; + +statement ok +COPY ( + SELECT arrow_cast('xyz', 'Utf8') AS val +) TO 'test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=c/file.parquet' +STORED AS PARQUET; + +query TT +select * from t_pushdown where part == val order by part, val; +---- +a a +b b + +query TT +select * from t_pushdown where part != val order by part, val; +---- +xyz c + +# If we reference both a file and partition column the predicate cannot be pushed down +query TT +EXPLAIN select * from t_pushdown where part != val +---- +logical_plan +01)Filter: t_pushdown.val != t_pushdown.part +02)--TableScan: t_pushdown projection=[val, part], partial_filters=[t_pushdown.val != t_pushdown.part] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: val@0 != part@1 +03)----DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=a/file.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=b/file.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=c/file.parquet]]}, projection=[val, part], file_type=parquet + +# If we reference only a partition column it gets evaluated during the listing phase +query TT +EXPLAIN select * from t_pushdown where part != 'a'; +---- +logical_plan TableScan: t_pushdown projection=[val, part], full_filters=[t_pushdown.part != Utf8View("a")] +physical_plan DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=b/file.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=c/file.parquet]]}, projection=[val, part], file_type=parquet + +# And if we reference only a file column it gets pushed down +query TT +EXPLAIN select * from t_pushdown where val != 'c'; +---- +logical_plan +01)Filter: t_pushdown.val != Utf8View("c") +02)--TableScan: t_pushdown projection=[val, part], partial_filters=[t_pushdown.val != Utf8View("c")] +physical_plan DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=a/file.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=b/file.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=c/file.parquet]]}, projection=[val, part], file_type=parquet, predicate=val@0 != c, pruning_predicate=val_null_count@2 != row_count@3 AND (val_min@0 != c OR c != val_max@1), required_guarantees=[val not in (c)] + +# If we have a mix of filters: +# - The partition filters get evaluated during planning +# - The mixed filters end up in a FilterExec +# - The file filters get pushed down into the scan +query TT +EXPLAIN select * from t_pushdown where val != 'd' AND val != 'c' AND part = 'a' AND part != val; +---- +logical_plan +01)Filter: t_pushdown.val != Utf8View("d") AND t_pushdown.val != Utf8View("c") AND t_pushdown.val != t_pushdown.part +02)--TableScan: t_pushdown projection=[val, part], full_filters=[t_pushdown.part = Utf8View("a")], partial_filters=[t_pushdown.val != Utf8View("d"), t_pushdown.val != Utf8View("c"), t_pushdown.val != t_pushdown.part] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: val@0 != part@1 +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=a/file.parquet]]}, projection=[val, part], file_type=parquet, predicate=val@0 != d AND val@0 != c, pruning_predicate=val_null_count@2 != row_count@3 AND (val_min@0 != d OR d != val_max@1) AND val_null_count@2 != row_count@3 AND (val_min@0 != c OR c != val_max@1), required_guarantees=[val not in (c, d)] + +# The order of filters should not matter +query TT +EXPLAIN select val, part from t_pushdown where part = 'a' AND part = val; +---- +logical_plan +01)Filter: t_pushdown.val = t_pushdown.part +02)--TableScan: t_pushdown projection=[val, part], full_filters=[t_pushdown.part = Utf8View("a")], partial_filters=[t_pushdown.val = t_pushdown.part] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: val@0 = part@1 +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=a/file.parquet]]}, projection=[val, part], file_type=parquet + +query TT +select val, part from t_pushdown where part = 'a' AND part = val; +---- +a a + +query TT +EXPLAIN select val, part from t_pushdown where part = val AND part = 'a'; +---- +logical_plan +01)Filter: t_pushdown.val = t_pushdown.part +02)--TableScan: t_pushdown projection=[val, part], full_filters=[t_pushdown.part = Utf8View("a")], partial_filters=[t_pushdown.val = t_pushdown.part] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: val@0 = part@1 +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_part_test/part=a/file.parquet]]}, projection=[val, part], file_type=parquet + +query TT +select val, part from t_pushdown where part = val AND part = 'a'; +---- +a a diff --git a/datafusion/sqllogictest/test_files/parquet_statistics.slt b/datafusion/sqllogictest/test_files/parquet_statistics.slt new file mode 100644 index 000000000000..efbe69bd856c --- /dev/null +++ b/datafusion/sqllogictest/test_files/parquet_statistics.slt @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for statistics in parquet files. +# Writes data into two files: +# * test_table/0.parquet +# * test_table/1.parquet +# +# And verifies statistics are correctly calculated for the table +# +# NOTE that statistics are ONLY gathered when the table is first created +# so the table must be recreated to see the effects of the setting + +query I +COPY (values (1), (2), (3)) +TO 'test_files/scratch/parquet_statistics/test_table/0.parquet' +STORED AS PARQUET; +---- +3 + +query I +COPY (values (3), (4)) +TO 'test_files/scratch/parquet_statistics/test_table/1.parquet' +STORED AS PARQUET; +---- +2 + +statement ok +set datafusion.explain.physical_plan_only = true; + +statement ok +set datafusion.explain.show_statistics = true; + +###### +# By default, the statistics are gathered +###### + +# Recreate the table to pick up the current setting +statement ok +CREATE EXTERNAL TABLE test_table +STORED AS PARQUET +LOCATION 'test_files/scratch/parquet_statistics/test_table'; + +query TT +EXPLAIN SELECT * FROM test_table WHERE column1 = 1; +---- +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192, statistics=[Rows=Inexact(2), Bytes=Inexact(44), [(Col[0]: Min=Exact(Int64(1)) Max=Exact(Int64(1)) Null=Inexact(0))]] +02)--FilterExec: column1@0 = 1, statistics=[Rows=Inexact(2), Bytes=Inexact(44), [(Col[0]: Min=Exact(Int64(1)) Max=Exact(Int64(1)) Null=Inexact(0))]] +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2, statistics=[Rows=Inexact(5), Bytes=Inexact(173), [(Col[0]: Min=Inexact(Int64(1)) Max=Inexact(Int64(4)) Null=Inexact(0))]] +04)------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_statistics/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_statistics/test_table/1.parquet]]}, projection=[column1], file_type=parquet, predicate=column1@0 = 1, pruning_predicate=column1_null_count@2 != row_count@3 AND column1_min@0 <= 1 AND 1 <= column1_max@1, required_guarantees=[column1 in (1)] +05), statistics=[Rows=Inexact(5), Bytes=Inexact(173), [(Col[0]: Min=Inexact(Int64(1)) Max=Inexact(Int64(4)) Null=Inexact(0))]] + +# cleanup +statement ok +DROP TABLE test_table; + +###### +# When the setting is true, statistics are gathered +###### + +statement ok +set datafusion.execution.collect_statistics = true; + +# Recreate the table to pick up the current setting +statement ok +CREATE EXTERNAL TABLE test_table +STORED AS PARQUET +LOCATION 'test_files/scratch/parquet_statistics/test_table'; + +query TT +EXPLAIN SELECT * FROM test_table WHERE column1 = 1; +---- +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192, statistics=[Rows=Inexact(2), Bytes=Inexact(44), [(Col[0]: Min=Exact(Int64(1)) Max=Exact(Int64(1)) Null=Inexact(0))]] +02)--FilterExec: column1@0 = 1, statistics=[Rows=Inexact(2), Bytes=Inexact(44), [(Col[0]: Min=Exact(Int64(1)) Max=Exact(Int64(1)) Null=Inexact(0))]] +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2, statistics=[Rows=Inexact(5), Bytes=Inexact(173), [(Col[0]: Min=Inexact(Int64(1)) Max=Inexact(Int64(4)) Null=Inexact(0))]] +04)------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_statistics/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_statistics/test_table/1.parquet]]}, projection=[column1], file_type=parquet, predicate=column1@0 = 1, pruning_predicate=column1_null_count@2 != row_count@3 AND column1_min@0 <= 1 AND 1 <= column1_max@1, required_guarantees=[column1 in (1)] +05), statistics=[Rows=Inexact(5), Bytes=Inexact(173), [(Col[0]: Min=Inexact(Int64(1)) Max=Inexact(Int64(4)) Null=Inexact(0))]] + +# cleanup +statement ok +DROP TABLE test_table; + + +###### +# When the setting is false, the statistics are NOT gathered +###### + +statement ok +set datafusion.execution.collect_statistics = false; + +# Recreate the table to pick up the current setting +statement ok +CREATE EXTERNAL TABLE test_table +STORED AS PARQUET +LOCATION 'test_files/scratch/parquet_statistics/test_table'; + +query TT +EXPLAIN SELECT * FROM test_table WHERE column1 = 1; +---- +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:)]] +02)--FilterExec: column1@0 = 1, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]: Min=Exact(Int64(1)) Max=Exact(Int64(1)))]] +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:)]] +04)------DataSourceExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_statistics/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_statistics/test_table/1.parquet]]}, projection=[column1], file_type=parquet, predicate=column1@0 = 1, pruning_predicate=column1_null_count@2 != row_count@3 AND column1_min@0 <= 1 AND 1 <= column1_max@1, required_guarantees=[column1 in (1)] +05), statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:)]] + +# cleanup +statement ok +DROP TABLE test_table; diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index b263e39f3b11..b4b31fa78a69 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -662,11 +662,11 @@ OR ---- logical_plan 01)Projection: lineitem.l_partkey -02)--Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8("Brand#12") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) +02)--Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8View("Brand#12") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) 03)----Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) 04)------TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] -05)----Filter: (part.p_brand = Utf8("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) -06)------TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_size <= Int32(15)] +05)----Filter: (part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) +06)------TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15)] physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 02)--HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_partkey@0] @@ -755,8 +755,8 @@ logical_plan 05)--------Inner Join: lineitem.l_partkey = part.p_partkey 06)----------TableScan: lineitem projection=[l_partkey, l_extendedprice, l_discount] 07)----------Projection: part.p_partkey -08)------------Filter: part.p_brand = Utf8("Brand#12") OR part.p_brand = Utf8("Brand#23") -09)--------------TableScan: part projection=[p_partkey, p_brand], partial_filters=[part.p_brand = Utf8("Brand#12") OR part.p_brand = Utf8("Brand#23")] +08)------------Filter: part.p_brand = Utf8View("Brand#12") OR part.p_brand = Utf8View("Brand#23") +09)--------------TableScan: part projection=[p_partkey, p_brand], partial_filters=[part.p_brand = Utf8View("Brand#12") OR part.p_brand = Utf8View("Brand#23")] 10)------TableScan: partsupp projection=[ps_partkey, ps_suppkey] physical_plan 01)AggregateExec: mode=SinglePartitioned, gby=[p_partkey@2 as p_partkey], aggr=[sum(lineitem.l_extendedprice), avg(lineitem.l_discount), count(DISTINCT partsupp.ps_suppkey)] diff --git a/datafusion/sqllogictest/test_files/prepare.slt b/datafusion/sqllogictest/test_files/prepare.slt index 33df0d26f361..d61603ae6558 100644 --- a/datafusion/sqllogictest/test_files/prepare.slt +++ b/datafusion/sqllogictest/test_files/prepare.slt @@ -92,7 +92,7 @@ DEALLOCATE my_plan statement ok PREPARE my_plan AS SELECT * FROM person WHERE id < $1; -statement error No value found for placeholder with id \$1 +statement error Prepared statement 'my_plan' expects 1 parameters, but 0 provided EXECUTE my_plan statement ok diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt index 67965146e76b..a0d319332462 100644 --- a/datafusion/sqllogictest/test_files/push_down_filter.slt +++ b/datafusion/sqllogictest/test_files/push_down_filter.slt @@ -18,7 +18,7 @@ # Test push down filter statement ok -set datafusion.explain.logical_plan_only = true; +set datafusion.explain.physical_plan_only = true; statement ok CREATE TABLE IF NOT EXISTS v AS VALUES(1,[1,2,3]),(2,[3,4,5]); @@ -35,12 +35,14 @@ select uc2 from (select unnest(column2) as uc2, column1 from v) where column1 = query TT explain select uc2 from (select unnest(column2) as uc2, column1 from v) where column1 = 2; ---- -logical_plan -01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2 -02)--Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] -03)----Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 -04)------Filter: v.column1 = Int64(2) -05)--------TableScan: v projection=[column1, column2] +physical_plan +01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2] +02)--UnnestExec +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------ProjectionExec: expr=[column2@1 as __unnest_placeholder(v.column2), column1@0 as column1] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------FilterExec: column1@0 = 2 +07)------------DataSourceExec: partitions=1, partition_sizes=[1] query I select uc2 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3; @@ -52,13 +54,15 @@ select uc2 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3; query TT explain select uc2 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3; ---- -logical_plan -01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2 -02)--Filter: __unnest_placeholder(v.column2,depth=1) > Int64(3) -03)----Projection: __unnest_placeholder(v.column2,depth=1) -04)------Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] -05)--------Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 -06)----------TableScan: v projection=[column1, column2] +physical_plan +01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: __unnest_placeholder(v.column2,depth=1)@0 > 3 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as __unnest_placeholder(v.column2,depth=1)] +06)----------UnnestExec +07)------------ProjectionExec: expr=[column2@1 as __unnest_placeholder(v.column2), column1@0 as column1] +08)--------------DataSourceExec: partitions=1, partition_sizes=[1] query II select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 AND column1 = 2; @@ -70,13 +74,16 @@ select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where query TT explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 AND column1 = 2; ---- -logical_plan -01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 -02)--Filter: __unnest_placeholder(v.column2,depth=1) > Int64(3) -03)----Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] -04)------Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 -05)--------Filter: v.column1 = Int64(2) -06)----------TableScan: v projection=[column1, column2] +physical_plan +01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2, column1@1 as column1] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: __unnest_placeholder(v.column2,depth=1)@0 > 3 +04)------UnnestExec +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +06)----------ProjectionExec: expr=[column2@1 as __unnest_placeholder(v.column2), column1@0 as column1] +07)------------CoalesceBatchesExec: target_batch_size=8192 +08)--------------FilterExec: column1@0 = 2 +09)----------------DataSourceExec: partitions=1, partition_sizes=[1] query II select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; @@ -89,12 +96,14 @@ select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where query TT explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; ---- -logical_plan -01)Projection: __unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 -02)--Filter: __unnest_placeholder(v.column2,depth=1) > Int64(3) OR v.column1 = Int64(2) -03)----Unnest: lists[__unnest_placeholder(v.column2)|depth=1] structs[] -04)------Projection: v.column2 AS __unnest_placeholder(v.column2), v.column1 -05)--------TableScan: v projection=[column1, column2] +physical_plan +01)ProjectionExec: expr=[__unnest_placeholder(v.column2,depth=1)@0 as uc2, column1@1 as column1] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: __unnest_placeholder(v.column2,depth=1)@0 > 3 OR column1@1 = 2 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------UnnestExec +06)----------ProjectionExec: expr=[column2@1 as __unnest_placeholder(v.column2), column1@0 as column1] +07)------------DataSourceExec: partitions=1, partition_sizes=[1] statement ok drop table v; @@ -111,12 +120,14 @@ select * from (select column1, unnest(column2) as o from d) where o['a'] = 1; query TT explain select * from (select column1, unnest(column2) as o from d) where o['a'] = 1; ---- -logical_plan -01)Projection: d.column1, __unnest_placeholder(d.column2,depth=1) AS o -02)--Filter: get_field(__unnest_placeholder(d.column2,depth=1), Utf8("a")) = Int64(1) -03)----Unnest: lists[__unnest_placeholder(d.column2)|depth=1] structs[] -04)------Projection: d.column1, d.column2 AS __unnest_placeholder(d.column2) -05)--------TableScan: d projection=[column1, column2] +physical_plan +01)ProjectionExec: expr=[column1@0 as column1, __unnest_placeholder(d.column2,depth=1)@1 as o] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: get_field(__unnest_placeholder(d.column2,depth=1)@1, a) = 1 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------UnnestExec +06)----------ProjectionExec: expr=[column1@0 as column1, column2@1 as __unnest_placeholder(d.column2)] +07)------------DataSourceExec: partitions=1, partition_sizes=[1] @@ -179,9 +190,9 @@ LOCATION 'test_files/scratch/parquet/test_filter_with_limit/'; query TT explain select * from test_filter_with_limit where value = 2 limit 1; ---- -logical_plan -01)Limit: skip=0, fetch=1 -02)--TableScan: test_filter_with_limit projection=[part_key, value], full_filters=[test_filter_with_limit.value = Int32(2)], fetch=1 +physical_plan +01)CoalescePartitionsExec: fetch=1 +02)--DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_filter_with_limit/part-0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_filter_with_limit/part-1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_filter_with_limit/part-2.parquet]]}, projection=[part_key, value], limit=1, file_type=parquet, predicate=value@1 = 2, pruning_predicate=value_null_count@2 != row_count@3 AND value_min@0 <= 2 AND 2 <= value_max@1, required_guarantees=[value in (2)] query II select * from test_filter_with_limit where value = 2 limit 1; @@ -218,43 +229,43 @@ LOCATION 'test_files/scratch/push_down_filter/t.parquet'; query TT explain select a from t where a = '100'; ---- -logical_plan TableScan: t projection=[a], full_filters=[t.a = Int32(100)] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 = 100, pruning_predicate=a_null_count@2 != row_count@3 AND a_min@0 <= 100 AND 100 <= a_max@1, required_guarantees=[a in (100)] # The predicate should not have a column cast when the value is a valid i32 query TT explain select a from t where a != '100'; ---- -logical_plan TableScan: t projection=[a], full_filters=[t.a != Int32(100)] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 != 100, pruning_predicate=a_null_count@2 != row_count@3 AND (a_min@0 != 100 OR 100 != a_max@1), required_guarantees=[a not in (100)] # The predicate should still have the column cast when the value is a NOT valid i32 query TT explain select a from t where a = '99999999999'; ---- -logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("99999999999")] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = 99999999999 # The predicate should still have the column cast when the value is a NOT valid i32 query TT explain select a from t where a = '99.99'; ---- -logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("99.99")] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = 99.99 # The predicate should still have the column cast when the value is a NOT valid i32 query TT explain select a from t where a = ''; ---- -logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("")] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8) = # The predicate should not have a column cast when the operator is = or != and the literal can be round-trip casted without losing information. query TT explain select a from t where cast(a as string) = '100'; ---- -logical_plan TableScan: t projection=[a], full_filters=[t.a = Int32(100)] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=a@0 = 100, pruning_predicate=a_null_count@2 != row_count@3 AND a_min@0 <= 100 AND 100 <= a_max@1, required_guarantees=[a in (100)] # The predicate should still have the column cast when the literal alters its string representation after round-trip casting (leading zero lost). query TT explain select a from t where CAST(a AS string) = '0123'; ---- -logical_plan TableScan: t projection=[a], full_filters=[CAST(t.a AS Utf8) = Utf8("0123")] +physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/push_down_filter/t.parquet]]}, projection=[a], file_type=parquet, predicate=CAST(a@0 AS Utf8View) = 0123 statement ok diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt deleted file mode 100644 index 44ba61e877d9..000000000000 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ /dev/null @@ -1,898 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -statement ok -CREATE TABLE t (str varchar, pattern varchar, start int, flags varchar) AS VALUES - ('abc', '^(a)', 1, 'i'), - ('ABC', '^(A).*', 1, 'i'), - ('aBc', '(b|d)', 1, 'i'), - ('AbC', '(B|D)', 2, null), - ('aBC', '^(b|c)', 3, null), - ('4000', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 1, null), - ('4010', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 2, null), - ('Düsseldorf','[\p{Letter}-]+', 3, null), - ('Москва', '[\p{L}-]+', 4, null), - ('Köln', '[a-zA-Z]ö[a-zA-Z]{2}', 1, null), - ('إسرائيل', '^\p{Arabic}+$', 2, null); - -# -# regexp_like tests -# - -query B -SELECT regexp_like(str, pattern, flags) FROM t; ----- -true -true -true -false -false -false -true -true -true -true -true - -query B -SELECT str ~ NULL FROM t; ----- -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL - -query B -select str ~ right('foo', NULL) FROM t; ----- -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL - -query B -select right('foo', NULL) !~ str FROM t; ----- -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL -NULL - -query B -SELECT regexp_like('foobarbequebaz', ''); ----- -true - -query B -SELECT regexp_like('', ''); ----- -true - -query B -SELECT regexp_like('foobarbequebaz', '(bar)(beque)'); ----- -true - -query B -SELECT regexp_like('fooBarb -eQuebaz', '(bar).*(que)', 'is'); ----- -true - -query B -SELECT regexp_like('foobarbequebaz', '(ba3r)(bequ34e)'); ----- -false - -query B -SELECT regexp_like('foobarbequebaz', '^.*(barbequ[0-9]*e).*$', 'm'); ----- -true - -query B -SELECT regexp_like('aaa-0', '.*-(\d)'); ----- -true - -query B -SELECT regexp_like('bb-1', '.*-(\d)'); ----- -true - -query B -SELECT regexp_like('aa', '.*-(\d)'); ----- -false - -query B -SELECT regexp_like(NULL, '.*-(\d)'); ----- -NULL - -query B -SELECT regexp_like('aaa-0', NULL); ----- -NULL - -query B -SELECT regexp_like(null, '.*-(\d)'); ----- -NULL - -query error Error during planning: regexp_like\(\) does not support the "global" option -SELECT regexp_like('bb-1', '.*-(\d)', 'g'); - -query error Error during planning: regexp_like\(\) does not support the "global" option -SELECT regexp_like('bb-1', '.*-(\d)', 'g'); - -query error Arrow error: Compute error: Regular expression did not compile: CompiledTooBig\(10485760\) -SELECT regexp_like('aaaaa', 'a{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}'); - -# look-around is not supported and will just return false -query B -SELECT regexp_like('(?<=[A-Z]\w )Smith', 'John Smith', 'i'); ----- -false - -query B -select regexp_like('aaa-555', '.*-(\d*)'); ----- -true - -# -# regexp_match tests -# - -query ? -SELECT regexp_match(str, pattern, flags) FROM t; ----- -[a] -[A] -[B] -NULL -NULL -NULL -[010] -[Düsseldorf] -[Москва] -[Köln] -[إسرائيل] - -# test string view -statement ok -CREATE TABLE t_stringview AS -SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM t; - -query ? -SELECT regexp_match(str, pattern, flags) FROM t_stringview; ----- -[a] -[A] -[B] -NULL -NULL -NULL -[010] -[Düsseldorf] -[Москва] -[Köln] -[إسرائيل] - -statement ok -DROP TABLE t_stringview; - -query ? -SELECT regexp_match('foobarbequebaz', ''); ----- -[] - -query ? -SELECT regexp_match('', ''); ----- -[] - -query ? -SELECT regexp_match('foobarbequebaz', '(bar)(beque)'); ----- -[bar, beque] - -query ? -SELECT regexp_match('fooBarb -eQuebaz', '(bar).*(que)', 'is'); ----- -[Bar, Que] - -query ? -SELECT regexp_match('foobarbequebaz', '(ba3r)(bequ34e)'); ----- -NULL - -query ? -SELECT regexp_match('foobarbequebaz', '^.*(barbequ[0-9]*e).*$', 'm'); ----- -[barbeque] - -query ? -SELECT regexp_match('aaa-0', '.*-(\d)'); ----- -[0] - -query ? -SELECT regexp_match('bb-1', '.*-(\d)'); ----- -[1] - -query ? -SELECT regexp_match('aa', '.*-(\d)'); ----- -NULL - -query ? -SELECT regexp_match(NULL, '.*-(\d)'); ----- -NULL - -query ? -SELECT regexp_match('aaa-0', NULL); ----- -NULL - -query ? -SELECT regexp_match(null, '.*-(\d)'); ----- -NULL - -query error Error during planning: regexp_match\(\) does not support the "global" option -SELECT regexp_match('bb-1', '.*-(\d)', 'g'); - -query error Error during planning: regexp_match\(\) does not support the "global" option -SELECT regexp_match('bb-1', '.*-(\d)', 'g'); - -query error Arrow error: Compute error: Regular expression did not compile: CompiledTooBig\(10485760\) -SELECT regexp_match('aaaaa', 'a{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}'); - -# look-around is not supported and will just return null -query ? -SELECT regexp_match('(?<=[A-Z]\w )Smith', 'John Smith', 'i'); ----- -NULL - -# ported test -query ? -SELECT regexp_match('aaa-555', '.*-(\d*)'); ----- -[555] - -query B -select 'abc' ~ null; ----- -NULL - -query B -select null ~ null; ----- -NULL - -query B -select null ~ 'abc'; ----- -NULL - -query B -select 'abc' ~* null; ----- -NULL - -query B -select null ~* null; ----- -NULL - -query B -select null ~* 'abc'; ----- -NULL - -query B -select 'abc' !~ null; ----- -NULL - -query B -select null !~ null; ----- -NULL - -query B -select null !~ 'abc'; ----- -NULL - -query B -select 'abc' !~* null; ----- -NULL - -query B -select null !~* null; ----- -NULL - -query B -select null !~* 'abc'; ----- -NULL - -# -# regexp_replace tests -# - -query T -SELECT regexp_replace(str, pattern, 'X', concat('g', flags)) FROM t; ----- -Xbc -X -aXc -AbC -aBC -4000 -X -X -X -X -X - -# test string view -statement ok -CREATE TABLE t_stringview AS -SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM t; - -query T -SELECT regexp_replace(str, pattern, 'X', concat('g', flags)) FROM t_stringview; ----- -Xbc -X -aXc -AbC -aBC -4000 -X -X -X -X -X - -statement ok -DROP TABLE t_stringview; - -query T -SELECT regexp_replace('ABCabcABC', '(abc)', 'X', 'gi'); ----- -XXX - -query T -SELECT regexp_replace('ABCabcABC', '(abc)', 'X', 'i'); ----- -XabcABC - -query T -SELECT regexp_replace('foobarbaz', 'b..', 'X', 'g'); ----- -fooXX - -query T -SELECT regexp_replace('foobarbaz', 'b..', 'X'); ----- -fooXbaz - -query T -SELECT regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g'); ----- -fooXarYXazY - -query T -SELECT regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', NULL); ----- -NULL - -query T -SELECT regexp_replace('foobarbaz', 'b(..)', NULL, 'g'); ----- -NULL - -query T -SELECT regexp_replace('foobarbaz', NULL, 'X\\1Y', 'g'); ----- -NULL - -query T -SELECT regexp_replace('Thomas', '.[mN]a.', 'M'); ----- -ThM - -query T -SELECT regexp_replace(NULL, 'b(..)', 'X\\1Y', 'g'); ----- -NULL - -query T -SELECT regexp_replace('foobar', 'bar', 'xx', 'gi') ----- -fooxx - -query T -SELECT regexp_replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'xx', 'gi') ----- -fooxx - -query TTT -select - regexp_replace(col, NULL, 'c'), - regexp_replace(col, 'a', NULL), - regexp_replace(col, 'a', 'c', NULL) -from (values ('a'), ('b')) as tbl(col); ----- -NULL NULL NULL -NULL NULL NULL - -# multiline string -query B -SELECT 'foo\nbar\nbaz' ~ 'bar'; ----- -true - -statement error -Error during planning: Cannot infer common argument type for regex operation List(Field { name: "item", data_type: Int64, nullable: true, dict_is_ordered: false, metadata -: {} }) ~ List(Field { name: "item", data_type: Int64, nullable: true, dict_is_ordered: false, metadata: {} }) -select [1,2] ~ [3]; - -query B -SELECT 'foo\nbar\nbaz' LIKE '%bar%'; ----- -true - -query B -SELECT NULL LIKE NULL; ----- -NULL - -query B -SELECT NULL iLIKE NULL; ----- -NULL - -query B -SELECT NULL not LIKE NULL; ----- -NULL - -query B -SELECT NULL not iLIKE NULL; ----- -NULL - -# regexp_count tests - -# regexp_count tests from postgresql -# https://github.com/postgres/postgres/blob/56d23855c864b7384970724f3ad93fb0fc319e51/src/test/regress/sql/strings.sql#L226-L235 - -query I -SELECT regexp_count('123123123123123', '(12)3'); ----- -5 - -query I -SELECT regexp_count('123123123123', '123', 1); ----- -4 - -query I -SELECT regexp_count('123123123123', '123', 3); ----- -3 - -query I -SELECT regexp_count('123123123123', '123', 33); ----- -0 - -query I -SELECT regexp_count('ABCABCABCABC', 'Abc', 1, ''); ----- -0 - -query I -SELECT regexp_count('ABCABCABCABC', 'Abc', 1, 'i'); ----- -4 - -statement error -External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based -SELECT regexp_count('123123123123', '123', 0); - -statement error -External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based -SELECT regexp_count('123123123123', '123', -3); - -statement error -External error: statement failed: DataFusion error: Arrow error: Compute error: regexp_count() does not support global flag -SELECT regexp_count('123123123123', '123', 1, 'g'); - -query I -SELECT regexp_count(str, '\w') from t; ----- -3 -3 -3 -3 -3 -4 -4 -10 -6 -4 -7 - -query I -SELECT regexp_count(str, '\w{2}', start) from t; ----- -1 -1 -1 -1 -0 -2 -1 -4 -1 -2 -3 - -query I -SELECT regexp_count(str, 'ab', 1, 'i') from t; ----- -1 -1 -1 -1 -1 -0 -0 -0 -0 -0 -0 - - -query I -SELECT regexp_count(str, pattern) from t; ----- -1 -1 -0 -0 -0 -0 -1 -1 -1 -1 -1 - -query I -SELECT regexp_count(str, pattern, start) from t; ----- -1 -1 -0 -0 -0 -0 -0 -1 -1 -1 -1 - -query I -SELECT regexp_count(str, pattern, start, flags) from t; ----- -1 -1 -1 -0 -0 -0 -0 -1 -1 -1 -1 - -# test type coercion -query I -SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t; ----- -1 -1 -1 -0 -0 -0 -0 -1 -1 -1 -1 - -# test string views - -statement ok -CREATE TABLE t_stringview AS -SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(start, 'Int64') as start, arrow_cast(flags, 'Utf8View') as flags FROM t; - -query I -SELECT regexp_count(str, '\w') from t_stringview; ----- -3 -3 -3 -3 -3 -4 -4 -10 -6 -4 -7 - -query I -SELECT regexp_count(str, '\w{2}', start) from t_stringview; ----- -1 -1 -1 -1 -0 -2 -1 -4 -1 -2 -3 - -query I -SELECT regexp_count(str, 'ab', 1, 'i') from t_stringview; ----- -1 -1 -1 -1 -1 -0 -0 -0 -0 -0 -0 - - -query I -SELECT regexp_count(str, pattern) from t_stringview; ----- -1 -1 -0 -0 -0 -0 -1 -1 -1 -1 -1 - -query I -SELECT regexp_count(str, pattern, start) from t_stringview; ----- -1 -1 -0 -0 -0 -0 -0 -1 -1 -1 -1 - -query I -SELECT regexp_count(str, pattern, start, flags) from t_stringview; ----- -1 -1 -1 -0 -0 -0 -0 -1 -1 -1 -1 - -# test type coercion -query I -SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t_stringview; ----- -1 -1 -1 -0 -0 -0 -0 -1 -1 -1 -1 - -# NULL tests - -query I -SELECT regexp_count(NULL, NULL); ----- -0 - -query I -SELECT regexp_count(NULL, 'a'); ----- -0 - -query I -SELECT regexp_count('a', NULL); ----- -0 - -query I -SELECT regexp_count(NULL, NULL, NULL, NULL); ----- -0 - -statement ok -CREATE TABLE empty_table (str varchar, pattern varchar, start int, flags varchar); - -query I -SELECT regexp_count(str, pattern, start, flags) from empty_table; ----- - -statement ok -INSERT INTO empty_table VALUES ('a', NULL, 1, 'i'), (NULL, 'a', 1, 'i'), (NULL, NULL, 1, 'i'), (NULL, NULL, NULL, 'i'); - -query I -SELECT regexp_count(str, pattern, start, flags) from empty_table; ----- -0 -0 -0 -0 - -statement ok -drop table t; - -statement ok -create or replace table strings as values - ('FooBar'), - ('Foo'), - ('Foo'), - ('Bar'), - ('FooBar'), - ('Bar'), - ('Baz'); - -statement ok -create or replace table dict_table as -select arrow_cast(column1, 'Dictionary(Int32, Utf8)') as column1 -from strings; - -query T -select column1 from dict_table where column1 LIKE '%oo%'; ----- -FooBar -Foo -Foo -FooBar - -query T -select column1 from dict_table where column1 NOT LIKE '%oo%'; ----- -Bar -Bar -Baz - -query T -select column1 from dict_table where column1 ILIKE '%oO%'; ----- -FooBar -Foo -Foo -FooBar - -query T -select column1 from dict_table where column1 NOT ILIKE '%oO%'; ----- -Bar -Bar -Baz - - -# plan should not cast the column, instead it should use the dictionary directly -query TT -explain select column1 from dict_table where column1 LIKE '%oo%'; ----- -logical_plan -01)Filter: dict_table.column1 LIKE Utf8("%oo%") -02)--TableScan: dict_table projection=[column1] -physical_plan -01)CoalesceBatchesExec: target_batch_size=8192 -02)--FilterExec: column1@0 LIKE %oo% -03)----DataSourceExec: partitions=1, partition_sizes=[1] - -# Ensure casting / coercion works for all operators -# (there should be no casts to Utf8) -query TT -explain select - column1 LIKE '%oo%', - column1 NOT LIKE '%oo%', - column1 ILIKE '%oo%', - column1 NOT ILIKE '%oo%' -from dict_table; ----- -logical_plan -01)Projection: dict_table.column1 LIKE Utf8("%oo%"), dict_table.column1 NOT LIKE Utf8("%oo%"), dict_table.column1 ILIKE Utf8("%oo%"), dict_table.column1 NOT ILIKE Utf8("%oo%") -02)--TableScan: dict_table projection=[column1] -physical_plan -01)ProjectionExec: expr=[column1@0 LIKE %oo% as dict_table.column1 LIKE Utf8("%oo%"), column1@0 NOT LIKE %oo% as dict_table.column1 NOT LIKE Utf8("%oo%"), column1@0 ILIKE %oo% as dict_table.column1 ILIKE Utf8("%oo%"), column1@0 NOT ILIKE %oo% as dict_table.column1 NOT ILIKE Utf8("%oo%")] -02)--DataSourceExec: partitions=1, partition_sizes=[1] - -statement ok -drop table strings - -statement ok -drop table dict_table diff --git a/datafusion/sqllogictest/test_files/regexp/README.md b/datafusion/sqllogictest/test_files/regexp/README.md new file mode 100644 index 000000000000..7e5efc5b5ddf --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/README.md @@ -0,0 +1,59 @@ + + +# Regexp Test Files + +This directory contains test files for regular expression (regexp) functions in DataFusion. + +## Directory Structure + +``` +regexp/ + - init_data.slt.part // Shared test data for regexp functions + - regexp_like.slt // Tests for regexp_like function + - regexp_count.slt // Tests for regexp_count function + - regexp_match.slt // Tests for regexp_match function + - regexp_replace.slt // Tests for regexp_replace function +``` + +## Tested Functions + +1. `regexp_like`: Check if a string matches a regular expression +2. `regexp_count`: Count occurrences of a pattern in a string +3. `regexp_match`: Extract matching substrings +4. `regexp_replace`: Replace matched substrings + +## Test Data + +Test data is centralized in the `init_data.slt.part` file and imported into each test file using the `include` directive. This approach ensures: + +Consistent test data across different regexp function tests +Easy maintenance of test data +Reduced duplication + +## Test Coverage + +Each test file covers: + +Basic functionality +Case-insensitive matching +Null handling +Start position tests +Capture group handling +Different string types (UTF-8, Unicode) diff --git a/datafusion/sqllogictest/test_files/regexp/init_data.slt.part b/datafusion/sqllogictest/test_files/regexp/init_data.slt.part new file mode 100644 index 000000000000..ed6fb0e872df --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/init_data.slt.part @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +create table regexp_test_data (str varchar, pattern varchar, start int, flags varchar) as values + (NULL, '^(a)', 1, 'i'), + ('abc', '^(a)', 1, 'i'), + ('ABC', '^(A).*', 1, 'i'), + ('aBc', '(b|d)', 1, 'i'), + ('AbC', '(B|D)', 2, null), + ('aBC', '^(b|c)', 3, null), + ('4000', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 1, null), + ('4010', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 2, null), + ('Düsseldorf','[\p{Letter}-]+', 3, null), + ('Москва', '[\p{L}-]+', 4, null), + ('Köln', '[a-zA-Z]ö[a-zA-Z]{2}', 1, null), + ('إسرائيل', '^\p{Arabic}+$', 2, null); diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_count.slt b/datafusion/sqllogictest/test_files/regexp/regexp_count.slt new file mode 100644 index 000000000000..d842a1ee81df --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/regexp_count.slt @@ -0,0 +1,344 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Import common test data +include ./init_data.slt.part + +# regexp_count tests from postgresql +# https://github.com/postgres/postgres/blob/56d23855c864b7384970724f3ad93fb0fc319e51/src/test/regress/sql/strings.sql#L226-L235 + +query I +SELECT regexp_count('123123123123123', '(12)3'); +---- +5 + +query I +SELECT regexp_count('123123123123', '123', 1); +---- +4 + +query I +SELECT regexp_count('123123123123', '123', 3); +---- +3 + +query I +SELECT regexp_count('123123123123', '123', 33); +---- +0 + +query I +SELECT regexp_count('ABCABCABCABC', 'Abc', 1, ''); +---- +0 + +query I +SELECT regexp_count('ABCABCABCABC', 'Abc', 1, 'i'); +---- +4 + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based +SELECT regexp_count('123123123123', '123', 0); + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based +SELECT regexp_count('123123123123', '123', -3); + +statement error +External error: statement failed: DataFusion error: Arrow error: Compute error: regexp_count() does not support global flag +SELECT regexp_count('123123123123', '123', 1, 'g'); + +query I +SELECT regexp_count(str, '\w') from regexp_test_data; +---- +0 +3 +3 +3 +3 +3 +4 +4 +10 +6 +4 +7 + +query I +SELECT regexp_count(str, '\w{2}', start) from regexp_test_data; +---- +0 +1 +1 +1 +1 +0 +2 +1 +4 +1 +2 +3 + +query I +SELECT regexp_count(str, 'ab', 1, 'i') from regexp_test_data; +---- +0 +1 +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 + + +query I +SELECT regexp_count(str, pattern) from regexp_test_data; +---- +0 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start) from regexp_test_data; +---- +0 +1 +1 +0 +0 +0 +0 +0 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start, flags) from regexp_test_data; +---- +0 +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test type coercion +query I +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from regexp_test_data; +---- +0 +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test string views + +statement ok +CREATE TABLE t_stringview AS +SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(start, 'Int64') as start, arrow_cast(flags, 'Utf8View') as flags FROM regexp_test_data; + +query I +SELECT regexp_count(str, '\w') from t_stringview; +---- +0 +3 +3 +3 +3 +3 +4 +4 +10 +6 +4 +7 + +query I +SELECT regexp_count(str, '\w{2}', start) from t_stringview; +---- +0 +1 +1 +1 +1 +0 +2 +1 +4 +1 +2 +3 + +query I +SELECT regexp_count(str, 'ab', 1, 'i') from t_stringview; +---- +0 +1 +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 + + +query I +SELECT regexp_count(str, pattern) from t_stringview; +---- +0 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start) from t_stringview; +---- +0 +1 +1 +0 +0 +0 +0 +0 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start, flags) from t_stringview; +---- +0 +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test type coercion +query I +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t_stringview; +---- +0 +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# NULL tests + +query I +SELECT regexp_count(NULL, NULL); +---- +0 + +query I +SELECT regexp_count(NULL, 'a'); +---- +0 + +query I +SELECT regexp_count('a', NULL); +---- +0 + +query I +SELECT regexp_count(NULL, NULL, NULL, NULL); +---- +0 + +statement ok +CREATE TABLE empty_table (str varchar, pattern varchar, start int, flags varchar); + +query I +SELECT regexp_count(str, pattern, start, flags) from empty_table; +---- + +statement ok +INSERT INTO empty_table VALUES ('a', NULL, 1, 'i'), (NULL, 'a', 1, 'i'), (NULL, NULL, 1, 'i'), (NULL, NULL, NULL, 'i'); + +query I +SELECT regexp_count(str, pattern, start, flags) from empty_table; +---- +0 +0 +0 +0 + +statement ok +drop table t_stringview; + +statement ok +drop table empty_table; diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_like.slt b/datafusion/sqllogictest/test_files/regexp/regexp_like.slt new file mode 100644 index 000000000000..223ef22b9861 --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/regexp_like.slt @@ -0,0 +1,279 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Import common test data +include ./init_data.slt.part + +query B +SELECT regexp_like(str, pattern, flags) FROM regexp_test_data; +---- +NULL +true +true +true +false +false +false +true +true +true +true +true + +query B +SELECT str ~ NULL FROM regexp_test_data; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +query B +select str ~ right('foo', NULL) FROM regexp_test_data; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +query B +select right('foo', NULL) !~ str FROM regexp_test_data; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +query B +SELECT regexp_like('foobarbequebaz', ''); +---- +true + +query B +SELECT regexp_like('', ''); +---- +true + +query B +SELECT regexp_like('foobarbequebaz', '(bar)(beque)'); +---- +true + +query B +SELECT regexp_like('fooBarbeQuebaz', '(bar).*(que)', 'is'); +---- +true + +query B +SELECT regexp_like('foobarbequebaz', '(ba3r)(bequ34e)'); +---- +false + +query B +SELECT regexp_like('foobarbequebaz', '^.*(barbequ[0-9]*e).*$', 'm'); +---- +true + +query B +SELECT regexp_like('aaa-0', '.*-(\d)'); +---- +true + +query B +SELECT regexp_like('bb-1', '.*-(\d)'); +---- +true + +query B +SELECT regexp_like('aa', '.*-(\d)'); +---- +false + +query B +SELECT regexp_like(NULL, '.*-(\d)'); +---- +NULL + +query B +SELECT regexp_like('aaa-0', NULL); +---- +NULL + +query B +SELECT regexp_like(null, '.*-(\d)'); +---- +NULL + +query error Error during planning: regexp_like\(\) does not support the "global" option +SELECT regexp_like('bb-1', '.*-(\d)', 'g'); + +query error Error during planning: regexp_like\(\) does not support the "global" option +SELECT regexp_like('bb-1', '.*-(\d)', 'g'); + +query error Arrow error: Compute error: Regular expression did not compile: CompiledTooBig\(10485760\) +SELECT regexp_like('aaaaa', 'a{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}'); + +# look-around is not supported and will just return false +query B +SELECT regexp_like('(?<=[A-Z]\w )Smith', 'John Smith', 'i'); +---- +false + +query B +select regexp_like('aaa-555', '.*-(\d*)'); +---- +true + +# multiline string +query B +SELECT 'foo\nbar\nbaz' ~ 'bar'; +---- +true + +statement error +Error during planning: Cannot infer common argument type for regex operation List(Field { name: "item", data_type: Int64, nullable: true, metadata: {} }) ~ List(Field { name: "item", data_type: Int64, nullable: true, metadata: {} }) +select [1,2] ~ [3]; + +query B +SELECT 'foo\nbar\nbaz' LIKE '%bar%'; +---- +true + +query B +SELECT NULL LIKE NULL; +---- +NULL + +query B +SELECT NULL iLIKE NULL; +---- +NULL + +query B +SELECT NULL not LIKE NULL; +---- +NULL + +query B +SELECT NULL not iLIKE NULL; +---- +NULL + +statement ok +create or replace table strings as values + ('FooBar'), + ('Foo'), + ('Foo'), + ('Bar'), + ('FooBar'), + ('Bar'), + ('Baz'); + +statement ok +create or replace table dict_table as +select arrow_cast(column1, 'Dictionary(Int32, Utf8)') as column1 +from strings; + +query T +select column1 from dict_table where column1 LIKE '%oo%'; +---- +FooBar +Foo +Foo +FooBar + +query T +select column1 from dict_table where column1 NOT LIKE '%oo%'; +---- +Bar +Bar +Baz + +query T +select column1 from dict_table where column1 ILIKE '%oO%'; +---- +FooBar +Foo +Foo +FooBar + +query T +select column1 from dict_table where column1 NOT ILIKE '%oO%'; +---- +Bar +Bar +Baz + + +# plan should not cast the column, instead it should use the dictionary directly +query TT +explain select column1 from dict_table where column1 LIKE '%oo%'; +---- +logical_plan +01)Filter: dict_table.column1 LIKE Utf8("%oo%") +02)--TableScan: dict_table projection=[column1] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: column1@0 LIKE %oo% +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +# Ensure casting / coercion works for all operators +# (there should be no casts to Utf8) +query TT +explain select + column1 LIKE '%oo%', + column1 NOT LIKE '%oo%', + column1 ILIKE '%oo%', + column1 NOT ILIKE '%oo%' +from dict_table; +---- +logical_plan +01)Projection: dict_table.column1 LIKE Utf8("%oo%"), dict_table.column1 NOT LIKE Utf8("%oo%"), dict_table.column1 ILIKE Utf8("%oo%"), dict_table.column1 NOT ILIKE Utf8("%oo%") +02)--TableScan: dict_table projection=[column1] +physical_plan +01)ProjectionExec: expr=[column1@0 LIKE %oo% as dict_table.column1 LIKE Utf8("%oo%"), column1@0 NOT LIKE %oo% as dict_table.column1 NOT LIKE Utf8("%oo%"), column1@0 ILIKE %oo% as dict_table.column1 ILIKE Utf8("%oo%"), column1@0 NOT ILIKE %oo% as dict_table.column1 NOT ILIKE Utf8("%oo%")] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +drop table strings + +statement ok +drop table dict_table diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_match.slt b/datafusion/sqllogictest/test_files/regexp/regexp_match.slt new file mode 100644 index 000000000000..e79af4774aa2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/regexp_match.slt @@ -0,0 +1,201 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Import common test data +include ./init_data.slt.part + +query ? +SELECT regexp_match(str, pattern, flags) FROM regexp_test_data; +---- +NULL +[a] +[A] +[B] +NULL +NULL +NULL +[010] +[Düsseldorf] +[Москва] +[Köln] +[إسرائيل] + +# test string view +statement ok +CREATE TABLE t_stringview AS +SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM regexp_test_data; + +query ? +SELECT regexp_match(str, pattern, flags) FROM t_stringview; +---- +NULL +[a] +[A] +[B] +NULL +NULL +NULL +[010] +[Düsseldorf] +[Москва] +[Köln] +[إسرائيل] + +statement ok +DROP TABLE t_stringview; + +query ? +SELECT regexp_match('foobarbequebaz', ''); +---- +[] + +query ? +SELECT regexp_match('', ''); +---- +[] + +query ? +SELECT regexp_match('foobarbequebaz', '(bar)(beque)'); +---- +[bar, beque] + +query ? +SELECT regexp_match('fooBarb +eQuebaz', '(bar).*(que)', 'is'); +---- +[Bar, Que] + +query ? +SELECT regexp_match('foobarbequebaz', '(ba3r)(bequ34e)'); +---- +NULL + +query ? +SELECT regexp_match('foobarbequebaz', '^.*(barbequ[0-9]*e).*$', 'm'); +---- +[barbeque] + +query ? +SELECT regexp_match('aaa-0', '.*-(\d)'); +---- +[0] + +query ? +SELECT regexp_match('bb-1', '.*-(\d)'); +---- +[1] + +query ? +SELECT regexp_match('aa', '.*-(\d)'); +---- +NULL + +query ? +SELECT regexp_match(NULL, '.*-(\d)'); +---- +NULL + +query ? +SELECT regexp_match('aaa-0', NULL); +---- +NULL + +query ? +SELECT regexp_match(null, '.*-(\d)'); +---- +NULL + +query error Error during planning: regexp_match\(\) does not support the "global" option +SELECT regexp_match('bb-1', '.*-(\d)', 'g'); + +query error Error during planning: regexp_match\(\) does not support the "global" option +SELECT regexp_match('bb-1', '.*-(\d)', 'g'); + +query error Arrow error: Compute error: Regular expression did not compile: CompiledTooBig\(10485760\) +SELECT regexp_match('aaaaa', 'a{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}{5}'); + +# look-around is not supported and will just return null +query ? +SELECT regexp_match('(?<=[A-Z]\w )Smith', 'John Smith', 'i'); +---- +NULL + +# ported test +query ? +SELECT regexp_match('aaa-555', '.*-(\d*)'); +---- +[555] + +query B +select 'abc' ~ null; +---- +NULL + +query B +select null ~ null; +---- +NULL + +query B +select null ~ 'abc'; +---- +NULL + +query B +select 'abc' ~* null; +---- +NULL + +query B +select null ~* null; +---- +NULL + +query B +select null ~* 'abc'; +---- +NULL + +query B +select 'abc' !~ null; +---- +NULL + +query B +select null !~ null; +---- +NULL + +query B +select null !~ 'abc'; +---- +NULL + +query B +select 'abc' !~* null; +---- +NULL + +query B +select null !~* null; +---- +NULL + +query B +select null !~* 'abc'; +---- +NULL diff --git a/datafusion/sqllogictest/test_files/regexp/regexp_replace.slt b/datafusion/sqllogictest/test_files/regexp/regexp_replace.slt new file mode 100644 index 000000000000..a16801adcef7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/regexp/regexp_replace.slt @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Import common test data +include ./init_data.slt.part + +query T +SELECT regexp_replace(str, pattern, 'X', concat('g', flags)) FROM regexp_test_data; +---- +NULL +Xbc +X +aXc +AbC +aBC +4000 +X +X +X +X +X + +# test string view +statement ok +CREATE TABLE t_stringview AS +SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM regexp_test_data; + +query T +SELECT regexp_replace(str, pattern, 'X', concat('g', flags)) FROM t_stringview; +---- +NULL +Xbc +X +aXc +AbC +aBC +4000 +X +X +X +X +X + +statement ok +DROP TABLE t_stringview; + +query T +SELECT regexp_replace('ABCabcABC', '(abc)', 'X', 'gi'); +---- +XXX + +query T +SELECT regexp_replace('ABCabcABC', '(abc)', 'X', 'i'); +---- +XabcABC + +query T +SELECT regexp_replace('foobarbaz', 'b..', 'X', 'g'); +---- +fooXX + +query T +SELECT regexp_replace('foobarbaz', 'b..', 'X'); +---- +fooXbaz + +query T +SELECT regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g'); +---- +fooXarYXazY + +query T +SELECT regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', NULL); +---- +NULL + +query T +SELECT regexp_replace('foobarbaz', 'b(..)', NULL, 'g'); +---- +NULL + +query T +SELECT regexp_replace('foobarbaz', NULL, 'X\\1Y', 'g'); +---- +NULL + +query T +SELECT regexp_replace('Thomas', '.[mN]a.', 'M'); +---- +ThM + +query T +SELECT regexp_replace(NULL, 'b(..)', 'X\\1Y', 'g'); +---- +NULL + +query T +SELECT regexp_replace('foobar', 'bar', 'xx', 'gi') +---- +fooxx + +query T +SELECT regexp_replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'xx', 'gi') +---- +fooxx + +query TTT +select + regexp_replace(col, NULL, 'c'), + regexp_replace(col, 'a', NULL), + regexp_replace(col, 'a', 'c', NULL) +from (values ('a'), ('b')) as tbl(col); +---- +NULL NULL NULL +NULL NULL NULL diff --git a/datafusion/sqllogictest/test_files/repartition.slt b/datafusion/sqllogictest/test_files/repartition.slt index 70666346e2ca..29d20d10b671 100644 --- a/datafusion/sqllogictest/test_files/repartition.slt +++ b/datafusion/sqllogictest/test_files/repartition.slt @@ -46,8 +46,8 @@ physical_plan 01)AggregateExec: mode=FinalPartitioned, gby=[column1@0 as column1], aggr=[sum(parquet_table.column2)] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----RepartitionExec: partitioning=Hash([column1@0], 4), input_partitions=4 -04)------AggregateExec: mode=Partial, gby=[column1@0 as column1], aggr=[sum(parquet_table.column2)] -05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[column1@0 as column1], aggr=[sum(parquet_table.column2)] 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition/parquet_table/2.parquet]]}, projection=[column1, column2], file_type=parquet # disable round robin repartitioning diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 2b30de572c8c..0b851f917855 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -61,7 +61,7 @@ logical_plan physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 02)--FilterExec: column1@0 != 42 -03)----DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..137], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:137..274], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:274..411], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:411..547]]}, projection=[column1], file_type=parquet, predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] +03)----DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..141], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:141..282], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:282..423], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:423..563]]}, projection=[column1], file_type=parquet, predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] # disable round robin repartitioning statement ok @@ -77,7 +77,7 @@ logical_plan physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 02)--FilterExec: column1@0 != 42 -03)----DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..137], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:137..274], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:274..411], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:411..547]]}, projection=[column1], file_type=parquet, predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] +03)----DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..141], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:141..282], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:282..423], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:423..563]]}, projection=[column1], file_type=parquet, predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] # enable round robin repartitioning again statement ok @@ -102,7 +102,7 @@ physical_plan 02)--SortExec: expr=[column1@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=8192 04)------FilterExec: column1@0 != 42 -05)--------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..272], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:272..538, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..278], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:278..547]]}, projection=[column1], file_type=parquet, predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] +05)--------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..280], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:280..554, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..286], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:286..563]]}, projection=[column1], file_type=parquet, predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] ## Read the files as though they are ordered @@ -138,7 +138,7 @@ physical_plan 01)SortPreservingMergeExec: [column1@0 ASC NULLS LAST] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----FilterExec: column1@0 != 42 -04)------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..269], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..273], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:273..547], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:269..538]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], file_type=parquet, predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] +04)------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..277], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:281..563], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:277..554]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], file_type=parquet, predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] # Cleanup statement ok diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index f583d659fd4f..ca0b472de9e0 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1832,7 +1832,7 @@ query TT EXPLAIN SELECT letter, letter = LEFT('APACHE', 1) FROM simple_string; ---- logical_plan -01)Projection: simple_string.letter, simple_string.letter = Utf8("A") AS simple_string.letter = left(Utf8("APACHE"),Int64(1)) +01)Projection: simple_string.letter, simple_string.letter = Utf8View("A") AS simple_string.letter = left(Utf8("APACHE"),Int64(1)) 02)--TableScan: simple_string projection=[letter] physical_plan 01)ProjectionExec: expr=[letter@0 as letter, letter@0 = A as simple_string.letter = left(Utf8("APACHE"),Int64(1))] @@ -1851,10 +1851,10 @@ query TT EXPLAIN SELECT letter, letter = LEFT(letter2, 1) FROM simple_string; ---- logical_plan -01)Projection: simple_string.letter, simple_string.letter = left(simple_string.letter2, Int64(1)) +01)Projection: simple_string.letter, simple_string.letter = CAST(left(simple_string.letter2, Int64(1)) AS Utf8View) 02)--TableScan: simple_string projection=[letter, letter2] physical_plan -01)ProjectionExec: expr=[letter@0 as letter, letter@0 = left(letter2@1, 1) as simple_string.letter = left(simple_string.letter2,Int64(1))] +01)ProjectionExec: expr=[letter@0 as letter, letter@0 = CAST(left(letter2@1, 1) AS Utf8View) as simple_string.letter = left(simple_string.letter2,Int64(1))] 02)--DataSourceExec: partitions=1, partition_sizes=[1] query TB diff --git a/datafusion/sqllogictest/test_files/simplify_expr.slt b/datafusion/sqllogictest/test_files/simplify_expr.slt index 9985ab49c2da..c77163dc996d 100644 --- a/datafusion/sqllogictest/test_files/simplify_expr.slt +++ b/datafusion/sqllogictest/test_files/simplify_expr.slt @@ -35,22 +35,22 @@ query TT explain select b from t where b ~ '.*' ---- logical_plan -01)Filter: t.b IS NOT NULL +01)Filter: t.b ~ Utf8View(".*") 02)--TableScan: t projection=[b] physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 -02)--FilterExec: b@0 IS NOT NULL +02)--FilterExec: b@0 ~ .* 03)----DataSourceExec: partitions=1, partition_sizes=[1] query TT explain select b from t where b !~ '.*' ---- logical_plan -01)Filter: t.b = Utf8("") +01)Filter: t.b !~ Utf8View(".*") 02)--TableScan: t projection=[b] physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 -02)--FilterExec: b@0 = +02)--FilterExec: b@0 !~ .* 03)----DataSourceExec: partitions=1, partition_sizes=[1] query T @@ -107,4 +107,3 @@ query B SELECT a / NULL::DECIMAL(4,3) > 1.2::decimal(2,1) FROM VALUES (1) AS t(a); ---- NULL - diff --git a/datafusion/sqllogictest/test_files/simplify_predicates.slt b/datafusion/sqllogictest/test_files/simplify_predicates.slt new file mode 100644 index 000000000000..0dd551d96d0c --- /dev/null +++ b/datafusion/sqllogictest/test_files/simplify_predicates.slt @@ -0,0 +1,234 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Test cases for predicate simplification feature +# Basic redundant comparison simplification + +statement ok +set datafusion.explain.logical_plan_only=true; + +statement ok +CREATE TABLE test_data ( + int_col INT, + float_col FLOAT, + str_col VARCHAR, + date_col DATE, + bool_col BOOLEAN +); + +# x > 5 AND x > 6 should simplify to x > 6 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col > 5 AND int_col > 6; +---- +logical_plan +01)Filter: test_data.int_col > Int32(6) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x > 5 AND x >= 6 should simplify to x >= 6 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col > 5 AND int_col >= 6; +---- +logical_plan +01)Filter: test_data.int_col >= Int32(6) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x < 10 AND x <= 8 should simplify to x <= 8 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col < 10 AND int_col <= 8; +---- +logical_plan +01)Filter: test_data.int_col <= Int32(8) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x > 5 AND x > 6 AND x > 7 should simplify to x > 7 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col > 5 AND int_col > 6 AND int_col > 7; +---- +logical_plan +01)Filter: test_data.int_col > Int32(7) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x > 5 AND y < 10 AND x > 6 AND y < 8 should simplify to x > 6 AND y < 8 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col > 5 AND float_col < 10 AND int_col > 6 AND float_col < 8; +---- +logical_plan +01)Filter: test_data.float_col < Float32(8) AND test_data.int_col > Int32(6) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x = 7 AND x = 7 should simplify to x = 7 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col = 7 AND int_col = 7; +---- +logical_plan +01)Filter: test_data.int_col = Int32(7) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x = 7 AND x = 6 should simplify to false +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col = 7 AND int_col = 6; +---- +logical_plan EmptyRelation + +# TODO: x = 7 AND x < 2 should simplify to false +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col = 7 AND int_col < 2; +---- +logical_plan +01)Filter: test_data.int_col = Int32(7) AND test_data.int_col < Int32(2) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + + +# TODO: x = 7 AND x > 5 should simplify to x = 7 +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col = 7 AND int_col > 5; +---- +logical_plan +01)Filter: test_data.int_col = Int32(7) AND test_data.int_col > Int32(5) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# str_col > 'apple' AND str_col > 'banana' should simplify to str_col > 'banana' +query TT +EXPLAIN SELECT * FROM test_data WHERE str_col > 'apple' AND str_col > 'banana'; +---- +logical_plan +01)Filter: test_data.str_col > Utf8View("banana") +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# date_col > '2023-01-01' AND date_col > '2023-02-01' should simplify to date_col > '2023-02-01' +query TT +EXPLAIN SELECT * FROM test_data WHERE date_col > '2023-01-01' AND date_col > '2023-02-01'; +---- +logical_plan +01)Filter: test_data.date_col > Date32("2023-02-01") +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +query TT +EXPLAIN SELECT * FROM test_data WHERE bool_col = true AND bool_col = false; +---- +logical_plan +01)Filter: test_data.bool_col AND NOT test_data.bool_col +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + + +# This shouldn't be simplified since they're different relationships +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col > float_col AND int_col > 5; +---- +logical_plan +01)Filter: CAST(test_data.int_col AS Float32) > test_data.float_col AND test_data.int_col > Int32(5) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# Should simplify the int_col predicates but preserve the others +query TT +EXPLAIN SELECT * FROM test_data +WHERE int_col > 5 + AND int_col > 10 + AND str_col LIKE 'A%' + AND float_col BETWEEN 1 AND 100; +---- +logical_plan +01)Filter: test_data.str_col LIKE Utf8View("A%") AND test_data.float_col >= Float32(1) AND test_data.float_col <= Float32(100) AND test_data.int_col > Int32(10) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +statement ok +CREATE TABLE test_data2 ( + id INT, + value INT +); + +query TT +EXPLAIN SELECT t1.int_col, t2.value +FROM test_data t1 +JOIN test_data2 t2 ON t1.int_col = t2.id +WHERE t1.int_col > 5 + AND t1.int_col > 10 + AND t2.value < 100 + AND t2.value < 50; +---- +logical_plan +01)Projection: t1.int_col, t2.value +02)--Inner Join: t1.int_col = t2.id +03)----SubqueryAlias: t1 +04)------Filter: test_data.int_col > Int32(10) +05)--------TableScan: test_data projection=[int_col] +06)----SubqueryAlias: t2 +07)------Filter: test_data2.value < Int32(50) AND test_data2.id > Int32(10) +08)--------TableScan: test_data2 projection=[id, value] + +# Handling negated predicates +# NOT (x < 10) AND NOT (x < 5) should simplify to NOT (x < 10) +query TT +EXPLAIN SELECT * FROM test_data WHERE NOT (int_col < 10) AND NOT (int_col < 5); +---- +logical_plan +01)Filter: test_data.int_col >= Int32(10) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x > 5 AND x < 10 should be preserved (can't be simplified) +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col > 5 AND int_col < 10; +---- +logical_plan +01)Filter: test_data.int_col > Int32(5) AND test_data.int_col < Int32(10) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# 5 < x AND 3 < x should simplify to 5 < x +query TT +EXPLAIN SELECT * FROM test_data WHERE 5 < int_col AND 3 < int_col; +---- +logical_plan +01)Filter: test_data.int_col > Int32(5) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# CAST(x AS FLOAT) > 5.0 AND CAST(x AS FLOAT) > 6.0 should simplify +query TT +EXPLAIN SELECT * FROM test_data WHERE CAST(int_col AS FLOAT) > 5.0 AND CAST(int_col AS FLOAT) > 6.0; +---- +logical_plan +01)Filter: CAST(CAST(test_data.int_col AS Float32) AS Float64) > Float64(6) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# x = 5 AND x = 6 (logically impossible) +query TT +EXPLAIN SELECT * FROM test_data WHERE int_col = 5 AND int_col = 6; +---- +logical_plan EmptyRelation + +# (x > 5 OR y < 10) AND (x > 6 OR y < 8) +# This is more complex but could still benefit from some simplification +query TT +EXPLAIN SELECT * FROM test_data +WHERE (int_col > 5 OR float_col < 10) + AND (int_col > 6 OR float_col < 8); +---- +logical_plan +01)Filter: (test_data.int_col > Int32(5) OR test_data.float_col < Float32(10)) AND (test_data.int_col > Int32(6) OR test_data.float_col < Float32(8)) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +# Combination of AND and OR with simplifiable predicates +query TT +EXPLAIN SELECT * FROM test_data +WHERE (int_col > 5 AND int_col > 6) + OR (float_col < 10 AND float_col < 8); +---- +logical_plan +01)Filter: test_data.int_col > Int32(5) AND test_data.int_col > Int32(6) OR test_data.float_col < Float32(10) AND test_data.float_col < Float32(8) +02)--TableScan: test_data projection=[int_col, float_col, str_col, date_col, bool_col] + +statement ok +set datafusion.explain.logical_plan_only=false; diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 162c9a17b61f..c17fe8dfc7e6 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -695,6 +695,144 @@ select t2.* from t1 right anti join t2 on t1.a = t2.a and t1.b = t2.b ---- 51 54 +# RIGHTSEMI join tests + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b = t2.b +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select null a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b) + select t2.* from t1 right semi join t1 t2 on t1.a = t2.a and t1.b = t2.b +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b) + select t2.* from t1 right semi join t1 t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select null a, 12 b union all + select 11 a, 13 b) + select t2.* from t1 right semi join t1 t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 14 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 14 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 12 b union all + select 11 a, 14 b + ), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 13 + +# Test RIGHTSEMI with cross batch data distribution + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b union all + select 12 a, 14 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 14 b union all + select 12 a, 15 b + ) + select t2.* from t1 right semi join t2 on t1.a = t2.a and t1.b != t2.b +) order by 1, 2; +---- +11 12 +11 14 +12 15 + # return sql params back to default values statement ok set datafusion.optimizer.prefer_hash_join = true; diff --git a/datafusion/sqllogictest/test_files/spark/README.md b/datafusion/sqllogictest/test_files/spark/README.md new file mode 100644 index 000000000000..cffd28009889 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/README.md @@ -0,0 +1,67 @@ + + +# Spark Test Files + +This directory contains test files for the `spark` test suite. + +## RoadMap + +Implementing the `datafusion-spark` compatible functions project is still a work in progress. +Many of the tests in this directory are commented out and are waiting for help with implementation. + +For more information please see: + +- [The `datafusion-spark` Epic](https://github.com/apache/datafusion/issues/15914) +- [Spark Test Generation Script] (https://github.com/apache/datafusion/pull/16409#issuecomment-2972618052) + +## Testing Guide + +When testing Spark functions: + +- Functions must be tested on both `Scalar` and `Array` inputs +- Test cases should only contain `SELECT` statements with the function being tested +- Add explicit casts to input values to ensure the correct data type is used (e.g., `0::INT`) + - Explicit casting is necessary because DataFusion and Spark do not infer data types in the same way + +### Finding Test Cases + +To verify and compare function behavior at a minimum, you can refer to the following documentation sources: + +1. Databricks SQL Function Reference: + https://docs.databricks.com/aws/en/sql/language-manual/functions/NAME +2. Apache Spark SQL Function Reference: + https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.NAME.html +3. PySpark SQL Function Reference: + https://spark.apache.org/docs/latest/api/sql/#NAME + +**Note:** Replace `NAME` in each URL with the actual function name (e.g., for the `ASCII` function, use `ascii` instead +of `NAME`). + +### Scalar Example: + +```sql +SELECT expm1(0::INT); +``` + +### Array Example: + +```sql +SELECT expm1(a) FROM (VALUES (0::INT), (1::INT)) AS t(a); +``` diff --git a/datafusion/sqllogictest/test_files/spark/array/array.slt b/datafusion/sqllogictest/test_files/spark/array/array.slt new file mode 100644 index 000000000000..57aa080be393 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/array/array.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT array(1, 2, 3); +## PySpark 3.5.5 Result: {'array(1, 2, 3)': [1, 2, 3], 'typeof(array(1, 2, 3))': 'array', 'typeof(1)': 'int', 'typeof(2)': 'int', 'typeof(3)': 'int'} +#query +#SELECT array(1::int, 2::int, 3::int); diff --git a/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt b/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt new file mode 100644 index 000000000000..544c39608f33 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT array_repeat('123', 2); +## PySpark 3.5.5 Result: {'array_repeat(123, 2)': ['123', '123'], 'typeof(array_repeat(123, 2))': 'array', 'typeof(123)': 'string', 'typeof(2)': 'int'} +#query +#SELECT array_repeat('123'::string, 2::int); diff --git a/datafusion/sqllogictest/test_files/spark/array/sequence.slt b/datafusion/sqllogictest/test_files/spark/array/sequence.slt new file mode 100644 index 000000000000..bb4aa06bfd25 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/array/sequence.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT sequence(1, 5); +## PySpark 3.5.5 Result: {'sequence(1, 5)': [1, 2, 3, 4, 5], 'typeof(sequence(1, 5))': 'array', 'typeof(1)': 'int', 'typeof(5)': 'int'} +#query +#SELECT sequence(1::int, 5::int); + +## Original Query: SELECT sequence(5, 1); +## PySpark 3.5.5 Result: {'sequence(5, 1)': [5, 4, 3, 2, 1], 'typeof(sequence(5, 1))': 'array', 'typeof(5)': 'int', 'typeof(1)': 'int'} +#query +#SELECT sequence(5::int, 1::int); diff --git a/datafusion/sqllogictest/test_files/spark/bitwise/bit_count.slt b/datafusion/sqllogictest/test_files/spark/bitwise/bit_count.slt new file mode 100644 index 000000000000..dfbf1f1ff5d4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/bitwise/bit_count.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT bit_count(0); +## PySpark 3.5.5 Result: {'bit_count(0)': 0, 'typeof(bit_count(0))': 'int', 'typeof(0)': 'int'} +#query +#SELECT bit_count(0::int); + diff --git a/datafusion/sqllogictest/test_files/spark/bitwise/bit_get.slt b/datafusion/sqllogictest/test_files/spark/bitwise/bit_get.slt new file mode 100644 index 000000000000..9086d5651e98 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/bitwise/bit_get.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT bit_get(11, 0); +## PySpark 3.5.5 Result: {'bit_get(11, 0)': 1, 'typeof(bit_get(11, 0))': 'tinyint', 'typeof(11)': 'int', 'typeof(0)': 'int'} +#query +#SELECT bit_get(11::int, 0::int); + +## Original Query: SELECT bit_get(11, 2); +## PySpark 3.5.5 Result: {'bit_get(11, 2)': 0, 'typeof(bit_get(11, 2))': 'tinyint', 'typeof(11)': 'int', 'typeof(2)': 'int'} +#query +#SELECT bit_get(11::int, 2::int); + diff --git a/datafusion/sqllogictest/test_files/spark/bitwise/getbit.slt b/datafusion/sqllogictest/test_files/spark/bitwise/getbit.slt new file mode 100644 index 000000000000..6f25603c5b4a --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/bitwise/getbit.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT getbit(11, 0); +## PySpark 3.5.5 Result: {'getbit(11, 0)': 1, 'typeof(getbit(11, 0))': 'tinyint', 'typeof(11)': 'int', 'typeof(0)': 'int'} +#query +#SELECT getbit(11::int, 0::int); + +## Original Query: SELECT getbit(11, 2); +## PySpark 3.5.5 Result: {'getbit(11, 2)': 0, 'typeof(getbit(11, 2))': 'tinyint', 'typeof(11)': 'int', 'typeof(2)': 'int'} +#query +#SELECT getbit(11::int, 2::int); + diff --git a/datafusion/sqllogictest/test_files/spark/bitwise/shiftright.slt b/datafusion/sqllogictest/test_files/spark/bitwise/shiftright.slt new file mode 100644 index 000000000000..8c8cc366ffb6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/bitwise/shiftright.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT shiftright(4, 1); +## PySpark 3.5.5 Result: {'shiftright(4, 1)': 2, 'typeof(shiftright(4, 1))': 'int', 'typeof(4)': 'int', 'typeof(1)': 'int'} +#query +#SELECT shiftright(4::int, 1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/bitwise/shiftrightunsigned.slt b/datafusion/sqllogictest/test_files/spark/bitwise/shiftrightunsigned.slt new file mode 100644 index 000000000000..34be404005c4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/bitwise/shiftrightunsigned.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT shiftrightunsigned(4, 1); +## PySpark 3.5.5 Result: {'shiftrightunsigned(4, 1)': 2, 'typeof(shiftrightunsigned(4, 1))': 'int', 'typeof(4)': 'int', 'typeof(1)': 'int'} +#query +#SELECT shiftrightunsigned(4::int, 1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/collection/concat.slt b/datafusion/sqllogictest/test_files/spark/collection/concat.slt new file mode 100644 index 000000000000..8ac5a836c5d9 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/collection/concat.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT concat('Spark', 'SQL'); +## PySpark 3.5.5 Result: {'concat(Spark, SQL)': 'SparkSQL', 'typeof(concat(Spark, SQL))': 'string', 'typeof(Spark)': 'string', 'typeof(SQL)': 'string'} +#query +#SELECT concat('Spark'::string, 'SQL'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/collection/reverse.slt b/datafusion/sqllogictest/test_files/spark/collection/reverse.slt new file mode 100644 index 000000000000..961cbd3baa6e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/collection/reverse.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT reverse('Spark SQL'); +## PySpark 3.5.5 Result: {'reverse(Spark SQL)': 'LQS krapS', 'typeof(reverse(Spark SQL))': 'string', 'typeof(Spark SQL)': 'string'} +#query +#SELECT reverse('Spark SQL'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/conditional/coalesce.slt b/datafusion/sqllogictest/test_files/spark/conditional/coalesce.slt new file mode 100644 index 000000000000..9c9f9796fc39 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/conditional/coalesce.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT coalesce(NULL, 1, NULL); +## PySpark 3.5.5 Result: {'coalesce(NULL, 1, NULL)': 1, 'typeof(coalesce(NULL, 1, NULL))': 'int', 'typeof(NULL)': 'void', 'typeof(1)': 'int'} +#query +#SELECT coalesce(NULL::void, 1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/conditional/if.slt b/datafusion/sqllogictest/test_files/spark/conditional/if.slt new file mode 100644 index 000000000000..e63bab6754d4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/conditional/if.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT if(1 < 2, 'a', 'b'); +## PySpark 3.5.5 Result: {'(IF((1 < 2), a, b))': 'a', 'typeof((IF((1 < 2), a, b)))': 'string', 'typeof((1 < 2))': 'boolean', 'typeof(a)': 'string', 'typeof(b)': 'string'} +#query +#SELECT if((1 < 2)::boolean, 'a'::string, 'b'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/conditional/nullif.slt b/datafusion/sqllogictest/test_files/spark/conditional/nullif.slt new file mode 100644 index 000000000000..c13793ab078b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/conditional/nullif.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT nullif(2, 2); +## PySpark 3.5.5 Result: {'nullif(2, 2)': None, 'typeof(nullif(2, 2))': 'int', 'typeof(2)': 'int'} +#query +#SELECT nullif(2::int); + diff --git a/datafusion/sqllogictest/test_files/spark/conditional/nvl2.slt b/datafusion/sqllogictest/test_files/spark/conditional/nvl2.slt new file mode 100644 index 000000000000..e4507a8223cd --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/conditional/nvl2.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT nvl2(NULL, 2, 1); +## PySpark 3.5.5 Result: {'nvl2(NULL, 2, 1)': 1, 'typeof(nvl2(NULL, 2, 1))': 'int', 'typeof(NULL)': 'void', 'typeof(2)': 'int', 'typeof(1)': 'int'} +#query +#SELECT nvl2(NULL::void, 2::int, 1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/csv/schema_of_csv.slt b/datafusion/sqllogictest/test_files/spark/csv/schema_of_csv.slt new file mode 100644 index 000000000000..8f790748bfa7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/csv/schema_of_csv.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT schema_of_csv('1,abc'); +## PySpark 3.5.5 Result: {'schema_of_csv(1,abc)': 'STRUCT<_c0: INT, _c1: STRING>', 'typeof(schema_of_csv(1,abc))': 'string', 'typeof(1,abc)': 'string'} +#query +#SELECT schema_of_csv('1,abc'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/add_months.slt b/datafusion/sqllogictest/test_files/spark/datetime/add_months.slt new file mode 100644 index 000000000000..df71582048b4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/add_months.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT add_months('2016-08-31', 1); +## PySpark 3.5.5 Result: {'add_months(2016-08-31, 1)': datetime.date(2016, 9, 30), 'typeof(add_months(2016-08-31, 1))': 'date', 'typeof(2016-08-31)': 'string', 'typeof(1)': 'int'} +#query +#SELECT add_months('2016-08-31'::string, 1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/convert_timezone.slt b/datafusion/sqllogictest/test_files/spark/datetime/convert_timezone.slt new file mode 100644 index 000000000000..303787a96b46 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/convert_timezone.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT convert_timezone('Europe/Brussels', 'America/Los_Angeles', timestamp_ntz'2021-12-06 00:00:00'); +## PySpark 3.5.5 Result: {"convert_timezone(Europe/Brussels, America/Los_Angeles, TIMESTAMP_NTZ '2021-12-06 00:00:00')": datetime.datetime(2021, 12, 5, 15, 0), "typeof(convert_timezone(Europe/Brussels, America/Los_Angeles, TIMESTAMP_NTZ '2021-12-06 00:00:00'))": 'timestamp_ntz', 'typeof(Europe/Brussels)': 'string', 'typeof(America/Los_Angeles)': 'string', "typeof(TIMESTAMP_NTZ '2021-12-06 00:00:00')": 'timestamp_ntz'} +#query +#SELECT convert_timezone('Europe/Brussels'::string, 'America/Los_Angeles'::string, TIMESTAMP_NTZ '2021-12-06 00:00:00'::timestamp_ntz); + +## Original Query: SELECT convert_timezone('Europe/Brussels', timestamp_ntz'2021-12-05 15:00:00'); +## PySpark 3.5.5 Result: {"convert_timezone(current_timezone(), Europe/Brussels, TIMESTAMP_NTZ '2021-12-05 15:00:00')": datetime.datetime(2021, 12, 6, 0, 0), "typeof(convert_timezone(current_timezone(), Europe/Brussels, TIMESTAMP_NTZ '2021-12-05 15:00:00'))": 'timestamp_ntz', 'typeof(Europe/Brussels)': 'string', "typeof(TIMESTAMP_NTZ '2021-12-05 15:00:00')": 'timestamp_ntz'} +#query +#SELECT convert_timezone('Europe/Brussels'::string, TIMESTAMP_NTZ '2021-12-05 15:00:00'::timestamp_ntz); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/curdate.slt b/datafusion/sqllogictest/test_files/spark/datetime/curdate.slt new file mode 100644 index 000000000000..5dd57084c591 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/curdate.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT curdate(); +## PySpark 3.5.5 Result: {'current_date()': datetime.date(2025, 6, 14), 'typeof(current_date())': 'date'} +#query +#SELECT curdate(); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/current_date.slt b/datafusion/sqllogictest/test_files/spark/datetime/current_date.slt new file mode 100644 index 000000000000..50e29f8f34d6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/current_date.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT current_date(); +## PySpark 3.5.5 Result: {'current_date()': datetime.date(2025, 6, 14), 'typeof(current_date())': 'date'} +#query +#SELECT current_date(); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/current_timestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/current_timestamp.slt new file mode 100644 index 000000000000..b440e246ab5e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/current_timestamp.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT current_timestamp(); +## PySpark 3.5.5 Result: {'current_timestamp()': datetime.datetime(2025, 6, 14, 23, 57, 38, 948981), 'typeof(current_timestamp())': 'timestamp'} +#query +#SELECT current_timestamp(); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/current_timezone.slt b/datafusion/sqllogictest/test_files/spark/datetime/current_timezone.slt new file mode 100644 index 000000000000..c811bc980504 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/current_timezone.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT current_timezone(); +## PySpark 3.5.5 Result: {'current_timezone()': 'America/Los_Angeles', 'typeof(current_timezone())': 'string'} +#query +#SELECT current_timezone(); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_add.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_add.slt new file mode 100644 index 000000000000..3a9fcd2c2e53 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_add.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT date_add('2016-07-30', 1); +## PySpark 3.5.5 Result: {'date_add(2016-07-30, 1)': datetime.date(2016, 7, 31), 'typeof(date_add(2016-07-30, 1))': 'date', 'typeof(2016-07-30)': 'string', 'typeof(1)': 'int'} +#query +#SELECT date_add('2016-07-30'::string, 1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_diff.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_diff.slt new file mode 100644 index 000000000000..8b17bd74e2d1 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_diff.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT date_diff('2009-07-30', '2009-07-31'); +## PySpark 3.5.5 Result: {'date_diff(2009-07-30, 2009-07-31)': -1, 'typeof(date_diff(2009-07-30, 2009-07-31))': 'int', 'typeof(2009-07-30)': 'string', 'typeof(2009-07-31)': 'string'} +#query +#SELECT date_diff('2009-07-30'::string, '2009-07-31'::string); + +## Original Query: SELECT date_diff('2009-07-31', '2009-07-30'); +## PySpark 3.5.5 Result: {'date_diff(2009-07-31, 2009-07-30)': 1, 'typeof(date_diff(2009-07-31, 2009-07-30))': 'int', 'typeof(2009-07-31)': 'string', 'typeof(2009-07-30)': 'string'} +#query +#SELECT date_diff('2009-07-31'::string, '2009-07-30'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_format.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_format.slt new file mode 100644 index 000000000000..87dd581e0111 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_format.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT date_format('2016-04-08', 'y'); +## PySpark 3.5.5 Result: {'date_format(2016-04-08, y)': '2016', 'typeof(date_format(2016-04-08, y))': 'string', 'typeof(2016-04-08)': 'string', 'typeof(y)': 'string'} +#query +#SELECT date_format('2016-04-08'::string, 'y'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_part.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_part.slt new file mode 100644 index 000000000000..b2f00a0435e4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_part.slt @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT date_part('MINUTE', INTERVAL '123 23:55:59.002001' DAY TO SECOND); +## PySpark 3.5.5 Result: {"date_part(MINUTE, INTERVAL '123 23:55:59.002001' DAY TO SECOND)": 55, "typeof(date_part(MINUTE, INTERVAL '123 23:55:59.002001' DAY TO SECOND))": 'tinyint', 'typeof(MINUTE)': 'string', "typeof(INTERVAL '123 23:55:59.002001' DAY TO SECOND)": 'interval day to second'} +#query +#SELECT date_part('MINUTE'::string, INTERVAL '123 23:55:59.002001' DAY TO SECOND::interval day to second); + +## Original Query: SELECT date_part('MONTH', INTERVAL '2021-11' YEAR TO MONTH); +## PySpark 3.5.5 Result: {"date_part(MONTH, INTERVAL '2021-11' YEAR TO MONTH)": 11, "typeof(date_part(MONTH, INTERVAL '2021-11' YEAR TO MONTH))": 'tinyint', 'typeof(MONTH)': 'string', "typeof(INTERVAL '2021-11' YEAR TO MONTH)": 'interval year to month'} +#query +#SELECT date_part('MONTH'::string, INTERVAL '2021-11' YEAR TO MONTH::interval year to month); + +## Original Query: SELECT date_part('SECONDS', timestamp'2019-10-01 00:00:01.000001'); +## PySpark 3.5.5 Result: {"date_part(SECONDS, TIMESTAMP '2019-10-01 00:00:01.000001')": Decimal('1.000001'), "typeof(date_part(SECONDS, TIMESTAMP '2019-10-01 00:00:01.000001'))": 'decimal(8,6)', 'typeof(SECONDS)': 'string', "typeof(TIMESTAMP '2019-10-01 00:00:01.000001')": 'timestamp'} +#query +#SELECT date_part('SECONDS'::string, TIMESTAMP '2019-10-01 00:00:01.000001'::timestamp); + +## Original Query: SELECT date_part('YEAR', TIMESTAMP '2019-08-12 01:00:00.123456'); +## PySpark 3.5.5 Result: {"date_part(YEAR, TIMESTAMP '2019-08-12 01:00:00.123456')": 2019, "typeof(date_part(YEAR, TIMESTAMP '2019-08-12 01:00:00.123456'))": 'int', 'typeof(YEAR)': 'string', "typeof(TIMESTAMP '2019-08-12 01:00:00.123456')": 'timestamp'} +#query +#SELECT date_part('YEAR'::string, TIMESTAMP '2019-08-12 01:00:00.123456'::timestamp); + +## Original Query: SELECT date_part('days', interval 5 days 3 hours 7 minutes); +## PySpark 3.5.5 Result: {"date_part(days, INTERVAL '5 03:07' DAY TO MINUTE)": 5, "typeof(date_part(days, INTERVAL '5 03:07' DAY TO MINUTE))": 'int', 'typeof(days)': 'string', "typeof(INTERVAL '5 03:07' DAY TO MINUTE)": 'interval day to minute'} +#query +#SELECT date_part('days'::string, INTERVAL '5 03:07' DAY TO MINUTE::interval day to minute); + +## Original Query: SELECT date_part('doy', DATE'2019-08-12'); +## PySpark 3.5.5 Result: {"date_part(doy, DATE '2019-08-12')": 224, "typeof(date_part(doy, DATE '2019-08-12'))": 'int', 'typeof(doy)': 'string', "typeof(DATE '2019-08-12')": 'date'} +#query +#SELECT date_part('doy'::string, DATE '2019-08-12'::date); + +## Original Query: SELECT date_part('seconds', interval 5 hours 30 seconds 1 milliseconds 1 microseconds); +## PySpark 3.5.5 Result: {"date_part(seconds, INTERVAL '05:00:30.001001' HOUR TO SECOND)": Decimal('30.001001'), "typeof(date_part(seconds, INTERVAL '05:00:30.001001' HOUR TO SECOND))": 'decimal(8,6)', 'typeof(seconds)': 'string', "typeof(INTERVAL '05:00:30.001001' HOUR TO SECOND)": 'interval hour to second'} +#query +#SELECT date_part('seconds'::string, INTERVAL '05:00:30.001001' HOUR TO SECOND::interval hour to second); + +## Original Query: SELECT date_part('week', timestamp'2019-08-12 01:00:00.123456'); +## PySpark 3.5.5 Result: {"date_part(week, TIMESTAMP '2019-08-12 01:00:00.123456')": 33, "typeof(date_part(week, TIMESTAMP '2019-08-12 01:00:00.123456'))": 'int', 'typeof(week)': 'string', "typeof(TIMESTAMP '2019-08-12 01:00:00.123456')": 'timestamp'} +#query +#SELECT date_part('week'::string, TIMESTAMP '2019-08-12 01:00:00.123456'::timestamp); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_sub.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_sub.slt new file mode 100644 index 000000000000..54faf2586c10 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_sub.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT date_sub('2016-07-30', 1); +## PySpark 3.5.5 Result: {'date_sub(2016-07-30, 1)': datetime.date(2016, 7, 29), 'typeof(date_sub(2016-07-30, 1))': 'date', 'typeof(2016-07-30)': 'string', 'typeof(1)': 'int'} +#query +#SELECT date_sub('2016-07-30'::string, 1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_trunc.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_trunc.slt new file mode 100644 index 000000000000..4a8da2ceb1d6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_trunc.slt @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT date_trunc('DD', '2015-03-05T09:32:05.359'); +## PySpark 3.5.5 Result: {'date_trunc(DD, 2015-03-05T09:32:05.359)': datetime.datetime(2015, 3, 5, 0, 0), 'typeof(date_trunc(DD, 2015-03-05T09:32:05.359))': 'timestamp', 'typeof(DD)': 'string', 'typeof(2015-03-05T09:32:05.359)': 'string'} +#query +#SELECT date_trunc('DD'::string, '2015-03-05T09:32:05.359'::string); + +## Original Query: SELECT date_trunc('HOUR', '2015-03-05T09:32:05.359'); +## PySpark 3.5.5 Result: {'date_trunc(HOUR, 2015-03-05T09:32:05.359)': datetime.datetime(2015, 3, 5, 9, 0), 'typeof(date_trunc(HOUR, 2015-03-05T09:32:05.359))': 'timestamp', 'typeof(HOUR)': 'string', 'typeof(2015-03-05T09:32:05.359)': 'string'} +#query +#SELECT date_trunc('HOUR'::string, '2015-03-05T09:32:05.359'::string); + +## Original Query: SELECT date_trunc('MILLISECOND', '2015-03-05T09:32:05.123456'); +## PySpark 3.5.5 Result: {'date_trunc(MILLISECOND, 2015-03-05T09:32:05.123456)': datetime.datetime(2015, 3, 5, 9, 32, 5, 123000), 'typeof(date_trunc(MILLISECOND, 2015-03-05T09:32:05.123456))': 'timestamp', 'typeof(MILLISECOND)': 'string', 'typeof(2015-03-05T09:32:05.123456)': 'string'} +#query +#SELECT date_trunc('MILLISECOND'::string, '2015-03-05T09:32:05.123456'::string); + +## Original Query: SELECT date_trunc('MM', '2015-03-05T09:32:05.359'); +## PySpark 3.5.5 Result: {'date_trunc(MM, 2015-03-05T09:32:05.359)': datetime.datetime(2015, 3, 1, 0, 0), 'typeof(date_trunc(MM, 2015-03-05T09:32:05.359))': 'timestamp', 'typeof(MM)': 'string', 'typeof(2015-03-05T09:32:05.359)': 'string'} +#query +#SELECT date_trunc('MM'::string, '2015-03-05T09:32:05.359'::string); + +## Original Query: SELECT date_trunc('YEAR', '2015-03-05T09:32:05.359'); +## PySpark 3.5.5 Result: {'date_trunc(YEAR, 2015-03-05T09:32:05.359)': datetime.datetime(2015, 1, 1, 0, 0), 'typeof(date_trunc(YEAR, 2015-03-05T09:32:05.359))': 'timestamp', 'typeof(YEAR)': 'string', 'typeof(2015-03-05T09:32:05.359)': 'string'} +#query +#SELECT date_trunc('YEAR'::string, '2015-03-05T09:32:05.359'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/dateadd.slt b/datafusion/sqllogictest/test_files/spark/datetime/dateadd.slt new file mode 100644 index 000000000000..5c9e34fa12e2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/dateadd.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT dateadd('2016-07-30', 1); +## PySpark 3.5.5 Result: {'date_add(2016-07-30, 1)': datetime.date(2016, 7, 31), 'typeof(date_add(2016-07-30, 1))': 'date', 'typeof(2016-07-30)': 'string', 'typeof(1)': 'int'} +#query +#SELECT dateadd('2016-07-30'::string, 1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/datediff.slt b/datafusion/sqllogictest/test_files/spark/datetime/datediff.slt new file mode 100644 index 000000000000..a1f83ce01b6b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/datediff.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT datediff('2009-07-30', '2009-07-31'); +## PySpark 3.5.5 Result: {'datediff(2009-07-30, 2009-07-31)': -1, 'typeof(datediff(2009-07-30, 2009-07-31))': 'int', 'typeof(2009-07-30)': 'string', 'typeof(2009-07-31)': 'string'} +#query +#SELECT datediff('2009-07-30'::string, '2009-07-31'::string); + +## Original Query: SELECT datediff('2009-07-31', '2009-07-30'); +## PySpark 3.5.5 Result: {'datediff(2009-07-31, 2009-07-30)': 1, 'typeof(datediff(2009-07-31, 2009-07-30))': 'int', 'typeof(2009-07-31)': 'string', 'typeof(2009-07-30)': 'string'} +#query +#SELECT datediff('2009-07-31'::string, '2009-07-30'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/datepart.slt b/datafusion/sqllogictest/test_files/spark/datetime/datepart.slt new file mode 100644 index 000000000000..731e26eebc52 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/datepart.slt @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT datepart('MINUTE', INTERVAL '123 23:55:59.002001' DAY TO SECOND); +## PySpark 3.5.5 Result: {"datepart(MINUTE FROM INTERVAL '123 23:55:59.002001' DAY TO SECOND)": 55, "typeof(datepart(MINUTE FROM INTERVAL '123 23:55:59.002001' DAY TO SECOND))": 'tinyint', 'typeof(MINUTE)': 'string', "typeof(INTERVAL '123 23:55:59.002001' DAY TO SECOND)": 'interval day to second'} +#query +#SELECT datepart('MINUTE'::string, INTERVAL '123 23:55:59.002001' DAY TO SECOND::interval day to second); + +## Original Query: SELECT datepart('MONTH', INTERVAL '2021-11' YEAR TO MONTH); +## PySpark 3.5.5 Result: {"datepart(MONTH FROM INTERVAL '2021-11' YEAR TO MONTH)": 11, "typeof(datepart(MONTH FROM INTERVAL '2021-11' YEAR TO MONTH))": 'tinyint', 'typeof(MONTH)': 'string', "typeof(INTERVAL '2021-11' YEAR TO MONTH)": 'interval year to month'} +#query +#SELECT datepart('MONTH'::string, INTERVAL '2021-11' YEAR TO MONTH::interval year to month); + +## Original Query: SELECT datepart('SECONDS', timestamp'2019-10-01 00:00:01.000001'); +## PySpark 3.5.5 Result: {"datepart(SECONDS FROM TIMESTAMP '2019-10-01 00:00:01.000001')": Decimal('1.000001'), "typeof(datepart(SECONDS FROM TIMESTAMP '2019-10-01 00:00:01.000001'))": 'decimal(8,6)', 'typeof(SECONDS)': 'string', "typeof(TIMESTAMP '2019-10-01 00:00:01.000001')": 'timestamp'} +#query +#SELECT datepart('SECONDS'::string, TIMESTAMP '2019-10-01 00:00:01.000001'::timestamp); + +## Original Query: SELECT datepart('YEAR', TIMESTAMP '2019-08-12 01:00:00.123456'); +## PySpark 3.5.5 Result: {"datepart(YEAR FROM TIMESTAMP '2019-08-12 01:00:00.123456')": 2019, "typeof(datepart(YEAR FROM TIMESTAMP '2019-08-12 01:00:00.123456'))": 'int', 'typeof(YEAR)': 'string', "typeof(TIMESTAMP '2019-08-12 01:00:00.123456')": 'timestamp'} +#query +#SELECT datepart('YEAR'::string, TIMESTAMP '2019-08-12 01:00:00.123456'::timestamp); + +## Original Query: SELECT datepart('days', interval 5 days 3 hours 7 minutes); +## PySpark 3.5.5 Result: {"datepart(days FROM INTERVAL '5 03:07' DAY TO MINUTE)": 5, "typeof(datepart(days FROM INTERVAL '5 03:07' DAY TO MINUTE))": 'int', 'typeof(days)': 'string', "typeof(INTERVAL '5 03:07' DAY TO MINUTE)": 'interval day to minute'} +#query +#SELECT datepart('days'::string, INTERVAL '5 03:07' DAY TO MINUTE::interval day to minute); + +## Original Query: SELECT datepart('doy', DATE'2019-08-12'); +## PySpark 3.5.5 Result: {"datepart(doy FROM DATE '2019-08-12')": 224, "typeof(datepart(doy FROM DATE '2019-08-12'))": 'int', 'typeof(doy)': 'string', "typeof(DATE '2019-08-12')": 'date'} +#query +#SELECT datepart('doy'::string, DATE '2019-08-12'::date); + +## Original Query: SELECT datepart('seconds', interval 5 hours 30 seconds 1 milliseconds 1 microseconds); +## PySpark 3.5.5 Result: {"datepart(seconds FROM INTERVAL '05:00:30.001001' HOUR TO SECOND)": Decimal('30.001001'), "typeof(datepart(seconds FROM INTERVAL '05:00:30.001001' HOUR TO SECOND))": 'decimal(8,6)', 'typeof(seconds)': 'string', "typeof(INTERVAL '05:00:30.001001' HOUR TO SECOND)": 'interval hour to second'} +#query +#SELECT datepart('seconds'::string, INTERVAL '05:00:30.001001' HOUR TO SECOND::interval hour to second); + +## Original Query: SELECT datepart('week', timestamp'2019-08-12 01:00:00.123456'); +## PySpark 3.5.5 Result: {"datepart(week FROM TIMESTAMP '2019-08-12 01:00:00.123456')": 33, "typeof(datepart(week FROM TIMESTAMP '2019-08-12 01:00:00.123456'))": 'int', 'typeof(week)': 'string', "typeof(TIMESTAMP '2019-08-12 01:00:00.123456')": 'timestamp'} +#query +#SELECT datepart('week'::string, TIMESTAMP '2019-08-12 01:00:00.123456'::timestamp); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/day.slt b/datafusion/sqllogictest/test_files/spark/datetime/day.slt new file mode 100644 index 000000000000..e6426a20eb19 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/day.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT day('2009-07-30'); +## PySpark 3.5.5 Result: {'day(2009-07-30)': 30, 'typeof(day(2009-07-30))': 'int', 'typeof(2009-07-30)': 'string'} +#query +#SELECT day('2009-07-30'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/dayofmonth.slt b/datafusion/sqllogictest/test_files/spark/datetime/dayofmonth.slt new file mode 100644 index 000000000000..d685c3d64c1e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/dayofmonth.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT dayofmonth('2009-07-30'); +## PySpark 3.5.5 Result: {'dayofmonth(2009-07-30)': 30, 'typeof(dayofmonth(2009-07-30))': 'int', 'typeof(2009-07-30)': 'string'} +#query +#SELECT dayofmonth('2009-07-30'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/dayofweek.slt b/datafusion/sqllogictest/test_files/spark/datetime/dayofweek.slt new file mode 100644 index 000000000000..c482b6d1566d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/dayofweek.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT dayofweek('2009-07-30'); +## PySpark 3.5.5 Result: {'dayofweek(2009-07-30)': 5, 'typeof(dayofweek(2009-07-30))': 'int', 'typeof(2009-07-30)': 'string'} +#query +#SELECT dayofweek('2009-07-30'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/dayofyear.slt b/datafusion/sqllogictest/test_files/spark/datetime/dayofyear.slt new file mode 100644 index 000000000000..71b1afb9d826 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/dayofyear.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT dayofyear('2016-04-09'); +## PySpark 3.5.5 Result: {'dayofyear(2016-04-09)': 100, 'typeof(dayofyear(2016-04-09))': 'int', 'typeof(2016-04-09)': 'string'} +#query +#SELECT dayofyear('2016-04-09'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/hour.slt b/datafusion/sqllogictest/test_files/spark/datetime/hour.slt new file mode 100644 index 000000000000..1dd6bcd4669e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/hour.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT hour('2009-07-30 12:58:59'); +## PySpark 3.5.5 Result: {'hour(2009-07-30 12:58:59)': 12, 'typeof(hour(2009-07-30 12:58:59))': 'int', 'typeof(2009-07-30 12:58:59)': 'string'} +#query +#SELECT hour('2009-07-30 12:58:59'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/last_day.slt b/datafusion/sqllogictest/test_files/spark/datetime/last_day.slt new file mode 100644 index 000000000000..c298a2e69c98 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/last_day.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT last_day('2009-01-12'); +## PySpark 3.5.5 Result: {'last_day(2009-01-12)': datetime.date(2009, 1, 31), 'typeof(last_day(2009-01-12))': 'date', 'typeof(2009-01-12)': 'string'} +#query +#SELECT last_day('2009-01-12'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/localtimestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/localtimestamp.slt new file mode 100644 index 000000000000..3533c0c6131d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/localtimestamp.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT localtimestamp(); +## PySpark 3.5.5 Result: {'localtimestamp()': datetime.datetime(2025, 6, 14, 23, 57, 39, 529742), 'typeof(localtimestamp())': 'timestamp_ntz'} +#query +#SELECT localtimestamp(); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/make_date.slt b/datafusion/sqllogictest/test_files/spark/datetime/make_date.slt new file mode 100644 index 000000000000..adeeff7efb5c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/make_date.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT make_date(2013, 7, 15); +## PySpark 3.5.5 Result: {'make_date(2013, 7, 15)': datetime.date(2013, 7, 15), 'typeof(make_date(2013, 7, 15))': 'date', 'typeof(2013)': 'int', 'typeof(7)': 'int', 'typeof(15)': 'int'} +#query +#SELECT make_date(2013::int, 7::int, 15::int); + +## Original Query: SELECT make_date(2019, 7, NULL); +## PySpark 3.5.5 Result: {'make_date(2019, 7, NULL)': None, 'typeof(make_date(2019, 7, NULL))': 'date', 'typeof(2019)': 'int', 'typeof(7)': 'int', 'typeof(NULL)': 'void'} +#query +#SELECT make_date(2019::int, 7::int, NULL::void); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/make_dt_interval.slt b/datafusion/sqllogictest/test_files/spark/datetime/make_dt_interval.slt new file mode 100644 index 000000000000..cdde3d3c7bb5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/make_dt_interval.slt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT make_dt_interval(1, 12, 30, 01.001001); +## PySpark 3.5.5 Result: {'make_dt_interval(1, 12, 30, 1.001001)': datetime.timedelta(days=1, seconds=45001, microseconds=1001), 'typeof(make_dt_interval(1, 12, 30, 1.001001))': 'interval day to second', 'typeof(1)': 'int', 'typeof(12)': 'int', 'typeof(30)': 'int', 'typeof(1.001001)': 'decimal(7,6)'} +#query +#SELECT make_dt_interval(1::int, 12::int, 30::int, 1.001001::decimal(7,6)); + +## Original Query: SELECT make_dt_interval(100, null, 3); +## PySpark 3.5.5 Result: {'make_dt_interval(100, NULL, 3, 0.000000)': None, 'typeof(make_dt_interval(100, NULL, 3, 0.000000))': 'interval day to second', 'typeof(100)': 'int', 'typeof(NULL)': 'void', 'typeof(3)': 'int'} +#query +#SELECT make_dt_interval(100::int, NULL::void, 3::int); + +## Original Query: SELECT make_dt_interval(2); +## PySpark 3.5.5 Result: {'make_dt_interval(2, 0, 0, 0.000000)': datetime.timedelta(days=2), 'typeof(make_dt_interval(2, 0, 0, 0.000000))': 'interval day to second', 'typeof(2)': 'int'} +#query +#SELECT make_dt_interval(2::int); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/make_timestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/make_timestamp.slt new file mode 100644 index 000000000000..89c6bf32151b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/make_timestamp.slt @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT make_timestamp(2014, 12, 28, 6, 30, 45.887); +## PySpark 3.5.5 Result: {'make_timestamp(2014, 12, 28, 6, 30, 45.887)': datetime.datetime(2014, 12, 28, 6, 30, 45, 887000), 'typeof(make_timestamp(2014, 12, 28, 6, 30, 45.887))': 'timestamp', 'typeof(2014)': 'int', 'typeof(12)': 'int', 'typeof(28)': 'int', 'typeof(6)': 'int', 'typeof(30)': 'int', 'typeof(45.887)': 'decimal(5,3)'} +#query +#SELECT make_timestamp(2014::int, 12::int, 28::int, 6::int, 30::int, 45.887::decimal(5,3)); + +## Original Query: SELECT make_timestamp(2014, 12, 28, 6, 30, 45.887, 'CET'); +## PySpark 3.5.5 Result: {'make_timestamp(2014, 12, 28, 6, 30, 45.887, CET)': datetime.datetime(2014, 12, 27, 21, 30, 45, 887000), 'typeof(make_timestamp(2014, 12, 28, 6, 30, 45.887, CET))': 'timestamp', 'typeof(2014)': 'int', 'typeof(12)': 'int', 'typeof(28)': 'int', 'typeof(6)': 'int', 'typeof(30)': 'int', 'typeof(45.887)': 'decimal(5,3)', 'typeof(CET)': 'string'} +#query +#SELECT make_timestamp(2014::int, 12::int, 28::int, 6::int, 30::int, 45.887::decimal(5,3), 'CET'::string); + +## Original Query: SELECT make_timestamp(2019, 6, 30, 23, 59, 1); +## PySpark 3.5.5 Result: {'make_timestamp(2019, 6, 30, 23, 59, 1)': datetime.datetime(2019, 6, 30, 23, 59, 1), 'typeof(make_timestamp(2019, 6, 30, 23, 59, 1))': 'timestamp', 'typeof(2019)': 'int', 'typeof(6)': 'int', 'typeof(30)': 'int', 'typeof(23)': 'int', 'typeof(59)': 'int', 'typeof(1)': 'int'} +#query +#SELECT make_timestamp(2019::int, 6::int, 30::int, 23::int, 59::int, 1::int); + +## Original Query: SELECT make_timestamp(2019, 6, 30, 23, 59, 60); +## PySpark 3.5.5 Result: {'make_timestamp(2019, 6, 30, 23, 59, 60)': datetime.datetime(2019, 7, 1, 0, 0), 'typeof(make_timestamp(2019, 6, 30, 23, 59, 60))': 'timestamp', 'typeof(2019)': 'int', 'typeof(6)': 'int', 'typeof(30)': 'int', 'typeof(23)': 'int', 'typeof(59)': 'int', 'typeof(60)': 'int'} +#query +#SELECT make_timestamp(2019::int, 6::int, 30::int, 23::int, 59::int, 60::int); + +## Original Query: SELECT make_timestamp(null, 7, 22, 15, 30, 0); +## PySpark 3.5.5 Result: {'make_timestamp(NULL, 7, 22, 15, 30, 0)': None, 'typeof(make_timestamp(NULL, 7, 22, 15, 30, 0))': 'timestamp', 'typeof(NULL)': 'void', 'typeof(7)': 'int', 'typeof(22)': 'int', 'typeof(15)': 'int', 'typeof(30)': 'int', 'typeof(0)': 'int'} +#query +#SELECT make_timestamp(NULL::void, 7::int, 22::int, 15::int, 30::int, 0::int); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/make_timestamp_ltz.slt b/datafusion/sqllogictest/test_files/spark/datetime/make_timestamp_ltz.slt new file mode 100644 index 000000000000..8bb93df3d2e5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/make_timestamp_ltz.slt @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887); +## PySpark 3.5.5 Result: {'make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887)': datetime.datetime(2014, 12, 28, 6, 30, 45, 887000), 'typeof(make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887))': 'timestamp', 'typeof(2014)': 'int', 'typeof(12)': 'int', 'typeof(28)': 'int', 'typeof(6)': 'int', 'typeof(30)': 'int', 'typeof(45.887)': 'decimal(5,3)'} +#query +#SELECT make_timestamp_ltz(2014::int, 12::int, 28::int, 6::int, 30::int, 45.887::decimal(5,3)); + +## Original Query: SELECT make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887, 'CET'); +## PySpark 3.5.5 Result: {'make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887, CET)': datetime.datetime(2014, 12, 27, 21, 30, 45, 887000), 'typeof(make_timestamp_ltz(2014, 12, 28, 6, 30, 45.887, CET))': 'timestamp', 'typeof(2014)': 'int', 'typeof(12)': 'int', 'typeof(28)': 'int', 'typeof(6)': 'int', 'typeof(30)': 'int', 'typeof(45.887)': 'decimal(5,3)', 'typeof(CET)': 'string'} +#query +#SELECT make_timestamp_ltz(2014::int, 12::int, 28::int, 6::int, 30::int, 45.887::decimal(5,3), 'CET'::string); + +## Original Query: SELECT make_timestamp_ltz(2019, 6, 30, 23, 59, 60); +## PySpark 3.5.5 Result: {'make_timestamp_ltz(2019, 6, 30, 23, 59, 60)': datetime.datetime(2019, 7, 1, 0, 0), 'typeof(make_timestamp_ltz(2019, 6, 30, 23, 59, 60))': 'timestamp', 'typeof(2019)': 'int', 'typeof(6)': 'int', 'typeof(30)': 'int', 'typeof(23)': 'int', 'typeof(59)': 'int', 'typeof(60)': 'int'} +#query +#SELECT make_timestamp_ltz(2019::int, 6::int, 30::int, 23::int, 59::int, 60::int); + +## Original Query: SELECT make_timestamp_ltz(null, 7, 22, 15, 30, 0); +## PySpark 3.5.5 Result: {'make_timestamp_ltz(NULL, 7, 22, 15, 30, 0)': None, 'typeof(make_timestamp_ltz(NULL, 7, 22, 15, 30, 0))': 'timestamp', 'typeof(NULL)': 'void', 'typeof(7)': 'int', 'typeof(22)': 'int', 'typeof(15)': 'int', 'typeof(30)': 'int', 'typeof(0)': 'int'} +#query +#SELECT make_timestamp_ltz(NULL::void, 7::int, 22::int, 15::int, 30::int, 0::int); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/make_timestamp_ntz.slt b/datafusion/sqllogictest/test_files/spark/datetime/make_timestamp_ntz.slt new file mode 100644 index 000000000000..b65649d81a98 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/make_timestamp_ntz.slt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT make_timestamp_ntz(2014, 12, 28, 6, 30, 45.887); +## PySpark 3.5.5 Result: {'make_timestamp_ntz(2014, 12, 28, 6, 30, 45.887)': datetime.datetime(2014, 12, 28, 6, 30, 45, 887000), 'typeof(make_timestamp_ntz(2014, 12, 28, 6, 30, 45.887))': 'timestamp_ntz', 'typeof(2014)': 'int', 'typeof(12)': 'int', 'typeof(28)': 'int', 'typeof(6)': 'int', 'typeof(30)': 'int', 'typeof(45.887)': 'decimal(5,3)'} +#query +#SELECT make_timestamp_ntz(2014::int, 12::int, 28::int, 6::int, 30::int, 45.887::decimal(5,3)); + +## Original Query: SELECT make_timestamp_ntz(2019, 6, 30, 23, 59, 60); +## PySpark 3.5.5 Result: {'make_timestamp_ntz(2019, 6, 30, 23, 59, 60)': datetime.datetime(2019, 7, 1, 0, 0), 'typeof(make_timestamp_ntz(2019, 6, 30, 23, 59, 60))': 'timestamp_ntz', 'typeof(2019)': 'int', 'typeof(6)': 'int', 'typeof(30)': 'int', 'typeof(23)': 'int', 'typeof(59)': 'int', 'typeof(60)': 'int'} +#query +#SELECT make_timestamp_ntz(2019::int, 6::int, 30::int, 23::int, 59::int, 60::int); + +## Original Query: SELECT make_timestamp_ntz(null, 7, 22, 15, 30, 0); +## PySpark 3.5.5 Result: {'make_timestamp_ntz(NULL, 7, 22, 15, 30, 0)': None, 'typeof(make_timestamp_ntz(NULL, 7, 22, 15, 30, 0))': 'timestamp_ntz', 'typeof(NULL)': 'void', 'typeof(7)': 'int', 'typeof(22)': 'int', 'typeof(15)': 'int', 'typeof(30)': 'int', 'typeof(0)': 'int'} +#query +#SELECT make_timestamp_ntz(NULL::void, 7::int, 22::int, 15::int, 30::int, 0::int); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/make_ym_interval.slt b/datafusion/sqllogictest/test_files/spark/datetime/make_ym_interval.slt new file mode 100644 index 000000000000..2a5bc7b3d28c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/make_ym_interval.slt @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT make_ym_interval(-1, 1); +## PySpark 3.5.5 Result: {'make_ym_interval(-1, 1)': -11, 'typeof(make_ym_interval(-1, 1))': 'interval year to month', 'typeof(-1)': 'int', 'typeof(1)': 'int'} +#query +#SELECT make_ym_interval(-1::int, 1::int); + +## Original Query: SELECT make_ym_interval(1, 0); +## PySpark 3.5.5 Result: {'make_ym_interval(1, 0)': 12, 'typeof(make_ym_interval(1, 0))': 'interval year to month', 'typeof(1)': 'int', 'typeof(0)': 'int'} +#query +#SELECT make_ym_interval(1::int, 0::int); + +## Original Query: SELECT make_ym_interval(1, 2); +## PySpark 3.5.5 Result: {'make_ym_interval(1, 2)': 14, 'typeof(make_ym_interval(1, 2))': 'interval year to month', 'typeof(1)': 'int', 'typeof(2)': 'int'} +#query +#SELECT make_ym_interval(1::int, 2::int); + +## Original Query: SELECT make_ym_interval(2); +## PySpark 3.5.5 Result: {'make_ym_interval(2, 0)': 24, 'typeof(make_ym_interval(2, 0))': 'interval year to month', 'typeof(2)': 'int'} +#query +#SELECT make_ym_interval(2::int); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/minute.slt b/datafusion/sqllogictest/test_files/spark/datetime/minute.slt new file mode 100644 index 000000000000..451f5d2c4025 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/minute.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT minute('2009-07-30 12:58:59'); +## PySpark 3.5.5 Result: {'minute(2009-07-30 12:58:59)': 58, 'typeof(minute(2009-07-30 12:58:59))': 'int', 'typeof(2009-07-30 12:58:59)': 'string'} +#query +#SELECT minute('2009-07-30 12:58:59'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/month.slt b/datafusion/sqllogictest/test_files/spark/datetime/month.slt new file mode 100644 index 000000000000..fc182e157ffd --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/month.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT month('2016-07-30'); +## PySpark 3.5.5 Result: {'month(2016-07-30)': 7, 'typeof(month(2016-07-30))': 'int', 'typeof(2016-07-30)': 'string'} +#query +#SELECT month('2016-07-30'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/months_between.slt b/datafusion/sqllogictest/test_files/spark/datetime/months_between.slt new file mode 100644 index 000000000000..252e364a4818 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/months_between.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT months_between('1997-02-28 10:30:00', '1996-10-30'); +## PySpark 3.5.5 Result: {'months_between(1997-02-28 10:30:00, 1996-10-30, true)': 3.94959677, 'typeof(months_between(1997-02-28 10:30:00, 1996-10-30, true))': 'double', 'typeof(1997-02-28 10:30:00)': 'string', 'typeof(1996-10-30)': 'string'} +#query +#SELECT months_between('1997-02-28 10:30:00'::string, '1996-10-30'::string); + +## Original Query: SELECT months_between('1997-02-28 10:30:00', '1996-10-30', false); +## PySpark 3.5.5 Result: {'months_between(1997-02-28 10:30:00, 1996-10-30, false)': 3.9495967741935485, 'typeof(months_between(1997-02-28 10:30:00, 1996-10-30, false))': 'double', 'typeof(1997-02-28 10:30:00)': 'string', 'typeof(1996-10-30)': 'string', 'typeof(false)': 'boolean'} +#query +#SELECT months_between('1997-02-28 10:30:00'::string, '1996-10-30'::string, false::boolean); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/next_day.slt b/datafusion/sqllogictest/test_files/spark/datetime/next_day.slt new file mode 100644 index 000000000000..2c8f30b895e1 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/next_day.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT next_day('2015-01-14', 'TU'); +## PySpark 3.5.5 Result: {'next_day(2015-01-14, TU)': datetime.date(2015, 1, 20), 'typeof(next_day(2015-01-14, TU))': 'date', 'typeof(2015-01-14)': 'string', 'typeof(TU)': 'string'} +#query +#SELECT next_day('2015-01-14'::string, 'TU'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/now.slt b/datafusion/sqllogictest/test_files/spark/datetime/now.slt new file mode 100644 index 000000000000..36ae31ff5ec0 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/now.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT now(); +## PySpark 3.5.5 Result: {'now()': datetime.datetime(2025, 6, 14, 23, 57, 39, 982956), 'typeof(now())': 'timestamp'} +#query +#SELECT now(); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/quarter.slt b/datafusion/sqllogictest/test_files/spark/datetime/quarter.slt new file mode 100644 index 000000000000..e283d78d4fac --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/quarter.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT quarter('2016-08-31'); +## PySpark 3.5.5 Result: {'quarter(2016-08-31)': 3, 'typeof(quarter(2016-08-31))': 'int', 'typeof(2016-08-31)': 'string'} +#query +#SELECT quarter('2016-08-31'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/second.slt b/datafusion/sqllogictest/test_files/spark/datetime/second.slt new file mode 100644 index 000000000000..799d37101d51 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/second.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT second('2009-07-30 12:58:59'); +## PySpark 3.5.5 Result: {'second(2009-07-30 12:58:59)': 59, 'typeof(second(2009-07-30 12:58:59))': 'int', 'typeof(2009-07-30 12:58:59)': 'string'} +#query +#SELECT second('2009-07-30 12:58:59'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/timestamp_micros.slt b/datafusion/sqllogictest/test_files/spark/datetime/timestamp_micros.slt new file mode 100644 index 000000000000..b73955112307 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/timestamp_micros.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT timestamp_micros(1230219000123123); +## PySpark 3.5.5 Result: {'timestamp_micros(1230219000123123)': datetime.datetime(2008, 12, 25, 7, 30, 0, 123123), 'typeof(timestamp_micros(1230219000123123))': 'timestamp', 'typeof(1230219000123123)': 'bigint'} +#query +#SELECT timestamp_micros(1230219000123123::bigint); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/timestamp_millis.slt b/datafusion/sqllogictest/test_files/spark/datetime/timestamp_millis.slt new file mode 100644 index 000000000000..a0b7c13772ee --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/timestamp_millis.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT timestamp_millis(1230219000123); +## PySpark 3.5.5 Result: {'timestamp_millis(1230219000123)': datetime.datetime(2008, 12, 25, 7, 30, 0, 123000), 'typeof(timestamp_millis(1230219000123))': 'timestamp', 'typeof(1230219000123)': 'bigint'} +#query +#SELECT timestamp_millis(1230219000123::bigint); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/timestamp_seconds.slt b/datafusion/sqllogictest/test_files/spark/datetime/timestamp_seconds.slt new file mode 100644 index 000000000000..a883ab7bd1e1 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/timestamp_seconds.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT timestamp_seconds(1230219000); +## PySpark 3.5.5 Result: {'timestamp_seconds(1230219000)': datetime.datetime(2008, 12, 25, 7, 30), 'typeof(timestamp_seconds(1230219000))': 'timestamp', 'typeof(1230219000)': 'int'} +#query +#SELECT timestamp_seconds(1230219000::int); + +## Original Query: SELECT timestamp_seconds(1230219000.123); +## PySpark 3.5.5 Result: {'timestamp_seconds(1230219000.123)': datetime.datetime(2008, 12, 25, 7, 30, 0, 123000), 'typeof(timestamp_seconds(1230219000.123))': 'timestamp', 'typeof(1230219000.123)': 'decimal(13,3)'} +#query +#SELECT timestamp_seconds(1230219000.123::decimal(13,3)); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/to_date.slt b/datafusion/sqllogictest/test_files/spark/datetime/to_date.slt new file mode 100644 index 000000000000..d7128942b950 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/to_date.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_date('2009-07-30 04:17:52'); +## PySpark 3.5.5 Result: {'to_date(2009-07-30 04:17:52)': datetime.date(2009, 7, 30), 'typeof(to_date(2009-07-30 04:17:52))': 'date', 'typeof(2009-07-30 04:17:52)': 'string'} +#query +#SELECT to_date('2009-07-30 04:17:52'::string); + +## Original Query: SELECT to_date('2016-12-31', 'yyyy-MM-dd'); +## PySpark 3.5.5 Result: {'to_date(2016-12-31, yyyy-MM-dd)': datetime.date(2016, 12, 31), 'typeof(to_date(2016-12-31, yyyy-MM-dd))': 'date', 'typeof(2016-12-31)': 'string', 'typeof(yyyy-MM-dd)': 'string'} +#query +#SELECT to_date('2016-12-31'::string, 'yyyy-MM-dd'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/to_timestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/to_timestamp.slt new file mode 100644 index 000000000000..6511268b68b9 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/to_timestamp.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_timestamp('2016-12-31 00:12:00'); +## PySpark 3.5.5 Result: {'to_timestamp(2016-12-31 00:12:00)': datetime.datetime(2016, 12, 31, 0, 12), 'typeof(to_timestamp(2016-12-31 00:12:00))': 'timestamp', 'typeof(2016-12-31 00:12:00)': 'string'} +#query +#SELECT to_timestamp('2016-12-31 00:12:00'::string); + +## Original Query: SELECT to_timestamp('2016-12-31', 'yyyy-MM-dd'); +## PySpark 3.5.5 Result: {'to_timestamp(2016-12-31, yyyy-MM-dd)': datetime.datetime(2016, 12, 31, 0, 0), 'typeof(to_timestamp(2016-12-31, yyyy-MM-dd))': 'timestamp', 'typeof(2016-12-31)': 'string', 'typeof(yyyy-MM-dd)': 'string'} +#query +#SELECT to_timestamp('2016-12-31'::string, 'yyyy-MM-dd'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/to_timestamp_ltz.slt b/datafusion/sqllogictest/test_files/spark/datetime/to_timestamp_ltz.slt new file mode 100644 index 000000000000..9181fb5ea399 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/to_timestamp_ltz.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_timestamp_ltz('2016-12-31 00:12:00'); +## PySpark 3.5.5 Result: {'to_timestamp_ltz(2016-12-31 00:12:00)': datetime.datetime(2016, 12, 31, 0, 12), 'typeof(to_timestamp_ltz(2016-12-31 00:12:00))': 'timestamp', 'typeof(2016-12-31 00:12:00)': 'string'} +#query +#SELECT to_timestamp_ltz('2016-12-31 00:12:00'::string); + +## Original Query: SELECT to_timestamp_ltz('2016-12-31', 'yyyy-MM-dd'); +## PySpark 3.5.5 Result: {'to_timestamp_ltz(2016-12-31, yyyy-MM-dd)': datetime.datetime(2016, 12, 31, 0, 0), 'typeof(to_timestamp_ltz(2016-12-31, yyyy-MM-dd))': 'timestamp', 'typeof(2016-12-31)': 'string', 'typeof(yyyy-MM-dd)': 'string'} +#query +#SELECT to_timestamp_ltz('2016-12-31'::string, 'yyyy-MM-dd'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/to_timestamp_ntz.slt b/datafusion/sqllogictest/test_files/spark/datetime/to_timestamp_ntz.slt new file mode 100644 index 000000000000..5e93fcd067bf --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/to_timestamp_ntz.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_timestamp_ntz('2016-12-31 00:12:00'); +## PySpark 3.5.5 Result: {'to_timestamp_ntz(2016-12-31 00:12:00)': datetime.datetime(2016, 12, 31, 0, 12), 'typeof(to_timestamp_ntz(2016-12-31 00:12:00))': 'timestamp_ntz', 'typeof(2016-12-31 00:12:00)': 'string'} +#query +#SELECT to_timestamp_ntz('2016-12-31 00:12:00'::string); + +## Original Query: SELECT to_timestamp_ntz('2016-12-31', 'yyyy-MM-dd'); +## PySpark 3.5.5 Result: {'to_timestamp_ntz(2016-12-31, yyyy-MM-dd)': datetime.datetime(2016, 12, 31, 0, 0), 'typeof(to_timestamp_ntz(2016-12-31, yyyy-MM-dd))': 'timestamp_ntz', 'typeof(2016-12-31)': 'string', 'typeof(yyyy-MM-dd)': 'string'} +#query +#SELECT to_timestamp_ntz('2016-12-31'::string, 'yyyy-MM-dd'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/to_unix_timestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/to_unix_timestamp.slt new file mode 100644 index 000000000000..7b3589e4c8bb --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/to_unix_timestamp.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_unix_timestamp('2016-04-08', 'yyyy-MM-dd'); +## PySpark 3.5.5 Result: {'to_unix_timestamp(2016-04-08, yyyy-MM-dd)': 1460098800, 'typeof(to_unix_timestamp(2016-04-08, yyyy-MM-dd))': 'bigint', 'typeof(2016-04-08)': 'string', 'typeof(yyyy-MM-dd)': 'string'} +#query +#SELECT to_unix_timestamp('2016-04-08'::string, 'yyyy-MM-dd'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/to_utc_timestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/to_utc_timestamp.slt new file mode 100644 index 000000000000..e4c7b244fffe --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/to_utc_timestamp.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_utc_timestamp('2016-08-31', 'Asia/Seoul'); +## PySpark 3.5.5 Result: {'to_utc_timestamp(2016-08-31, Asia/Seoul)': datetime.datetime(2016, 8, 30, 15, 0), 'typeof(to_utc_timestamp(2016-08-31, Asia/Seoul))': 'timestamp', 'typeof(2016-08-31)': 'string', 'typeof(Asia/Seoul)': 'string'} +#query +#SELECT to_utc_timestamp('2016-08-31'::string, 'Asia/Seoul'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/trunc.slt b/datafusion/sqllogictest/test_files/spark/datetime/trunc.slt new file mode 100644 index 000000000000..f716c8e950c1 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/trunc.slt @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT trunc('2009-02-12', 'MM'); +## PySpark 3.5.5 Result: {'trunc(2009-02-12, MM)': datetime.date(2009, 2, 1), 'typeof(trunc(2009-02-12, MM))': 'date', 'typeof(2009-02-12)': 'string', 'typeof(MM)': 'string'} +#query +#SELECT trunc('2009-02-12'::string, 'MM'::string); + +## Original Query: SELECT trunc('2015-10-27', 'YEAR'); +## PySpark 3.5.5 Result: {'trunc(2015-10-27, YEAR)': datetime.date(2015, 1, 1), 'typeof(trunc(2015-10-27, YEAR))': 'date', 'typeof(2015-10-27)': 'string', 'typeof(YEAR)': 'string'} +#query +#SELECT trunc('2015-10-27'::string, 'YEAR'::string); + +## Original Query: SELECT trunc('2019-08-04', 'quarter'); +## PySpark 3.5.5 Result: {'trunc(2019-08-04, quarter)': datetime.date(2019, 7, 1), 'typeof(trunc(2019-08-04, quarter))': 'date', 'typeof(2019-08-04)': 'string', 'typeof(quarter)': 'string'} +#query +#SELECT trunc('2019-08-04'::string, 'quarter'::string); + +## Original Query: SELECT trunc('2019-08-04', 'week'); +## PySpark 3.5.5 Result: {'trunc(2019-08-04, week)': datetime.date(2019, 7, 29), 'typeof(trunc(2019-08-04, week))': 'date', 'typeof(2019-08-04)': 'string', 'typeof(week)': 'string'} +#query +#SELECT trunc('2019-08-04'::string, 'week'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/try_to_timestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/try_to_timestamp.slt new file mode 100644 index 000000000000..5e07b5a12e35 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/try_to_timestamp.slt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT try_to_timestamp('2016-12-31 00:12:00'); +## PySpark 3.5.5 Result: {'try_to_timestamp(2016-12-31 00:12:00)': datetime.datetime(2016, 12, 31, 0, 12), 'typeof(try_to_timestamp(2016-12-31 00:12:00))': 'timestamp', 'typeof(2016-12-31 00:12:00)': 'string'} +#query +#SELECT try_to_timestamp('2016-12-31 00:12:00'::string); + +## Original Query: SELECT try_to_timestamp('2016-12-31', 'yyyy-MM-dd'); +## PySpark 3.5.5 Result: {'try_to_timestamp(2016-12-31, yyyy-MM-dd)': datetime.datetime(2016, 12, 31, 0, 0), 'typeof(try_to_timestamp(2016-12-31, yyyy-MM-dd))': 'timestamp', 'typeof(2016-12-31)': 'string', 'typeof(yyyy-MM-dd)': 'string'} +#query +#SELECT try_to_timestamp('2016-12-31'::string, 'yyyy-MM-dd'::string); + +## Original Query: SELECT try_to_timestamp('foo', 'yyyy-MM-dd'); +## PySpark 3.5.5 Result: {'try_to_timestamp(foo, yyyy-MM-dd)': None, 'typeof(try_to_timestamp(foo, yyyy-MM-dd))': 'timestamp', 'typeof(foo)': 'string', 'typeof(yyyy-MM-dd)': 'string'} +#query +#SELECT try_to_timestamp('foo'::string, 'yyyy-MM-dd'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/unix_timestamp.slt b/datafusion/sqllogictest/test_files/spark/datetime/unix_timestamp.slt new file mode 100644 index 000000000000..3a128981c60b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/unix_timestamp.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT unix_timestamp('2016-04-08', 'yyyy-MM-dd'); +## PySpark 3.5.5 Result: {'unix_timestamp(2016-04-08, yyyy-MM-dd)': 1460098800, 'typeof(unix_timestamp(2016-04-08, yyyy-MM-dd))': 'bigint', 'typeof(2016-04-08)': 'string', 'typeof(yyyy-MM-dd)': 'string'} +#query +#SELECT unix_timestamp('2016-04-08'::string, 'yyyy-MM-dd'::string); + +## Original Query: SELECT unix_timestamp(); +## PySpark 3.5.5 Result: {'unix_timestamp(current_timestamp(), yyyy-MM-dd HH:mm:ss)': 1749970660, 'typeof(unix_timestamp(current_timestamp(), yyyy-MM-dd HH:mm:ss))': 'bigint'} +#query +#SELECT unix_timestamp(); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/weekday.slt b/datafusion/sqllogictest/test_files/spark/datetime/weekday.slt new file mode 100644 index 000000000000..7e65d600260d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/weekday.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT weekday('2009-07-30'); +## PySpark 3.5.5 Result: {'weekday(2009-07-30)': 3, 'typeof(weekday(2009-07-30))': 'int', 'typeof(2009-07-30)': 'string'} +#query +#SELECT weekday('2009-07-30'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/weekofyear.slt b/datafusion/sqllogictest/test_files/spark/datetime/weekofyear.slt new file mode 100644 index 000000000000..185bc51ffaeb --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/weekofyear.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT weekofyear('2008-02-20'); +## PySpark 3.5.5 Result: {'weekofyear(2008-02-20)': 8, 'typeof(weekofyear(2008-02-20))': 'int', 'typeof(2008-02-20)': 'string'} +#query +#SELECT weekofyear('2008-02-20'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/datetime/year.slt b/datafusion/sqllogictest/test_files/spark/datetime/year.slt new file mode 100644 index 000000000000..9a19a618f73f --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/datetime/year.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT year('2016-07-30'); +## PySpark 3.5.5 Result: {'year(2016-07-30)': 2016, 'typeof(year(2016-07-30))': 'int', 'typeof(2016-07-30)': 'string'} +#query +#SELECT year('2016-07-30'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/hash/crc32.slt b/datafusion/sqllogictest/test_files/spark/hash/crc32.slt new file mode 100644 index 000000000000..3fc442bbdde6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/hash/crc32.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT crc32('Spark'); +## PySpark 3.5.5 Result: {'crc32(Spark)': 1557323817, 'typeof(crc32(Spark))': 'bigint', 'typeof(Spark)': 'string'} +#query +#SELECT crc32('Spark'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/hash/md5.slt b/datafusion/sqllogictest/test_files/spark/hash/md5.slt new file mode 100644 index 000000000000..32aafcbc3768 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/hash/md5.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT md5('Spark'); +## PySpark 3.5.5 Result: {'md5(Spark)': '8cde774d6f7333752ed72cacddb05126', 'typeof(md5(Spark))': 'string', 'typeof(Spark)': 'string'} +#query +#SELECT md5('Spark'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/hash/sha.slt b/datafusion/sqllogictest/test_files/spark/hash/sha.slt new file mode 100644 index 000000000000..30965f26843b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/hash/sha.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT sha('Spark'); +## PySpark 3.5.5 Result: {'sha(Spark)': '85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c', 'typeof(sha(Spark))': 'string', 'typeof(Spark)': 'string'} +#query +#SELECT sha('Spark'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/hash/sha1.slt b/datafusion/sqllogictest/test_files/spark/hash/sha1.slt new file mode 100644 index 000000000000..e245cd5b3a2f --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/hash/sha1.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT sha1('Spark'); +## PySpark 3.5.5 Result: {'sha1(Spark)': '85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c', 'typeof(sha1(Spark))': 'string', 'typeof(Spark)': 'string'} +#query +#SELECT sha1('Spark'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/hash/sha2.slt b/datafusion/sqllogictest/test_files/spark/hash/sha2.slt new file mode 100644 index 000000000000..7690a38773b0 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/hash/sha2.slt @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query T +SELECT sha2('Spark', 0::INT); +---- +529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b + +query T +SELECT sha2('Spark', 256::INT); +---- +529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b + +query T +SELECT sha2('Spark', 224::INT); +---- +dbeab94971678d36af2195851c0f7485775a2a7c60073d62fc04549c + +query T +SELECT sha2('Spark', 384::INT); +---- +1e40b8d06c248a1cc32428c22582b6219d072283078fa140d9ad297ecadf2cabefc341b857ad36226aa8d6d79f2ab67d + +query T +SELECT sha2('Spark', 512::INT); +---- +44844a586c54c9a212da1dbfe05c5f1705de1af5fda1f0d36297623249b279fd8f0ccec03f888f4fb13bf7cd83fdad58591c797f81121a23cfdd5e0897795238 + +query T +SELECT sha2('Spark', 128::INT); +---- +NULL + +query T +SELECT sha2(expr, 256::INT) FROM VALUES ('foo'), ('bar') AS t(expr); +---- +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae +fcde2b2edba56bf408601fb721fe9b5c338d10ee429ea04fae5511b68fbf8fb9 + +query T +SELECT sha2(expr, 128::INT) FROM VALUES ('foo'), ('bar') AS t(expr); +---- +NULL +NULL + +query T +SELECT sha2('foo', bit_length) FROM VALUES (0::INT), (256::INT), (224::INT), (384::INT), (512::INT), (128::INT) AS t(bit_length); +---- +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae +0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db +98c11ffdfdd540676b1a137cb1a22b2a70350c9a44171d6b1180c6be5cbb2ee3f79d532c8a1dd9ef2e8e08e752a3babb +f7fbba6e0636f890e56fbbf3283e524c6fa3204ae298382d624741d0dc6638326e282c41be5e4254d8820772c5518a2c5a8c0c7f7eda19594a7eb539453e1ed7 +NULL + +query T +SELECT sha2(expr, bit_length) FROM VALUES ('foo',0::INT), ('bar',224::INT), ('baz',384::INT), ('qux',512::INT), ('qux',128::INT) AS t(expr, bit_length); +---- +2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae +07daf010de7f7f0d8d76a76eb8d1eb40182c8d1e7a3877a6686c9bf0 +967004d25de4abc1bd6a7c9a216254a5ac0733e8ad96dc9f1ea0fad9619da7c32d654ec8ad8ba2f9b5728fed6633bd91 +8c6be9ed448a34883a13a13f4ead4aefa036b67dcda59020c01e57ea075ea8a4792d428f2c6fd0c09d1c49994d6c22789336e062188df29572ed07e7f9779c52 +NULL diff --git a/datafusion/sqllogictest/test_files/spark/json/get_json_object.slt b/datafusion/sqllogictest/test_files/spark/json/get_json_object.slt new file mode 100644 index 000000000000..bd4f1c35f5c8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/json/get_json_object.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT get_json_object('{"a":"b"}', '$.a'); +## PySpark 3.5.5 Result: {'get_json_object({"a":"b"}, $.a)': 'b', 'typeof(get_json_object({"a":"b"}, $.a))': 'string', 'typeof({"a":"b"})': 'string', 'typeof($.a)': 'string'} +#query +#SELECT get_json_object('{"a":"b"}'::string, '$.a'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/json/json_object_keys.slt b/datafusion/sqllogictest/test_files/spark/json/json_object_keys.slt new file mode 100644 index 000000000000..d890ac8d96b5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/json/json_object_keys.slt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT json_object_keys('{"f1":"abc","f2":{"f3":"a", "f4":"b"}}'); +## PySpark 3.5.5 Result: {'json_object_keys({"f1":"abc","f2":{"f3":"a", "f4":"b"}})': ['f1', 'f2'], 'typeof(json_object_keys({"f1":"abc","f2":{"f3":"a", "f4":"b"}}))': 'array', 'typeof({"f1":"abc","f2":{"f3":"a", "f4":"b"}})': 'string'} +#query +#SELECT json_object_keys('{"f1":"abc","f2":{"f3":"a", "f4":"b"}}'::string); + +## Original Query: SELECT json_object_keys('{"key": "value"}'); +## PySpark 3.5.5 Result: {'json_object_keys({"key": "value"})': ['key'], 'typeof(json_object_keys({"key": "value"}))': 'array', 'typeof({"key": "value"})': 'string'} +#query +#SELECT json_object_keys('{"key": "value"}'::string); + +## Original Query: SELECT json_object_keys('{}'); +## PySpark 3.5.5 Result: {'json_object_keys({})': [], 'typeof(json_object_keys({}))': 'array', 'typeof({})': 'string'} +#query +#SELECT json_object_keys('{}'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/math/abs.slt b/datafusion/sqllogictest/test_files/spark/math/abs.slt new file mode 100644 index 000000000000..bc2993726714 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/abs.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT abs(-1); +## PySpark 3.5.5 Result: {'abs(-1)': 1, 'typeof(abs(-1))': 'int', 'typeof(-1)': 'int'} +#query +#SELECT abs(-1::int); + +## Original Query: SELECT abs(INTERVAL -'1-1' YEAR TO MONTH); +## PySpark 3.5.5 Result: {"abs(INTERVAL '-1-1' YEAR TO MONTH)": 13, "typeof(abs(INTERVAL '-1-1' YEAR TO MONTH))": 'interval year to month', "typeof(INTERVAL '-1-1' YEAR TO MONTH)": 'interval year to month'} +#query +#SELECT abs(INTERVAL '-1-1' YEAR TO MONTH::interval year to month); + diff --git a/datafusion/sqllogictest/test_files/spark/math/acos.slt b/datafusion/sqllogictest/test_files/spark/math/acos.slt new file mode 100644 index 000000000000..1c16f395ca2f --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/acos.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT acos(1); +## PySpark 3.5.5 Result: {'ACOS(1)': 0.0, 'typeof(ACOS(1))': 'double', 'typeof(1)': 'int'} +#query +#SELECT acos(1::int); + +## Original Query: SELECT acos(2); +## PySpark 3.5.5 Result: {'ACOS(2)': nan, 'typeof(ACOS(2))': 'double', 'typeof(2)': 'int'} +#query +#SELECT acos(2::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/acosh.slt b/datafusion/sqllogictest/test_files/spark/math/acosh.slt new file mode 100644 index 000000000000..5ebfa190c4b9 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/acosh.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT acosh(0); +## PySpark 3.5.5 Result: {'ACOSH(0)': nan, 'typeof(ACOSH(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT acosh(0::int); + +## Original Query: SELECT acosh(1); +## PySpark 3.5.5 Result: {'ACOSH(1)': 0.0, 'typeof(ACOSH(1))': 'double', 'typeof(1)': 'int'} +#query +#SELECT acosh(1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/asin.slt b/datafusion/sqllogictest/test_files/spark/math/asin.slt new file mode 100644 index 000000000000..d782475a5f2a --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/asin.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT asin(0); +## PySpark 3.5.5 Result: {'ASIN(0)': 0.0, 'typeof(ASIN(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT asin(0::int); + +## Original Query: SELECT asin(2); +## PySpark 3.5.5 Result: {'ASIN(2)': nan, 'typeof(ASIN(2))': 'double', 'typeof(2)': 'int'} +#query +#SELECT asin(2::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/asinh.slt b/datafusion/sqllogictest/test_files/spark/math/asinh.slt new file mode 100644 index 000000000000..1169fa7fbda6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/asinh.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT asinh(0); +## PySpark 3.5.5 Result: {'ASINH(0)': 0.0, 'typeof(ASINH(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT asinh(0::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/atan.slt b/datafusion/sqllogictest/test_files/spark/math/atan.slt new file mode 100644 index 000000000000..c4b5baf1c77b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/atan.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT atan(0); +## PySpark 3.5.5 Result: {'ATAN(0)': 0.0, 'typeof(ATAN(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT atan(0::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/atan2.slt b/datafusion/sqllogictest/test_files/spark/math/atan2.slt new file mode 100644 index 000000000000..c4af082f7b27 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/atan2.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT atan2(0, 0); +## PySpark 3.5.5 Result: {'ATAN2(0, 0)': 0.0, 'typeof(ATAN2(0, 0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT atan2(0::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/atanh.slt b/datafusion/sqllogictest/test_files/spark/math/atanh.slt new file mode 100644 index 000000000000..00681b57d030 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/atanh.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT atanh(0); +## PySpark 3.5.5 Result: {'ATANH(0)': 0.0, 'typeof(ATANH(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT atanh(0::int); + +## Original Query: SELECT atanh(2); +## PySpark 3.5.5 Result: {'ATANH(2)': nan, 'typeof(ATANH(2))': 'double', 'typeof(2)': 'int'} +#query +#SELECT atanh(2::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/bin.slt b/datafusion/sqllogictest/test_files/spark/math/bin.slt new file mode 100644 index 000000000000..fab65eb837ee --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/bin.slt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT bin(-13); +## PySpark 3.5.5 Result: {'bin(-13)': '1111111111111111111111111111111111111111111111111111111111110011', 'typeof(bin(-13))': 'string', 'typeof(-13)': 'int'} +#query +#SELECT bin(-13::int); + +## Original Query: SELECT bin(13); +## PySpark 3.5.5 Result: {'bin(13)': '1101', 'typeof(bin(13))': 'string', 'typeof(13)': 'int'} +#query +#SELECT bin(13::int); + +## Original Query: SELECT bin(13.3); +## PySpark 3.5.5 Result: {'bin(13.3)': '1101', 'typeof(bin(13.3))': 'string', 'typeof(13.3)': 'decimal(3,1)'} +#query +#SELECT bin(13.3::decimal(3,1)); + diff --git a/datafusion/sqllogictest/test_files/spark/math/bround.slt b/datafusion/sqllogictest/test_files/spark/math/bround.slt new file mode 100644 index 000000000000..3db3f1ebf15c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/bround.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT bround(2.5, 0); +## PySpark 3.5.5 Result: {'bround(2.5, 0)': Decimal('2'), 'typeof(bround(2.5, 0))': 'decimal(2,0)', 'typeof(2.5)': 'decimal(2,1)', 'typeof(0)': 'int'} +#query +#SELECT bround(2.5::decimal(2,1), 0::int); + +## Original Query: SELECT bround(25, -1); +## PySpark 3.5.5 Result: {'bround(25, -1)': 20, 'typeof(bround(25, -1))': 'int', 'typeof(25)': 'int', 'typeof(-1)': 'int'} +#query +#SELECT bround(25::int, -1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/cbrt.slt b/datafusion/sqllogictest/test_files/spark/math/cbrt.slt new file mode 100644 index 000000000000..7fee600d4142 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/cbrt.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT cbrt(27.0); +## PySpark 3.5.5 Result: {'CBRT(27.0)': 3.0, 'typeof(CBRT(27.0))': 'double', 'typeof(27.0)': 'decimal(3,1)'} +#query +#SELECT cbrt(27.0::decimal(3,1)); + diff --git a/datafusion/sqllogictest/test_files/spark/math/ceil.slt b/datafusion/sqllogictest/test_files/spark/math/ceil.slt new file mode 100644 index 000000000000..1e6d86858fa2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/ceil.slt @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT ceil(-0.1); +## PySpark 3.5.5 Result: {'CEIL(-0.1)': Decimal('0'), 'typeof(CEIL(-0.1))': 'decimal(1,0)', 'typeof(-0.1)': 'decimal(1,1)'} +#query +#SELECT ceil(-0.1::decimal(1,1)); + +## Original Query: SELECT ceil(3.1411, -3); +## PySpark 3.5.5 Result: {'ceil(3.1411, -3)': Decimal('1000'), 'typeof(ceil(3.1411, -3))': 'decimal(4,0)', 'typeof(3.1411)': 'decimal(5,4)', 'typeof(-3)': 'int'} +#query +#SELECT ceil(3.1411::decimal(5,4), -3::int); + +## Original Query: SELECT ceil(3.1411, 3); +## PySpark 3.5.5 Result: {'ceil(3.1411, 3)': Decimal('3.142'), 'typeof(ceil(3.1411, 3))': 'decimal(5,3)', 'typeof(3.1411)': 'decimal(5,4)', 'typeof(3)': 'int'} +#query +#SELECT ceil(3.1411::decimal(5,4), 3::int); + +## Original Query: SELECT ceil(5); +## PySpark 3.5.5 Result: {'CEIL(5)': 5, 'typeof(CEIL(5))': 'bigint', 'typeof(5)': 'int'} +#query +#SELECT ceil(5::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/ceiling.slt b/datafusion/sqllogictest/test_files/spark/math/ceiling.slt new file mode 100644 index 000000000000..3db7f4a192a0 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/ceiling.slt @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT ceiling(-0.1); +## PySpark 3.5.5 Result: {'ceiling(-0.1)': Decimal('0'), 'typeof(ceiling(-0.1))': 'decimal(1,0)', 'typeof(-0.1)': 'decimal(1,1)'} +#query +#SELECT ceiling(-0.1::decimal(1,1)); + +## Original Query: SELECT ceiling(3.1411, -3); +## PySpark 3.5.5 Result: {'ceiling(3.1411, -3)': Decimal('1000'), 'typeof(ceiling(3.1411, -3))': 'decimal(4,0)', 'typeof(3.1411)': 'decimal(5,4)', 'typeof(-3)': 'int'} +#query +#SELECT ceiling(3.1411::decimal(5,4), -3::int); + +## Original Query: SELECT ceiling(3.1411, 3); +## PySpark 3.5.5 Result: {'ceiling(3.1411, 3)': Decimal('3.142'), 'typeof(ceiling(3.1411, 3))': 'decimal(5,3)', 'typeof(3.1411)': 'decimal(5,4)', 'typeof(3)': 'int'} +#query +#SELECT ceiling(3.1411::decimal(5,4), 3::int); + +## Original Query: SELECT ceiling(5); +## PySpark 3.5.5 Result: {'ceiling(5)': 5, 'typeof(ceiling(5))': 'bigint', 'typeof(5)': 'int'} +#query +#SELECT ceiling(5::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/conv.slt b/datafusion/sqllogictest/test_files/spark/math/conv.slt new file mode 100644 index 000000000000..bc6033f4f400 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/conv.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT conv('100', 2, 10); +## PySpark 3.5.5 Result: {'conv(100, 2, 10)': '4', 'typeof(conv(100, 2, 10))': 'string', 'typeof(100)': 'string', 'typeof(2)': 'int', 'typeof(10)': 'int'} +#query +#SELECT conv('100'::string, 2::int, 10::int); + +## Original Query: SELECT conv(-10, 16, -10); +## PySpark 3.5.5 Result: {'conv(-10, 16, -10)': '-16', 'typeof(conv(-10, 16, -10))': 'string', 'typeof(-10)': 'int', 'typeof(16)': 'int'} +#query +#SELECT conv(-10::int, 16::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/cos.slt b/datafusion/sqllogictest/test_files/spark/math/cos.slt new file mode 100644 index 000000000000..c122173805c3 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/cos.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT cos(0); +## PySpark 3.5.5 Result: {'COS(0)': 1.0, 'typeof(COS(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT cos(0::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/cosh.slt b/datafusion/sqllogictest/test_files/spark/math/cosh.slt new file mode 100644 index 000000000000..73313defc95e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/cosh.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT cosh(0); +## PySpark 3.5.5 Result: {'COSH(0)': 1.0, 'typeof(COSH(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT cosh(0::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/cot.slt b/datafusion/sqllogictest/test_files/spark/math/cot.slt new file mode 100644 index 000000000000..77ecea32502e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/cot.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT cot(1); +## PySpark 3.5.5 Result: {'COT(1)': 0.6420926159343306, 'typeof(COT(1))': 'double', 'typeof(1)': 'int'} +#query +#SELECT cot(1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/csc.slt b/datafusion/sqllogictest/test_files/spark/math/csc.slt new file mode 100644 index 000000000000..d64a7ffbf116 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/csc.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT csc(1); +## PySpark 3.5.5 Result: {'CSC(1)': 1.1883951057781212, 'typeof(CSC(1))': 'double', 'typeof(1)': 'int'} +#query +#SELECT csc(1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/degrees.slt b/datafusion/sqllogictest/test_files/spark/math/degrees.slt new file mode 100644 index 000000000000..95d5541ca01b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/degrees.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT degrees(3.141592653589793); +## PySpark 3.5.5 Result: {'DEGREES(3.141592653589793)': 180.0, 'typeof(DEGREES(3.141592653589793))': 'double', 'typeof(3.141592653589793)': 'decimal(16,15)'} +#query +#SELECT degrees(3.141592653589793::decimal(16,15)); + diff --git a/datafusion/sqllogictest/test_files/spark/math/e.slt b/datafusion/sqllogictest/test_files/spark/math/e.slt new file mode 100644 index 000000000000..ad4ed0bd9340 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/e.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT e(); +## PySpark 3.5.5 Result: {'E()': 2.718281828459045, 'typeof(E())': 'double'} +#query +#SELECT e(); + diff --git a/datafusion/sqllogictest/test_files/spark/math/exp.slt b/datafusion/sqllogictest/test_files/spark/math/exp.slt new file mode 100644 index 000000000000..9ee28533ffb8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/exp.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT exp(0); +## PySpark 3.5.5 Result: {'EXP(0)': 1.0, 'typeof(EXP(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT exp(0::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/expm1.slt b/datafusion/sqllogictest/test_files/spark/math/expm1.slt new file mode 100644 index 000000000000..96d4abb0414b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/expm1.slt @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query R +SELECT expm1(0::INT); +---- +0 + +query R +SELECT expm1(1::INT); +---- +1.718281828459045 + +query R +SELECT expm1(a) FROM (VALUES (0::INT), (1::INT)) AS t(a); +---- +0 +1.718281828459045 diff --git a/datafusion/sqllogictest/test_files/spark/math/factorial.slt b/datafusion/sqllogictest/test_files/spark/math/factorial.slt new file mode 100644 index 000000000000..f8eae5d95ab8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/factorial.slt @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT factorial(5); +## PySpark 3.5.5 Result: {'factorial(5)': 120, 'typeof(factorial(5))': 'bigint', 'typeof(5)': 'int'} +query I +SELECT factorial(5::INT); +---- +120 + +query I +SELECT factorial(a) +FROM VALUES + (-1::INT), + (0::INT), (1::INT), (2::INT), (3::INT), (4::INT), (5::INT), (6::INT), (7::INT), (8::INT), (9::INT), (10::INT), + (11::INT), (12::INT), (13::INT), (14::INT), (15::INT), (16::INT), (17::INT), (18::INT), (19::INT), (20::INT), + (21::INT), + (NULL) AS t(a); +---- +NULL +1 +1 +2 +6 +24 +120 +720 +5040 +40320 +362880 +3628800 +39916800 +479001600 +6227020800 +87178291200 +1307674368000 +20922789888000 +355687428096000 +6402373705728000 +121645100408832000 +2432902008176640000 +NULL +NULL + +query error Error during planning: Failed to coerce arguments to satisfy a call to 'factorial' function +SELECT factorial(5::BIGINT); diff --git a/datafusion/sqllogictest/test_files/spark/math/floor.slt b/datafusion/sqllogictest/test_files/spark/math/floor.slt new file mode 100644 index 000000000000..5e4a63a1a24d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/floor.slt @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT floor(-0.1); +## PySpark 3.5.5 Result: {'FLOOR(-0.1)': Decimal('-1'), 'typeof(FLOOR(-0.1))': 'decimal(1,0)', 'typeof(-0.1)': 'decimal(1,1)'} +#query +#SELECT floor(-0.1::decimal(1,1)); + +## Original Query: SELECT floor(3.1411, -3); +## PySpark 3.5.5 Result: {'floor(3.1411, -3)': Decimal('0'), 'typeof(floor(3.1411, -3))': 'decimal(4,0)', 'typeof(3.1411)': 'decimal(5,4)', 'typeof(-3)': 'int'} +#query +#SELECT floor(3.1411::decimal(5,4), -3::int); + +## Original Query: SELECT floor(3.1411, 3); +## PySpark 3.5.5 Result: {'floor(3.1411, 3)': Decimal('3.141'), 'typeof(floor(3.1411, 3))': 'decimal(5,3)', 'typeof(3.1411)': 'decimal(5,4)', 'typeof(3)': 'int'} +#query +#SELECT floor(3.1411::decimal(5,4), 3::int); + +## Original Query: SELECT floor(5); +## PySpark 3.5.5 Result: {'FLOOR(5)': 5, 'typeof(FLOOR(5))': 'bigint', 'typeof(5)': 'int'} +#query +#SELECT floor(5::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/greatest.slt b/datafusion/sqllogictest/test_files/spark/math/greatest.slt new file mode 100644 index 000000000000..51cdb0d8613c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/greatest.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT greatest(10, 9, 2, 4, 3); +## PySpark 3.5.5 Result: {'greatest(10, 9, 2, 4, 3)': 10, 'typeof(greatest(10, 9, 2, 4, 3))': 'int', 'typeof(10)': 'int', 'typeof(9)': 'int', 'typeof(2)': 'int', 'typeof(4)': 'int', 'typeof(3)': 'int'} +#query +#SELECT greatest(10::int, 9::int, 2::int, 4::int, 3::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/hex.slt b/datafusion/sqllogictest/test_files/spark/math/hex.slt new file mode 100644 index 000000000000..24db1a318358 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/hex.slt @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query T +SELECT hex('Spark SQL'); +---- +537061726B2053514C + +query T +SELECT hex(1234::INT); +---- +4D2 + +query T +SELECT hex(a) from VALUES (1234::INT), (NULL), (456::INT) AS t(a); +---- +4D2 +NULL +1C8 + +query T +SELECT hex(a) from VALUES ('foo'), (NULL), ('foobarbaz') AS t(a); +---- +666F6F +NULL +666F6F62617262617A diff --git a/datafusion/sqllogictest/test_files/spark/math/hypot.slt b/datafusion/sqllogictest/test_files/spark/math/hypot.slt new file mode 100644 index 000000000000..3a087f565450 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/hypot.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT hypot(3, 4); +## PySpark 3.5.5 Result: {'HYPOT(3, 4)': 5.0, 'typeof(HYPOT(3, 4))': 'double', 'typeof(3)': 'int', 'typeof(4)': 'int'} +#query +#SELECT hypot(3::int, 4::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/least.slt b/datafusion/sqllogictest/test_files/spark/math/least.slt new file mode 100644 index 000000000000..d0ef2e66998f --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/least.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT least(10, 9, 2, 4, 3); +## PySpark 3.5.5 Result: {'least(10, 9, 2, 4, 3)': 2, 'typeof(least(10, 9, 2, 4, 3))': 'int', 'typeof(10)': 'int', 'typeof(9)': 'int', 'typeof(2)': 'int', 'typeof(4)': 'int', 'typeof(3)': 'int'} +#query +#SELECT least(10::int, 9::int, 2::int, 4::int, 3::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/ln.slt b/datafusion/sqllogictest/test_files/spark/math/ln.slt new file mode 100644 index 000000000000..ec8fbfeabb3b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/ln.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT ln(1); +## PySpark 3.5.5 Result: {'ln(1)': 0.0, 'typeof(ln(1))': 'double', 'typeof(1)': 'int'} +#query +#SELECT ln(1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/log.slt b/datafusion/sqllogictest/test_files/spark/math/log.slt new file mode 100644 index 000000000000..a439fd4f21a7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/log.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT log(10, 100); +## PySpark 3.5.5 Result: {'LOG(10, 100)': 2.0, 'typeof(LOG(10, 100))': 'double', 'typeof(10)': 'int', 'typeof(100)': 'int'} +#query +#SELECT log(10::int, 100::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/log10.slt b/datafusion/sqllogictest/test_files/spark/math/log10.slt new file mode 100644 index 000000000000..d4867c388f49 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/log10.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT log10(10); +## PySpark 3.5.5 Result: {'LOG10(10)': 1.0, 'typeof(LOG10(10))': 'double', 'typeof(10)': 'int'} +#query +#SELECT log10(10::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/log1p.slt b/datafusion/sqllogictest/test_files/spark/math/log1p.slt new file mode 100644 index 000000000000..1602263d85c3 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/log1p.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT log1p(0); +## PySpark 3.5.5 Result: {'LOG1P(0)': 0.0, 'typeof(LOG1P(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT log1p(0::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/log2.slt b/datafusion/sqllogictest/test_files/spark/math/log2.slt new file mode 100644 index 000000000000..e52f2c7a7ea0 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/log2.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT log2(2); +## PySpark 3.5.5 Result: {'LOG2(2)': 1.0, 'typeof(LOG2(2))': 'double', 'typeof(2)': 'int'} +#query +#SELECT log2(2::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/mod.slt b/datafusion/sqllogictest/test_files/spark/math/mod.slt new file mode 100644 index 000000000000..b39db3aac4b8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/mod.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT MOD(2, 1.8); +## PySpark 3.5.5 Result: {'mod(2, 1.8)': Decimal('0.2'), 'typeof(mod(2, 1.8))': 'decimal(2,1)', 'typeof(2)': 'int', 'typeof(1.8)': 'decimal(2,1)'} +#query +#SELECT MOD(2::int, 1.8::decimal(2,1)); + diff --git a/datafusion/sqllogictest/test_files/spark/math/negative.slt b/datafusion/sqllogictest/test_files/spark/math/negative.slt new file mode 100644 index 000000000000..00688a98d782 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/negative.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT negative(1); +## PySpark 3.5.5 Result: {'negative(1)': -1, 'typeof(negative(1))': 'int', 'typeof(1)': 'int'} +#query +#SELECT negative(1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/pi.slt b/datafusion/sqllogictest/test_files/spark/math/pi.slt new file mode 100644 index 000000000000..147991780204 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/pi.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT pi(); +## PySpark 3.5.5 Result: {'PI()': 3.141592653589793, 'typeof(PI())': 'double'} +#query +#SELECT pi(); + diff --git a/datafusion/sqllogictest/test_files/spark/math/pmod.slt b/datafusion/sqllogictest/test_files/spark/math/pmod.slt new file mode 100644 index 000000000000..1b751fc8e016 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/pmod.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT pmod(-10, 3); +## PySpark 3.5.5 Result: {'pmod(-10, 3)': 2, 'typeof(pmod(-10, 3))': 'int', 'typeof(-10)': 'int', 'typeof(3)': 'int'} +#query +#SELECT pmod(-10::int, 3::int); + +## Original Query: SELECT pmod(10, 3); +## PySpark 3.5.5 Result: {'pmod(10, 3)': 1, 'typeof(pmod(10, 3))': 'int', 'typeof(10)': 'int', 'typeof(3)': 'int'} +#query +#SELECT pmod(10::int, 3::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/positive.slt b/datafusion/sqllogictest/test_files/spark/math/positive.slt new file mode 100644 index 000000000000..c6a44bb5610d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/positive.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT positive(1); +## PySpark 3.5.5 Result: {'(+ 1)': 1, 'typeof((+ 1))': 'int', 'typeof(1)': 'int'} +#query +#SELECT positive(1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/pow.slt b/datafusion/sqllogictest/test_files/spark/math/pow.slt new file mode 100644 index 000000000000..83d85ce0e57f --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/pow.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT pow(2, 3); +## PySpark 3.5.5 Result: {'pow(2, 3)': 8.0, 'typeof(pow(2, 3))': 'double', 'typeof(2)': 'int', 'typeof(3)': 'int'} +#query +#SELECT pow(2::int, 3::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/power.slt b/datafusion/sqllogictest/test_files/spark/math/power.slt new file mode 100644 index 000000000000..3e56944d1304 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/power.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT power(2, 3); +## PySpark 3.5.5 Result: {'POWER(2, 3)': 8.0, 'typeof(POWER(2, 3))': 'double', 'typeof(2)': 'int', 'typeof(3)': 'int'} +#query +#SELECT power(2::int, 3::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/radians.slt b/datafusion/sqllogictest/test_files/spark/math/radians.slt new file mode 100644 index 000000000000..e65177d6e208 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/radians.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT radians(180); +## PySpark 3.5.5 Result: {'RADIANS(180)': 3.141592653589793, 'typeof(RADIANS(180))': 'double', 'typeof(180)': 'int'} +#query +#SELECT radians(180::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/rand.slt b/datafusion/sqllogictest/test_files/spark/math/rand.slt new file mode 100644 index 000000000000..af24e996743d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/rand.slt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT rand(); +## PySpark 3.5.5 Result: {'rand()': 0.949892358232337, 'typeof(rand())': 'double'} +#query +#SELECT rand(); + +## Original Query: SELECT rand(0); +## PySpark 3.5.5 Result: {'rand(0)': 0.7604953758285915, 'typeof(rand(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT rand(0::int); + +## Original Query: SELECT rand(null); +## PySpark 3.5.5 Result: {'rand(NULL)': 0.7604953758285915, 'typeof(rand(NULL))': 'double', 'typeof(NULL)': 'void'} +#query +#SELECT rand(NULL::void); + diff --git a/datafusion/sqllogictest/test_files/spark/math/randn.slt b/datafusion/sqllogictest/test_files/spark/math/randn.slt new file mode 100644 index 000000000000..8e1a6ec79805 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/randn.slt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT randn(); +## PySpark 3.5.5 Result: {'randn()': 1.498983714060803, 'typeof(randn())': 'double'} +#query +#SELECT randn(); + +## Original Query: SELECT randn(0); +## PySpark 3.5.5 Result: {'randn(0)': 1.6034991609278433, 'typeof(randn(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT randn(0::int); + +## Original Query: SELECT randn(null); +## PySpark 3.5.5 Result: {'randn(NULL)': 1.6034991609278433, 'typeof(randn(NULL))': 'double', 'typeof(NULL)': 'void'} +#query +#SELECT randn(NULL::void); + diff --git a/datafusion/sqllogictest/test_files/spark/math/random.slt b/datafusion/sqllogictest/test_files/spark/math/random.slt new file mode 100644 index 000000000000..31e4a0e63360 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/random.slt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT random(); +## PySpark 3.5.5 Result: {'rand()': 0.7460731389309176, 'typeof(rand())': 'double'} +#query +#SELECT random(); + +## Original Query: SELECT random(0); +## PySpark 3.5.5 Result: {'rand(0)': 0.7604953758285915, 'typeof(rand(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT random(0::int); + +## Original Query: SELECT random(null); +## PySpark 3.5.5 Result: {'rand(NULL)': 0.7604953758285915, 'typeof(rand(NULL))': 'double', 'typeof(NULL)': 'void'} +#query +#SELECT random(NULL::void); + diff --git a/datafusion/sqllogictest/test_files/spark/math/rint.slt b/datafusion/sqllogictest/test_files/spark/math/rint.slt new file mode 100644 index 000000000000..c62438faed95 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/rint.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT rint(12.3456); +## PySpark 3.5.5 Result: {'rint(12.3456)': 12.0, 'typeof(rint(12.3456))': 'double', 'typeof(12.3456)': 'decimal(6,4)'} +#query +#SELECT rint(12.3456::decimal(6,4)); + diff --git a/datafusion/sqllogictest/test_files/spark/math/round.slt b/datafusion/sqllogictest/test_files/spark/math/round.slt new file mode 100644 index 000000000000..0b7ac371d560 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/round.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT round(2.5, 0); +## PySpark 3.5.5 Result: {'round(2.5, 0)': Decimal('3'), 'typeof(round(2.5, 0))': 'decimal(2,0)', 'typeof(2.5)': 'decimal(2,1)', 'typeof(0)': 'int'} +#query +#SELECT round(2.5::decimal(2,1), 0::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/sec.slt b/datafusion/sqllogictest/test_files/spark/math/sec.slt new file mode 100644 index 000000000000..5c5328e55c9d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/sec.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT sec(0); +## PySpark 3.5.5 Result: {'SEC(0)': 1.0, 'typeof(SEC(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT sec(0::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/shiftleft.slt b/datafusion/sqllogictest/test_files/spark/math/shiftleft.slt new file mode 100644 index 000000000000..92db7f6bc26e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/shiftleft.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT shiftleft(2, 1); +## PySpark 3.5.5 Result: {'shiftleft(2, 1)': 4, 'typeof(shiftleft(2, 1))': 'int', 'typeof(2)': 'int', 'typeof(1)': 'int'} +#query +#SELECT shiftleft(2::int, 1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/sign.slt b/datafusion/sqllogictest/test_files/spark/math/sign.slt new file mode 100644 index 000000000000..b0f5ce963105 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/sign.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT sign(40); +## PySpark 3.5.5 Result: {'sign(40)': 1.0, 'typeof(sign(40))': 'double', 'typeof(40)': 'int'} +#query +#SELECT sign(40::int); + +## Original Query: SELECT sign(INTERVAL -'100' YEAR); +## PySpark 3.5.5 Result: {"sign(INTERVAL '-100' YEAR)": -1.0, "typeof(sign(INTERVAL '-100' YEAR))": 'double', "typeof(INTERVAL '-100' YEAR)": 'interval year'} +#query +#SELECT sign(INTERVAL '-100' YEAR::interval year); + diff --git a/datafusion/sqllogictest/test_files/spark/math/signum.slt b/datafusion/sqllogictest/test_files/spark/math/signum.slt new file mode 100644 index 000000000000..3531196ad06b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/signum.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT signum(40); +## PySpark 3.5.5 Result: {'SIGNUM(40)': 1.0, 'typeof(SIGNUM(40))': 'double', 'typeof(40)': 'int'} +#query +#SELECT signum(40::int); + +## Original Query: SELECT signum(INTERVAL -'100' YEAR); +## PySpark 3.5.5 Result: {"SIGNUM(INTERVAL '-100' YEAR)": -1.0, "typeof(SIGNUM(INTERVAL '-100' YEAR))": 'double', "typeof(INTERVAL '-100' YEAR)": 'interval year'} +#query +#SELECT signum(INTERVAL '-100' YEAR::interval year); + diff --git a/datafusion/sqllogictest/test_files/spark/math/sin.slt b/datafusion/sqllogictest/test_files/spark/math/sin.slt new file mode 100644 index 000000000000..994ff492f9e3 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/sin.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT sin(0); +## PySpark 3.5.5 Result: {'SIN(0)': 0.0, 'typeof(SIN(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT sin(0::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/sinh.slt b/datafusion/sqllogictest/test_files/spark/math/sinh.slt new file mode 100644 index 000000000000..c743a81e2362 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/sinh.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT sinh(0); +## PySpark 3.5.5 Result: {'SINH(0)': 0.0, 'typeof(SINH(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT sinh(0::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/sqrt.slt b/datafusion/sqllogictest/test_files/spark/math/sqrt.slt new file mode 100644 index 000000000000..7c2eaccabf17 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/sqrt.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT sqrt(4); +## PySpark 3.5.5 Result: {'SQRT(4)': 2.0, 'typeof(SQRT(4))': 'double', 'typeof(4)': 'int'} +#query +#SELECT sqrt(4::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/tan.slt b/datafusion/sqllogictest/test_files/spark/math/tan.slt new file mode 100644 index 000000000000..5880edea52c7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/tan.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT tan(0); +## PySpark 3.5.5 Result: {'TAN(0)': 0.0, 'typeof(TAN(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT tan(0::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/tanh.slt b/datafusion/sqllogictest/test_files/spark/math/tanh.slt new file mode 100644 index 000000000000..5db2e167a702 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/tanh.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT tanh(0); +## PySpark 3.5.5 Result: {'TANH(0)': 0.0, 'typeof(TANH(0))': 'double', 'typeof(0)': 'int'} +#query +#SELECT tanh(0::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/try_add.slt b/datafusion/sqllogictest/test_files/spark/math/try_add.slt new file mode 100644 index 000000000000..fb8653ce3795 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/try_add.slt @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT try_add(1, 2); +## PySpark 3.5.5 Result: {'try_add(1, 2)': 3, 'typeof(try_add(1, 2))': 'int', 'typeof(1)': 'int', 'typeof(2)': 'int'} +#query +#SELECT try_add(1::int, 2::int); + +## Original Query: SELECT try_add(2147483647, 1); +## PySpark 3.5.5 Result: {'try_add(2147483647, 1)': None, 'typeof(try_add(2147483647, 1))': 'int', 'typeof(2147483647)': 'int', 'typeof(1)': 'int'} +#query +#SELECT try_add(2147483647::int, 1::int); + +## Original Query: SELECT try_add(date'2021-01-01', 1); +## PySpark 3.5.5 Result: {"try_add(DATE '2021-01-01', 1)": datetime.date(2021, 1, 2), "typeof(try_add(DATE '2021-01-01', 1))": 'date', "typeof(DATE '2021-01-01')": 'date', 'typeof(1)': 'int'} +#query +#SELECT try_add(DATE '2021-01-01'::date, 1::int); + +## Original Query: SELECT try_add(date'2021-01-01', interval 1 year); +## PySpark 3.5.5 Result: {"try_add(DATE '2021-01-01', INTERVAL '1' YEAR)": datetime.date(2022, 1, 1), "typeof(try_add(DATE '2021-01-01', INTERVAL '1' YEAR))": 'date', "typeof(DATE '2021-01-01')": 'date', "typeof(INTERVAL '1' YEAR)": 'interval year'} +#query +#SELECT try_add(DATE '2021-01-01'::date, INTERVAL '1' YEAR::interval year); + +## Original Query: SELECT try_add(interval 1 year, interval 2 year); +## PySpark 3.5.5 Result: {"try_add(INTERVAL '1' YEAR, INTERVAL '2' YEAR)": 36, "typeof(try_add(INTERVAL '1' YEAR, INTERVAL '2' YEAR))": 'interval year', "typeof(INTERVAL '1' YEAR)": 'interval year', "typeof(INTERVAL '2' YEAR)": 'interval year'} +#query +#SELECT try_add(INTERVAL '1' YEAR::interval year, INTERVAL '2' YEAR::interval year); + +## Original Query: SELECT try_add(timestamp'2021-01-01 00:00:00', interval 1 day); +## PySpark 3.5.5 Result: {"try_add(TIMESTAMP '2021-01-01 00:00:00', INTERVAL '1' DAY)": datetime.datetime(2021, 1, 2, 0, 0), "typeof(try_add(TIMESTAMP '2021-01-01 00:00:00', INTERVAL '1' DAY))": 'timestamp', "typeof(TIMESTAMP '2021-01-01 00:00:00')": 'timestamp', "typeof(INTERVAL '1' DAY)": 'interval day'} +#query +#SELECT try_add(TIMESTAMP '2021-01-01 00:00:00'::timestamp, INTERVAL '1' DAY::interval day); + diff --git a/datafusion/sqllogictest/test_files/spark/math/try_divide.slt b/datafusion/sqllogictest/test_files/spark/math/try_divide.slt new file mode 100644 index 000000000000..1e2e6b555f5e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/try_divide.slt @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT try_divide(1, 0); +## PySpark 3.5.5 Result: {'try_divide(1, 0)': None, 'typeof(try_divide(1, 0))': 'double', 'typeof(1)': 'int', 'typeof(0)': 'int'} +#query +#SELECT try_divide(1::int, 0::int); + +## Original Query: SELECT try_divide(2L, 2L); +## PySpark 3.5.5 Result: {'try_divide(2, 2)': 1.0, 'typeof(try_divide(2, 2))': 'double', 'typeof(2)': 'bigint'} +#query +#SELECT try_divide(2::bigint); + +## Original Query: SELECT try_divide(3, 2); +## PySpark 3.5.5 Result: {'try_divide(3, 2)': 1.5, 'typeof(try_divide(3, 2))': 'double', 'typeof(3)': 'int', 'typeof(2)': 'int'} +#query +#SELECT try_divide(3::int, 2::int); + +## Original Query: SELECT try_divide(interval 2 month, 0); +## PySpark 3.5.5 Result: {"try_divide(INTERVAL '2' MONTH, 0)": None, "typeof(try_divide(INTERVAL '2' MONTH, 0))": 'interval year to month', "typeof(INTERVAL '2' MONTH)": 'interval month', 'typeof(0)': 'int'} +#query +#SELECT try_divide(INTERVAL '2' MONTH::interval month, 0::int); + +## Original Query: SELECT try_divide(interval 2 month, 2); +## PySpark 3.5.5 Result: {"try_divide(INTERVAL '2' MONTH, 2)": 1, "typeof(try_divide(INTERVAL '2' MONTH, 2))": 'interval year to month', "typeof(INTERVAL '2' MONTH)": 'interval month', 'typeof(2)': 'int'} +#query +#SELECT try_divide(INTERVAL '2' MONTH::interval month, 2::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/try_multiply.slt b/datafusion/sqllogictest/test_files/spark/math/try_multiply.slt new file mode 100644 index 000000000000..f5eaad6841c7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/try_multiply.slt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT try_multiply(-2147483648, 10); +## PySpark 3.5.5 Result: {'try_multiply(-2147483648, 10)': None, 'typeof(try_multiply(-2147483648, 10))': 'int', 'typeof(-2147483648)': 'int', 'typeof(10)': 'int'} +#query +#SELECT try_multiply(-2147483648::int, 10::int); + +## Original Query: SELECT try_multiply(2, 3); +## PySpark 3.5.5 Result: {'try_multiply(2, 3)': 6, 'typeof(try_multiply(2, 3))': 'int', 'typeof(2)': 'int', 'typeof(3)': 'int'} +#query +#SELECT try_multiply(2::int, 3::int); + +## Original Query: SELECT try_multiply(interval 2 year, 3); +## PySpark 3.5.5 Result: {"try_multiply(INTERVAL '2' YEAR, 3)": 72, "typeof(try_multiply(INTERVAL '2' YEAR, 3))": 'interval year to month', "typeof(INTERVAL '2' YEAR)": 'interval year', 'typeof(3)': 'int'} +#query +#SELECT try_multiply(INTERVAL '2' YEAR::interval year, 3::int); + diff --git a/datafusion/sqllogictest/test_files/spark/math/try_subtract.slt b/datafusion/sqllogictest/test_files/spark/math/try_subtract.slt new file mode 100644 index 000000000000..30af6877bf75 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/try_subtract.slt @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT try_subtract(-2147483648, 1); +## PySpark 3.5.5 Result: {'try_subtract(-2147483648, 1)': None, 'typeof(try_subtract(-2147483648, 1))': 'int', 'typeof(-2147483648)': 'int', 'typeof(1)': 'int'} +#query +#SELECT try_subtract(-2147483648::int, 1::int); + +## Original Query: SELECT try_subtract(2, 1); +## PySpark 3.5.5 Result: {'try_subtract(2, 1)': 1, 'typeof(try_subtract(2, 1))': 'int', 'typeof(2)': 'int', 'typeof(1)': 'int'} +#query +#SELECT try_subtract(2::int, 1::int); + +## Original Query: SELECT try_subtract(date'2021-01-01', interval 1 year); +## PySpark 3.5.5 Result: {"try_subtract(DATE '2021-01-01', INTERVAL '1' YEAR)": datetime.date(2020, 1, 1), "typeof(try_subtract(DATE '2021-01-01', INTERVAL '1' YEAR))": 'date', "typeof(DATE '2021-01-01')": 'date', "typeof(INTERVAL '1' YEAR)": 'interval year'} +#query +#SELECT try_subtract(DATE '2021-01-01'::date, INTERVAL '1' YEAR::interval year); + +## Original Query: SELECT try_subtract(date'2021-01-02', 1); +## PySpark 3.5.5 Result: {"try_subtract(DATE '2021-01-02', 1)": datetime.date(2021, 1, 1), "typeof(try_subtract(DATE '2021-01-02', 1))": 'date', "typeof(DATE '2021-01-02')": 'date', 'typeof(1)': 'int'} +#query +#SELECT try_subtract(DATE '2021-01-02'::date, 1::int); + +## Original Query: SELECT try_subtract(interval 2 year, interval 1 year); +## PySpark 3.5.5 Result: {"try_subtract(INTERVAL '2' YEAR, INTERVAL '1' YEAR)": 12, "typeof(try_subtract(INTERVAL '2' YEAR, INTERVAL '1' YEAR))": 'interval year', "typeof(INTERVAL '2' YEAR)": 'interval year', "typeof(INTERVAL '1' YEAR)": 'interval year'} +#query +#SELECT try_subtract(INTERVAL '2' YEAR::interval year, INTERVAL '1' YEAR::interval year); + +## Original Query: SELECT try_subtract(timestamp'2021-01-02 00:00:00', interval 1 day); +## PySpark 3.5.5 Result: {"try_subtract(TIMESTAMP '2021-01-02 00:00:00', INTERVAL '1' DAY)": datetime.datetime(2021, 1, 1, 0, 0), "typeof(try_subtract(TIMESTAMP '2021-01-02 00:00:00', INTERVAL '1' DAY))": 'timestamp', "typeof(TIMESTAMP '2021-01-02 00:00:00')": 'timestamp', "typeof(INTERVAL '1' DAY)": 'interval day'} +#query +#SELECT try_subtract(TIMESTAMP '2021-01-02 00:00:00'::timestamp, INTERVAL '1' DAY::interval day); + diff --git a/datafusion/sqllogictest/test_files/spark/math/width_bucket.slt b/datafusion/sqllogictest/test_files/spark/math/width_bucket.slt new file mode 100644 index 000000000000..b01ad9e587a8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/math/width_bucket.slt @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT width_bucket(-0.9, 5.2, 0.5, 2); +## PySpark 3.5.5 Result: {'width_bucket(-0.9, 5.2, 0.5, 2)': 3, 'typeof(width_bucket(-0.9, 5.2, 0.5, 2))': 'bigint', 'typeof(-0.9)': 'decimal(1,1)', 'typeof(5.2)': 'decimal(2,1)', 'typeof(0.5)': 'decimal(1,1)', 'typeof(2)': 'int'} +#query +#SELECT width_bucket(-0.9::decimal(1,1), 5.2::decimal(2,1), 0.5::decimal(1,1), 2::int); + +## Original Query: SELECT width_bucket(-2.1, 1.3, 3.4, 3); +## PySpark 3.5.5 Result: {'width_bucket(-2.1, 1.3, 3.4, 3)': 0, 'typeof(width_bucket(-2.1, 1.3, 3.4, 3))': 'bigint', 'typeof(-2.1)': 'decimal(2,1)', 'typeof(1.3)': 'decimal(2,1)', 'typeof(3.4)': 'decimal(2,1)', 'typeof(3)': 'int'} +#query +#SELECT width_bucket(-2.1::decimal(2,1), 1.3::decimal(2,1), 3.4::decimal(2,1), 3::int); + +## Original Query: SELECT width_bucket(5.3, 0.2, 10.6, 5); +## PySpark 3.5.5 Result: {'width_bucket(5.3, 0.2, 10.6, 5)': 3, 'typeof(width_bucket(5.3, 0.2, 10.6, 5))': 'bigint', 'typeof(5.3)': 'decimal(2,1)', 'typeof(0.2)': 'decimal(1,1)', 'typeof(10.6)': 'decimal(3,1)', 'typeof(5)': 'int'} +#query +#SELECT width_bucket(5.3::decimal(2,1), 0.2::decimal(1,1), 10.6::decimal(3,1), 5::int); + +## Original Query: SELECT width_bucket(8.1, 0.0, 5.7, 4); +## PySpark 3.5.5 Result: {'width_bucket(8.1, 0.0, 5.7, 4)': 5, 'typeof(width_bucket(8.1, 0.0, 5.7, 4))': 'bigint', 'typeof(8.1)': 'decimal(2,1)', 'typeof(0.0)': 'decimal(1,1)', 'typeof(5.7)': 'decimal(2,1)', 'typeof(4)': 'int'} +#query +#SELECT width_bucket(8.1::decimal(2,1), 0.0::decimal(1,1), 5.7::decimal(2,1), 4::int); + +## Original Query: SELECT width_bucket(INTERVAL '0' DAY, INTERVAL '0' DAY, INTERVAL '10' DAY, 10); +## PySpark 3.5.5 Result: {"width_bucket(INTERVAL '0' DAY, INTERVAL '0' DAY, INTERVAL '10' DAY, 10)": 1, "typeof(width_bucket(INTERVAL '0' DAY, INTERVAL '0' DAY, INTERVAL '10' DAY, 10))": 'bigint', "typeof(INTERVAL '0' DAY)": 'interval day', "typeof(INTERVAL '10' DAY)": 'interval day', 'typeof(10)': 'int'} +#query +#SELECT width_bucket(INTERVAL '0' DAY::interval day, INTERVAL '10' DAY::interval day, 10::int); + +## Original Query: SELECT width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10); +## PySpark 3.5.5 Result: {"width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10)": 1, "typeof(width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10))": 'bigint', "typeof(INTERVAL '0' YEAR)": 'interval year', "typeof(INTERVAL '10' YEAR)": 'interval year', 'typeof(10)': 'int'} +#query +#SELECT width_bucket(INTERVAL '0' YEAR::interval year, INTERVAL '10' YEAR::interval year, 10::int); + +## Original Query: SELECT width_bucket(INTERVAL '1' DAY, INTERVAL '0' DAY, INTERVAL '10' DAY, 10); +## PySpark 3.5.5 Result: {"width_bucket(INTERVAL '1' DAY, INTERVAL '0' DAY, INTERVAL '10' DAY, 10)": 2, "typeof(width_bucket(INTERVAL '1' DAY, INTERVAL '0' DAY, INTERVAL '10' DAY, 10))": 'bigint', "typeof(INTERVAL '1' DAY)": 'interval day', "typeof(INTERVAL '0' DAY)": 'interval day', "typeof(INTERVAL '10' DAY)": 'interval day', 'typeof(10)': 'int'} +#query +#SELECT width_bucket(INTERVAL '1' DAY::interval day, INTERVAL '0' DAY::interval day, INTERVAL '10' DAY::interval day, 10::int); + +## Original Query: SELECT width_bucket(INTERVAL '1' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10); +## PySpark 3.5.5 Result: {"width_bucket(INTERVAL '1' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10)": 2, "typeof(width_bucket(INTERVAL '1' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10))": 'bigint', "typeof(INTERVAL '1' YEAR)": 'interval year', "typeof(INTERVAL '0' YEAR)": 'interval year', "typeof(INTERVAL '10' YEAR)": 'interval year', 'typeof(10)': 'int'} +#query +#SELECT width_bucket(INTERVAL '1' YEAR::interval year, INTERVAL '0' YEAR::interval year, INTERVAL '10' YEAR::interval year, 10::int); + diff --git a/datafusion/sqllogictest/test_files/spark/misc/assert_true.slt b/datafusion/sqllogictest/test_files/spark/misc/assert_true.slt new file mode 100644 index 000000000000..c55ff71b05a5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/assert_true.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT assert_true(0 < 1); +## PySpark 3.5.5 Result: {"assert_true((0 < 1), '(0 < 1)' is not true!)": None, "typeof(assert_true((0 < 1), '(0 < 1)' is not true!))": 'void', 'typeof((0 < 1))': 'boolean'} +#query +#SELECT assert_true((0 < 1)::boolean); + diff --git a/datafusion/sqllogictest/test_files/spark/misc/current_catalog.slt b/datafusion/sqllogictest/test_files/spark/misc/current_catalog.slt new file mode 100644 index 000000000000..3fd49775e407 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/current_catalog.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT current_catalog(); +## PySpark 3.5.5 Result: {'current_catalog()': 'spark_catalog', 'typeof(current_catalog())': 'string'} +#query +#SELECT current_catalog(); + diff --git a/datafusion/sqllogictest/test_files/spark/misc/current_database.slt b/datafusion/sqllogictest/test_files/spark/misc/current_database.slt new file mode 100644 index 000000000000..917ee660d805 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/current_database.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT current_database(); +## PySpark 3.5.5 Result: {'current_database()': 'default', 'typeof(current_database())': 'string'} +#query +#SELECT current_database(); + diff --git a/datafusion/sqllogictest/test_files/spark/misc/current_schema.slt b/datafusion/sqllogictest/test_files/spark/misc/current_schema.slt new file mode 100644 index 000000000000..b96d1077b8d1 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/current_schema.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT current_schema(); +## PySpark 3.5.5 Result: {'current_database()': 'default', 'typeof(current_database())': 'string'} +#query +#SELECT current_schema(); + diff --git a/datafusion/sqllogictest/test_files/spark/misc/current_user.slt b/datafusion/sqllogictest/test_files/spark/misc/current_user.slt new file mode 100644 index 000000000000..7e8ea20b323c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/current_user.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT current_user(); +## PySpark 3.5.5 Result: {'current_user()': 'r', 'typeof(current_user())': 'string'} +#query +#SELECT current_user(); + diff --git a/datafusion/sqllogictest/test_files/spark/misc/equal_null.slt b/datafusion/sqllogictest/test_files/spark/misc/equal_null.slt new file mode 100644 index 000000000000..a5f9b1090eb9 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/equal_null.slt @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT equal_null(1, '11'); +## PySpark 3.5.5 Result: {'equal_null(1, 11)': False, 'typeof(equal_null(1, 11))': 'boolean', 'typeof(1)': 'int', 'typeof(11)': 'string'} +#query +#SELECT equal_null(1::int, '11'::string); + +## Original Query: SELECT equal_null(3, 3); +## PySpark 3.5.5 Result: {'equal_null(3, 3)': True, 'typeof(equal_null(3, 3))': 'boolean', 'typeof(3)': 'int'} +#query +#SELECT equal_null(3::int); + +## Original Query: SELECT equal_null(NULL, 'abc'); +## PySpark 3.5.5 Result: {'equal_null(NULL, abc)': False, 'typeof(equal_null(NULL, abc))': 'boolean', 'typeof(NULL)': 'void', 'typeof(abc)': 'string'} +#query +#SELECT equal_null(NULL::void, 'abc'::string); + +## Original Query: SELECT equal_null(NULL, NULL); +## PySpark 3.5.5 Result: {'equal_null(NULL, NULL)': True, 'typeof(equal_null(NULL, NULL))': 'boolean', 'typeof(NULL)': 'void'} +#query +#SELECT equal_null(NULL::void); + +## Original Query: SELECT equal_null(true, NULL); +## PySpark 3.5.5 Result: {'equal_null(true, NULL)': False, 'typeof(equal_null(true, NULL))': 'boolean', 'typeof(true)': 'boolean', 'typeof(NULL)': 'void'} +#query +#SELECT equal_null(true::boolean, NULL::void); + diff --git a/datafusion/sqllogictest/test_files/spark/misc/input_file_block_length.slt b/datafusion/sqllogictest/test_files/spark/misc/input_file_block_length.slt new file mode 100644 index 000000000000..7872ef4c2857 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/input_file_block_length.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT input_file_block_length(); +## PySpark 3.5.5 Result: {'input_file_block_length()': -1, 'typeof(input_file_block_length())': 'bigint'} +#query +#SELECT input_file_block_length(); + diff --git a/datafusion/sqllogictest/test_files/spark/misc/input_file_block_start.slt b/datafusion/sqllogictest/test_files/spark/misc/input_file_block_start.slt new file mode 100644 index 000000000000..35e527eb5f6d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/input_file_block_start.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT input_file_block_start(); +## PySpark 3.5.5 Result: {'input_file_block_start()': -1, 'typeof(input_file_block_start())': 'bigint'} +#query +#SELECT input_file_block_start(); + diff --git a/datafusion/sqllogictest/test_files/spark/misc/input_file_name.slt b/datafusion/sqllogictest/test_files/spark/misc/input_file_name.slt new file mode 100644 index 000000000000..dbbfce734149 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/input_file_name.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT input_file_name(); +## PySpark 3.5.5 Result: {'input_file_name()': '', 'typeof(input_file_name())': 'string'} +#query +#SELECT input_file_name(); + diff --git a/datafusion/sqllogictest/test_files/spark/misc/java_method.slt b/datafusion/sqllogictest/test_files/spark/misc/java_method.slt new file mode 100644 index 000000000000..30af81009ffb --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/java_method.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT java_method('java.util.UUID', 'randomUUID'); +## PySpark 3.5.5 Result: {'java_method(java.util.UUID, randomUUID)': 'e0d43859-1003-4f43-bfff-f2e3c34981e2', 'typeof(java_method(java.util.UUID, randomUUID))': 'string', 'typeof(java.util.UUID)': 'string', 'typeof(randomUUID)': 'string'} +#query +#SELECT java_method('java.util.UUID'::string, 'randomUUID'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/misc/monotonically_increasing_id.slt b/datafusion/sqllogictest/test_files/spark/misc/monotonically_increasing_id.slt new file mode 100644 index 000000000000..d47f4da86a9b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/monotonically_increasing_id.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT monotonically_increasing_id(); +## PySpark 3.5.5 Result: {'monotonically_increasing_id()': 0, 'typeof(monotonically_increasing_id())': 'bigint'} +#query +#SELECT monotonically_increasing_id(); + diff --git a/datafusion/sqllogictest/test_files/spark/misc/reflect.slt b/datafusion/sqllogictest/test_files/spark/misc/reflect.slt new file mode 100644 index 000000000000..464455d2470f --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/reflect.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT reflect('java.util.UUID', 'randomUUID'); +## PySpark 3.5.5 Result: {'reflect(java.util.UUID, randomUUID)': 'bcf8f6e4-0d46-41a1-bc3c-9f793c8f8aa8', 'typeof(reflect(java.util.UUID, randomUUID))': 'string', 'typeof(java.util.UUID)': 'string', 'typeof(randomUUID)': 'string'} +#query +#SELECT reflect('java.util.UUID'::string, 'randomUUID'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/misc/spark_partition_id.slt b/datafusion/sqllogictest/test_files/spark/misc/spark_partition_id.slt new file mode 100644 index 000000000000..5407c7442c6d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/spark_partition_id.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT spark_partition_id(); +## PySpark 3.5.5 Result: {'SPARK_PARTITION_ID()': 0, 'typeof(SPARK_PARTITION_ID())': 'int'} +#query +#SELECT spark_partition_id(); + diff --git a/datafusion/sqllogictest/test_files/spark/misc/typeof.slt b/datafusion/sqllogictest/test_files/spark/misc/typeof.slt new file mode 100644 index 000000000000..8ae58f29ce8e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/typeof.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT typeof(1); +## PySpark 3.5.5 Result: {'typeof(1)': 'int', 'typeof(typeof(1))': 'string'} +#query +#SELECT typeof(1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/misc/user.slt b/datafusion/sqllogictest/test_files/spark/misc/user.slt new file mode 100644 index 000000000000..87c71e8ea7df --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/user.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT user(); +## PySpark 3.5.5 Result: {'current_user()': 'r', 'typeof(current_user())': 'string'} +#query +#SELECT user(); + diff --git a/datafusion/sqllogictest/test_files/spark/misc/uuid.slt b/datafusion/sqllogictest/test_files/spark/misc/uuid.slt new file mode 100644 index 000000000000..f975adf9908c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/uuid.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT uuid(); +## PySpark 3.5.5 Result: {'uuid()': '96981e67-62f6-49bc-a6f4-2f9bc676edda', 'typeof(uuid())': 'string'} +#query +#SELECT uuid(); + diff --git a/datafusion/sqllogictest/test_files/spark/misc/version.slt b/datafusion/sqllogictest/test_files/spark/misc/version.slt new file mode 100644 index 000000000000..db495192a6b2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/misc/version.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT version(); +## PySpark 3.5.5 Result: {'version()': '3.5.5 7c29c664cdc9321205a98a14858aaf8daaa19db2', 'typeof(version())': 'string'} +#query +#SELECT version(); + diff --git a/datafusion/sqllogictest/test_files/spark/predicate/ilike.slt b/datafusion/sqllogictest/test_files/spark/predicate/ilike.slt new file mode 100644 index 000000000000..6a147e6be551 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/predicate/ilike.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT ilike('Spark', '_Park'); +## PySpark 3.5.5 Result: {'ilike(Spark, _Park)': True, 'typeof(ilike(Spark, _Park))': 'boolean', 'typeof(Spark)': 'string', 'typeof(_Park)': 'string'} +#query +#SELECT ilike('Spark'::string, '_Park'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/predicate/isnotnull.slt b/datafusion/sqllogictest/test_files/spark/predicate/isnotnull.slt new file mode 100644 index 000000000000..59f3d6cafa3a --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/predicate/isnotnull.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT isnotnull(1); +## PySpark 3.5.5 Result: {'(1 IS NOT NULL)': True, 'typeof((1 IS NOT NULL))': 'boolean', 'typeof(1)': 'int'} +#query +#SELECT isnotnull(1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/predicate/isnull.slt b/datafusion/sqllogictest/test_files/spark/predicate/isnull.slt new file mode 100644 index 000000000000..b62492e5d1bf --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/predicate/isnull.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT isnull(1); +## PySpark 3.5.5 Result: {'(1 IS NULL)': False, 'typeof((1 IS NULL))': 'boolean', 'typeof(1)': 'int'} +#query +#SELECT isnull(1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/predicate/like.slt b/datafusion/sqllogictest/test_files/spark/predicate/like.slt new file mode 100644 index 000000000000..71aebbee2a8d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/predicate/like.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT like('Spark', '_park'); +## PySpark 3.5.5 Result: {'Spark LIKE _park': True, 'typeof(Spark LIKE _park)': 'boolean', 'typeof(Spark)': 'string', 'typeof(_park)': 'string'} +#query +#SELECT like('Spark'::string, '_park'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/ascii.slt b/datafusion/sqllogictest/test_files/spark/string/ascii.slt new file mode 100644 index 000000000000..623154ffaa7b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/ascii.slt @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query I +SELECT ascii('234'); +---- +50 + +query I +SELECT ascii(''); +---- +0 + +query I +SELECT ascii('222'); +---- +50 + +query I +SELECT ascii('😀'); +---- +128512 + +query I +SELECT ascii(2::INT); +---- +50 + +query I +SELECT ascii(a) FROM (VALUES ('Spark'), ('PySpark'), ('Pandas API')) AS t(a); +---- +83 +80 +80 diff --git a/datafusion/sqllogictest/test_files/spark/string/base64.slt b/datafusion/sqllogictest/test_files/spark/string/base64.slt new file mode 100644 index 000000000000..ac0a8e4307a2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/base64.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT base64('Spark SQL'); +## PySpark 3.5.5 Result: {'base64(Spark SQL)': 'U3BhcmsgU1FM', 'typeof(base64(Spark SQL))': 'string', 'typeof(Spark SQL)': 'string'} +#query +#SELECT base64('Spark SQL'::string); + +## Original Query: SELECT base64(x'537061726b2053514c'); +## PySpark 3.5.5 Result: {"base64(X'537061726B2053514C')": 'U3BhcmsgU1FM', "typeof(base64(X'537061726B2053514C'))": 'string', "typeof(X'537061726B2053514C')": 'binary'} +#query +#SELECT base64(X'537061726B2053514C'::binary); + diff --git a/datafusion/sqllogictest/test_files/spark/string/bit_length.slt b/datafusion/sqllogictest/test_files/spark/string/bit_length.slt new file mode 100644 index 000000000000..4c7703f36df1 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/bit_length.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT bit_length('Spark SQL'); +## PySpark 3.5.5 Result: {'bit_length(Spark SQL)': 72, 'typeof(bit_length(Spark SQL))': 'int', 'typeof(Spark SQL)': 'string'} +#query +#SELECT bit_length('Spark SQL'::string); + +## Original Query: SELECT bit_length(x'537061726b2053514c'); +## PySpark 3.5.5 Result: {"bit_length(X'537061726B2053514C')": 72, "typeof(bit_length(X'537061726B2053514C'))": 'int', "typeof(X'537061726B2053514C')": 'binary'} +#query +#SELECT bit_length(X'537061726B2053514C'::binary); + diff --git a/datafusion/sqllogictest/test_files/spark/string/btrim.slt b/datafusion/sqllogictest/test_files/spark/string/btrim.slt new file mode 100644 index 000000000000..4cfbf4f14a22 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/btrim.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT btrim(' SparkSQL '); +## PySpark 3.5.5 Result: {'btrim( SparkSQL )': 'SparkSQL', 'typeof(btrim( SparkSQL ))': 'string', 'typeof( SparkSQL )': 'string'} +#query +#SELECT btrim(' SparkSQL '::string); + +## Original Query: SELECT btrim('SSparkSQLS', 'SL'); +## PySpark 3.5.5 Result: {'btrim(SSparkSQLS, SL)': 'parkSQ', 'typeof(btrim(SSparkSQLS, SL))': 'string', 'typeof(SSparkSQLS)': 'string', 'typeof(SL)': 'string'} +#query +#SELECT btrim('SSparkSQLS'::string, 'SL'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/char.slt b/datafusion/sqllogictest/test_files/spark/string/char.slt new file mode 100644 index 000000000000..abf0c7076a02 Binary files /dev/null and b/datafusion/sqllogictest/test_files/spark/string/char.slt differ diff --git a/datafusion/sqllogictest/test_files/spark/string/char_length.slt b/datafusion/sqllogictest/test_files/spark/string/char_length.slt new file mode 100644 index 000000000000..13d39a9ef249 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/char_length.slt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT CHAR_LENGTH('Spark SQL '); +## PySpark 3.5.5 Result: {'char_length(Spark SQL )': 10, 'typeof(char_length(Spark SQL ))': 'int', 'typeof(Spark SQL )': 'string'} +#query +#SELECT CHAR_LENGTH('Spark SQL '::string); + +## Original Query: SELECT char_length('Spark SQL '); +## PySpark 3.5.5 Result: {'char_length(Spark SQL )': 10, 'typeof(char_length(Spark SQL ))': 'int', 'typeof(Spark SQL )': 'string'} +#query +#SELECT char_length('Spark SQL '::string); + +## Original Query: SELECT char_length(x'537061726b2053514c'); +## PySpark 3.5.5 Result: {"char_length(X'537061726B2053514C')": 9, "typeof(char_length(X'537061726B2053514C'))": 'int', "typeof(X'537061726B2053514C')": 'binary'} +#query +#SELECT char_length(X'537061726B2053514C'::binary); + diff --git a/datafusion/sqllogictest/test_files/spark/string/character_length.slt b/datafusion/sqllogictest/test_files/spark/string/character_length.slt new file mode 100644 index 000000000000..497d2a13274a --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/character_length.slt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT CHARACTER_LENGTH('Spark SQL '); +## PySpark 3.5.5 Result: {'character_length(Spark SQL )': 10, 'typeof(character_length(Spark SQL ))': 'int', 'typeof(Spark SQL )': 'string'} +#query +#SELECT CHARACTER_LENGTH('Spark SQL '::string); + +## Original Query: SELECT character_length('Spark SQL '); +## PySpark 3.5.5 Result: {'character_length(Spark SQL )': 10, 'typeof(character_length(Spark SQL ))': 'int', 'typeof(Spark SQL )': 'string'} +#query +#SELECT character_length('Spark SQL '::string); + +## Original Query: SELECT character_length(x'537061726b2053514c'); +## PySpark 3.5.5 Result: {"character_length(X'537061726B2053514C')": 9, "typeof(character_length(X'537061726B2053514C'))": 'int', "typeof(X'537061726B2053514C')": 'binary'} +#query +#SELECT character_length(X'537061726B2053514C'::binary); + diff --git a/datafusion/sqllogictest/test_files/spark/string/chr.slt b/datafusion/sqllogictest/test_files/spark/string/chr.slt new file mode 100644 index 000000000000..c3ea88e8e3be --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/chr.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT chr(65); +## PySpark 3.5.5 Result: {'chr(65)': 'A', 'typeof(chr(65))': 'string', 'typeof(65)': 'int'} +#query +#SELECT chr(65::int); + diff --git a/datafusion/sqllogictest/test_files/spark/string/concat_ws.slt b/datafusion/sqllogictest/test_files/spark/string/concat_ws.slt new file mode 100644 index 000000000000..9172d204c33a --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/concat_ws.slt @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT concat_ws(' ', 'Spark', 'SQL'); +## PySpark 3.5.5 Result: {'concat_ws( , Spark, SQL)': 'Spark SQL', 'typeof(concat_ws( , Spark, SQL))': 'string', 'typeof( )': 'string', 'typeof(Spark)': 'string', 'typeof(SQL)': 'string'} +#query +#SELECT concat_ws(' '::string, 'Spark'::string, 'SQL'::string); + +## Original Query: SELECT concat_ws('/', 'foo', null, 'bar'); +## PySpark 3.5.5 Result: {'concat_ws(/, foo, NULL, bar)': 'foo/bar', 'typeof(concat_ws(/, foo, NULL, bar))': 'string', 'typeof(/)': 'string', 'typeof(foo)': 'string', 'typeof(NULL)': 'void', 'typeof(bar)': 'string'} +#query +#SELECT concat_ws('/'::string, 'foo'::string, NULL::void, 'bar'::string); + +## Original Query: SELECT concat_ws('s'); +## PySpark 3.5.5 Result: {'concat_ws(s)': '', 'typeof(concat_ws(s))': 'string', 'typeof(s)': 'string'} +#query +#SELECT concat_ws('s'::string); + +## Original Query: SELECT concat_ws(null, 'Spark', 'SQL'); +## PySpark 3.5.5 Result: {'concat_ws(NULL, Spark, SQL)': None, 'typeof(concat_ws(NULL, Spark, SQL))': 'string', 'typeof(NULL)': 'void', 'typeof(Spark)': 'string', 'typeof(SQL)': 'string'} +#query +#SELECT concat_ws(NULL::void, 'Spark'::string, 'SQL'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/contains.slt b/datafusion/sqllogictest/test_files/spark/string/contains.slt new file mode 100644 index 000000000000..80b05b8f255e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/contains.slt @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT contains('Spark SQL', 'SPARK'); +## PySpark 3.5.5 Result: {'contains(Spark SQL, SPARK)': False, 'typeof(contains(Spark SQL, SPARK))': 'boolean', 'typeof(Spark SQL)': 'string', 'typeof(SPARK)': 'string'} +#query +#SELECT contains('Spark SQL'::string, 'SPARK'::string); + +## Original Query: SELECT contains('Spark SQL', 'Spark'); +## PySpark 3.5.5 Result: {'contains(Spark SQL, Spark)': True, 'typeof(contains(Spark SQL, Spark))': 'boolean', 'typeof(Spark SQL)': 'string', 'typeof(Spark)': 'string'} +#query +#SELECT contains('Spark SQL'::string, 'Spark'::string); + +## Original Query: SELECT contains('Spark SQL', null); +## PySpark 3.5.5 Result: {'contains(Spark SQL, NULL)': None, 'typeof(contains(Spark SQL, NULL))': 'boolean', 'typeof(Spark SQL)': 'string', 'typeof(NULL)': 'void'} +#query +#SELECT contains('Spark SQL'::string, NULL::void); + +## Original Query: SELECT contains(x'537061726b2053514c', x'537061726b'); +## PySpark 3.5.5 Result: {"contains(X'537061726B2053514C', X'537061726B')": True, "typeof(contains(X'537061726B2053514C', X'537061726B'))": 'boolean', "typeof(X'537061726B2053514C')": 'binary', "typeof(X'537061726B')": 'binary'} +#query +#SELECT contains(X'537061726B2053514C'::binary, X'537061726B'::binary); + diff --git a/datafusion/sqllogictest/test_files/spark/string/decode.slt b/datafusion/sqllogictest/test_files/spark/string/decode.slt new file mode 100644 index 000000000000..c6848dccd47a --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/decode.slt @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT decode(2, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle', 'Non domestic'); +## PySpark 3.5.5 Result: {'decode(2, 1, Southlake, 2, San Francisco, 3, New Jersey, 4, Seattle, Non domestic)': 'San Francisco', 'typeof(decode(2, 1, Southlake, 2, San Francisco, 3, New Jersey, 4, Seattle, Non domestic))': 'string', 'typeof(2)': 'int', 'typeof(1)': 'int', 'typeof(Southlake)': 'string', 'typeof(San Francisco)': 'string', 'typeof(3)': 'int', 'typeof(New Jersey)': 'string', 'typeof(4)': 'int', 'typeof(Seattle)': 'string', 'typeof(Non domestic)': 'string'} +#query +#SELECT decode(2::int, 1::int, 'Southlake'::string, 'San Francisco'::string, 3::int, 'New Jersey'::string, 4::int, 'Seattle'::string, 'Non domestic'::string); + +## Original Query: SELECT decode(6, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle'); +## PySpark 3.5.5 Result: {'decode(6, 1, Southlake, 2, San Francisco, 3, New Jersey, 4, Seattle)': None, 'typeof(decode(6, 1, Southlake, 2, San Francisco, 3, New Jersey, 4, Seattle))': 'string', 'typeof(6)': 'int', 'typeof(1)': 'int', 'typeof(Southlake)': 'string', 'typeof(2)': 'int', 'typeof(San Francisco)': 'string', 'typeof(3)': 'int', 'typeof(New Jersey)': 'string', 'typeof(4)': 'int', 'typeof(Seattle)': 'string'} +#query +#SELECT decode(6::int, 1::int, 'Southlake'::string, 2::int, 'San Francisco'::string, 3::int, 'New Jersey'::string, 4::int, 'Seattle'::string); + +## Original Query: SELECT decode(6, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle', 'Non domestic'); +## PySpark 3.5.5 Result: {'decode(6, 1, Southlake, 2, San Francisco, 3, New Jersey, 4, Seattle, Non domestic)': 'Non domestic', 'typeof(decode(6, 1, Southlake, 2, San Francisco, 3, New Jersey, 4, Seattle, Non domestic))': 'string', 'typeof(6)': 'int', 'typeof(1)': 'int', 'typeof(Southlake)': 'string', 'typeof(2)': 'int', 'typeof(San Francisco)': 'string', 'typeof(3)': 'int', 'typeof(New Jersey)': 'string', 'typeof(4)': 'int', 'typeof(Seattle)': 'string', 'typeof(Non domestic)': 'string'} +#query +#SELECT decode(6::int, 1::int, 'Southlake'::string, 2::int, 'San Francisco'::string, 3::int, 'New Jersey'::string, 4::int, 'Seattle'::string, 'Non domestic'::string); + +## Original Query: SELECT decode(null, 6, 'Spark', NULL, 'SQL', 4, 'rocks'); +## PySpark 3.5.5 Result: {'decode(NULL, 6, Spark, NULL, SQL, 4, rocks)': 'SQL', 'typeof(decode(NULL, 6, Spark, NULL, SQL, 4, rocks))': 'string', 'typeof(NULL)': 'void', 'typeof(6)': 'int', 'typeof(Spark)': 'string', 'typeof(SQL)': 'string', 'typeof(4)': 'int', 'typeof(rocks)': 'string'} +#query +#SELECT decode(NULL::void, 6::int, 'Spark'::string, 'SQL'::string, 4::int, 'rocks'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/elt.slt b/datafusion/sqllogictest/test_files/spark/string/elt.slt new file mode 100644 index 000000000000..406be0d5e5d8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/elt.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT elt(1, 'scala', 'java'); +## PySpark 3.5.5 Result: {'elt(1, scala, java)': 'scala', 'typeof(elt(1, scala, java))': 'string', 'typeof(1)': 'int', 'typeof(scala)': 'string', 'typeof(java)': 'string'} +#query +#SELECT elt(1::int, 'scala'::string, 'java'::string); + +## Original Query: SELECT elt(2, 'a', 1); +## PySpark 3.5.5 Result: {'elt(2, a, 1)': '1', 'typeof(elt(2, a, 1))': 'string', 'typeof(2)': 'int', 'typeof(a)': 'string', 'typeof(1)': 'int'} +#query +#SELECT elt(2::int, 'a'::string, 1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/string/encode.slt b/datafusion/sqllogictest/test_files/spark/string/encode.slt new file mode 100644 index 000000000000..627ed842c7e8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/encode.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT encode('abc', 'utf-8'); +## PySpark 3.5.5 Result: {'encode(abc, utf-8)': bytearray(b'abc'), 'typeof(encode(abc, utf-8))': 'binary', 'typeof(abc)': 'string', 'typeof(utf-8)': 'string'} +#query +#SELECT encode('abc'::string, 'utf-8'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/endswith.slt b/datafusion/sqllogictest/test_files/spark/string/endswith.slt new file mode 100644 index 000000000000..b933afe11045 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/endswith.slt @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT endswith('Spark SQL', 'SQL'); +## PySpark 3.5.5 Result: {'endswith(Spark SQL, SQL)': True, 'typeof(endswith(Spark SQL, SQL))': 'boolean', 'typeof(Spark SQL)': 'string', 'typeof(SQL)': 'string'} +#query +#SELECT endswith('Spark SQL'::string, 'SQL'::string); + +## Original Query: SELECT endswith('Spark SQL', 'Spark'); +## PySpark 3.5.5 Result: {'endswith(Spark SQL, Spark)': False, 'typeof(endswith(Spark SQL, Spark))': 'boolean', 'typeof(Spark SQL)': 'string', 'typeof(Spark)': 'string'} +#query +#SELECT endswith('Spark SQL'::string, 'Spark'::string); + +## Original Query: SELECT endswith('Spark SQL', null); +## PySpark 3.5.5 Result: {'endswith(Spark SQL, NULL)': None, 'typeof(endswith(Spark SQL, NULL))': 'boolean', 'typeof(Spark SQL)': 'string', 'typeof(NULL)': 'void'} +#query +#SELECT endswith('Spark SQL'::string, NULL::void); + +## Original Query: SELECT endswith(x'537061726b2053514c', x'53514c'); +## PySpark 3.5.5 Result: {"endswith(X'537061726B2053514C', X'53514C')": True, "typeof(endswith(X'537061726B2053514C', X'53514C'))": 'boolean', "typeof(X'537061726B2053514C')": 'binary', "typeof(X'53514C')": 'binary'} +#query +#SELECT endswith(X'537061726B2053514C'::binary, X'53514C'::binary); + +## Original Query: SELECT endswith(x'537061726b2053514c', x'537061726b'); +## PySpark 3.5.5 Result: {"endswith(X'537061726B2053514C', X'537061726B')": False, "typeof(endswith(X'537061726B2053514C', X'537061726B'))": 'boolean', "typeof(X'537061726B2053514C')": 'binary', "typeof(X'537061726B')": 'binary'} +#query +#SELECT endswith(X'537061726B2053514C'::binary, X'537061726B'::binary); + diff --git a/datafusion/sqllogictest/test_files/spark/string/find_in_set.slt b/datafusion/sqllogictest/test_files/spark/string/find_in_set.slt new file mode 100644 index 000000000000..9715879152b4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/find_in_set.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT find_in_set('ab','abc,b,ab,c,def'); +## PySpark 3.5.5 Result: {'find_in_set(ab, abc,b,ab,c,def)': 3, 'typeof(find_in_set(ab, abc,b,ab,c,def))': 'int', 'typeof(ab)': 'string', 'typeof(abc,b,ab,c,def)': 'string'} +#query +#SELECT find_in_set('ab'::string, 'abc,b,ab,c,def'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/format_number.slt b/datafusion/sqllogictest/test_files/spark/string/format_number.slt new file mode 100644 index 000000000000..01af282bcbe6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/format_number.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT format_number(12332.123456, '##################.###'); +## PySpark 3.5.5 Result: {'format_number(12332.123456, ##################.###)': '12332.123', 'typeof(format_number(12332.123456, ##################.###))': 'string', 'typeof(12332.123456)': 'decimal(11,6)', 'typeof(##################.###)': 'string'} +#query +#SELECT format_number(12332.123456::decimal(11,6), '##################.###'::string); + +## Original Query: SELECT format_number(12332.123456, 4); +## PySpark 3.5.5 Result: {'format_number(12332.123456, 4)': '12,332.1235', 'typeof(format_number(12332.123456, 4))': 'string', 'typeof(12332.123456)': 'decimal(11,6)', 'typeof(4)': 'int'} +#query +#SELECT format_number(12332.123456::decimal(11,6), 4::int); + diff --git a/datafusion/sqllogictest/test_files/spark/string/format_string.slt b/datafusion/sqllogictest/test_files/spark/string/format_string.slt new file mode 100644 index 000000000000..2505557a71ae --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/format_string.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT format_string("Hello World %d %s", 100, "days"); +## PySpark 3.5.5 Result: {'format_string(Hello World %d %s, 100, days)': 'Hello World 100 days', 'typeof(format_string(Hello World %d %s, 100, days))': 'string', 'typeof(Hello World %d %s)': 'string', 'typeof(100)': 'int', 'typeof(days)': 'string'} +#query +#SELECT format_string('Hello World %d %s'::string, 100::int, 'days'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/initcap.slt b/datafusion/sqllogictest/test_files/spark/string/initcap.slt new file mode 100644 index 000000000000..b464cad4716d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/initcap.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT initcap('sPark sql'); +## PySpark 3.5.5 Result: {'initcap(sPark sql)': 'Spark Sql', 'typeof(initcap(sPark sql))': 'string', 'typeof(sPark sql)': 'string'} +#query +#SELECT initcap('sPark sql'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/instr.slt b/datafusion/sqllogictest/test_files/spark/string/instr.slt new file mode 100644 index 000000000000..55ba2d731281 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/instr.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT instr('SparkSQL', 'SQL'); +## PySpark 3.5.5 Result: {'instr(SparkSQL, SQL)': 6, 'typeof(instr(SparkSQL, SQL))': 'int', 'typeof(SparkSQL)': 'string', 'typeof(SQL)': 'string'} +#query +#SELECT instr('SparkSQL'::string, 'SQL'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/lcase.slt b/datafusion/sqllogictest/test_files/spark/string/lcase.slt new file mode 100644 index 000000000000..795cf620fd24 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/lcase.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT lcase('SparkSql'); +## PySpark 3.5.5 Result: {'lcase(SparkSql)': 'sparksql', 'typeof(lcase(SparkSql))': 'string', 'typeof(SparkSql)': 'string'} +#query +#SELECT lcase('SparkSql'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/left.slt b/datafusion/sqllogictest/test_files/spark/string/left.slt new file mode 100644 index 000000000000..12effeb38fc5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/left.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT left('Spark SQL', 3); +## PySpark 3.5.5 Result: {'left(Spark SQL, 3)': 'Spa', 'typeof(left(Spark SQL, 3))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(3)': 'int'} +#query +#SELECT left('Spark SQL'::string, 3::int); + diff --git a/datafusion/sqllogictest/test_files/spark/string/len.slt b/datafusion/sqllogictest/test_files/spark/string/len.slt new file mode 100644 index 000000000000..ae469cdc3514 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/len.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT len('Spark SQL '); +## PySpark 3.5.5 Result: {'len(Spark SQL )': 10, 'typeof(len(Spark SQL ))': 'int', 'typeof(Spark SQL )': 'string'} +#query +#SELECT len('Spark SQL '::string); + +## Original Query: SELECT len(x'537061726b2053514c'); +## PySpark 3.5.5 Result: {"len(X'537061726B2053514C')": 9, "typeof(len(X'537061726B2053514C'))": 'int', "typeof(X'537061726B2053514C')": 'binary'} +#query +#SELECT len(X'537061726B2053514C'::binary); + diff --git a/datafusion/sqllogictest/test_files/spark/string/length.slt b/datafusion/sqllogictest/test_files/spark/string/length.slt new file mode 100644 index 000000000000..94252a99772c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/length.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT length('Spark SQL '); +## PySpark 3.5.5 Result: {'length(Spark SQL )': 10, 'typeof(length(Spark SQL ))': 'int', 'typeof(Spark SQL )': 'string'} +#query +#SELECT length('Spark SQL '::string); + +## Original Query: SELECT length(x'537061726b2053514c'); +## PySpark 3.5.5 Result: {"length(X'537061726B2053514C')": 9, "typeof(length(X'537061726B2053514C'))": 'int', "typeof(X'537061726B2053514C')": 'binary'} +#query +#SELECT length(X'537061726B2053514C'::binary); + diff --git a/datafusion/sqllogictest/test_files/spark/string/levenshtein.slt b/datafusion/sqllogictest/test_files/spark/string/levenshtein.slt new file mode 100644 index 000000000000..76d731da20bc --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/levenshtein.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT levenshtein('kitten', 'sitting'); +## PySpark 3.5.5 Result: {'levenshtein(kitten, sitting)': 3, 'typeof(levenshtein(kitten, sitting))': 'int', 'typeof(kitten)': 'string', 'typeof(sitting)': 'string'} +#query +#SELECT levenshtein('kitten'::string, 'sitting'::string); + +## Original Query: SELECT levenshtein('kitten', 'sitting', 2); +## PySpark 3.5.5 Result: {'levenshtein(kitten, sitting, 2)': -1, 'typeof(levenshtein(kitten, sitting, 2))': 'int', 'typeof(kitten)': 'string', 'typeof(sitting)': 'string', 'typeof(2)': 'int'} +#query +#SELECT levenshtein('kitten'::string, 'sitting'::string, 2::int); + diff --git a/datafusion/sqllogictest/test_files/spark/string/locate.slt b/datafusion/sqllogictest/test_files/spark/string/locate.slt new file mode 100644 index 000000000000..c5a6c625e95d --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/locate.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT locate('bar', 'foobarbar'); +## PySpark 3.5.5 Result: {'locate(bar, foobarbar, 1)': 4, 'typeof(locate(bar, foobarbar, 1))': 'int', 'typeof(bar)': 'string', 'typeof(foobarbar)': 'string'} +#query +#SELECT locate('bar'::string, 'foobarbar'::string); + +## Original Query: SELECT locate('bar', 'foobarbar', 5); +## PySpark 3.5.5 Result: {'locate(bar, foobarbar, 5)': 7, 'typeof(locate(bar, foobarbar, 5))': 'int', 'typeof(bar)': 'string', 'typeof(foobarbar)': 'string', 'typeof(5)': 'int'} +#query +#SELECT locate('bar'::string, 'foobarbar'::string, 5::int); + diff --git a/datafusion/sqllogictest/test_files/spark/string/lower.slt b/datafusion/sqllogictest/test_files/spark/string/lower.slt new file mode 100644 index 000000000000..e221e622148c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/lower.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT lower('SparkSql'); +## PySpark 3.5.5 Result: {'lower(SparkSql)': 'sparksql', 'typeof(lower(SparkSql))': 'string', 'typeof(SparkSql)': 'string'} +#query +#SELECT lower('SparkSql'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/lpad.slt b/datafusion/sqllogictest/test_files/spark/string/lpad.slt new file mode 100644 index 000000000000..f9716d59cc03 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/lpad.slt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT lpad('hi', 1, '??'); +## PySpark 3.5.5 Result: {'lpad(hi, 1, ??)': 'h', 'typeof(lpad(hi, 1, ??))': 'string', 'typeof(hi)': 'string', 'typeof(1)': 'int', 'typeof(??)': 'string'} +#query +#SELECT lpad('hi'::string, 1::int, '??'::string); + +## Original Query: SELECT lpad('hi', 5); +## PySpark 3.5.5 Result: {'lpad(hi, 5, )': ' hi', 'typeof(lpad(hi, 5, ))': 'string', 'typeof(hi)': 'string', 'typeof(5)': 'int'} +#query +#SELECT lpad('hi'::string, 5::int); + +## Original Query: SELECT lpad('hi', 5, '??'); +## PySpark 3.5.5 Result: {'lpad(hi, 5, ??)': '???hi', 'typeof(lpad(hi, 5, ??))': 'string', 'typeof(hi)': 'string', 'typeof(5)': 'int', 'typeof(??)': 'string'} +#query +#SELECT lpad('hi'::string, 5::int, '??'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/ltrim.slt b/datafusion/sqllogictest/test_files/spark/string/ltrim.slt new file mode 100644 index 000000000000..8719dad4d4d3 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/ltrim.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT ltrim(' SparkSQL '); +## PySpark 3.5.5 Result: {'ltrim( SparkSQL )': 'SparkSQL ', 'typeof(ltrim( SparkSQL ))': 'string', 'typeof( SparkSQL )': 'string'} +#query +#SELECT ltrim(' SparkSQL '::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/luhn_check.slt b/datafusion/sqllogictest/test_files/spark/string/luhn_check.slt new file mode 100644 index 000000000000..a28c60e62917 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/luhn_check.slt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT luhn_check('79927398713'); +## PySpark 3.5.5 Result: {'luhn_check(79927398713)': True, 'typeof(luhn_check(79927398713))': 'boolean', 'typeof(79927398713)': 'string'} +#query +#SELECT luhn_check('79927398713'::string); + +## Original Query: SELECT luhn_check('79927398714'); +## PySpark 3.5.5 Result: {'luhn_check(79927398714)': False, 'typeof(luhn_check(79927398714))': 'boolean', 'typeof(79927398714)': 'string'} +#query +#SELECT luhn_check('79927398714'::string); + +## Original Query: SELECT luhn_check('8112189876'); +## PySpark 3.5.5 Result: {'luhn_check(8112189876)': True, 'typeof(luhn_check(8112189876))': 'boolean', 'typeof(8112189876)': 'string'} +#query +#SELECT luhn_check('8112189876'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/mask.slt b/datafusion/sqllogictest/test_files/spark/string/mask.slt new file mode 100644 index 000000000000..6468af3cc4f9 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/mask.slt @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT mask('AbCD123-@$#'); +## PySpark 3.5.5 Result: {'mask(AbCD123-@$#, X, x, n, NULL)': 'XxXXnnn-@$#', 'typeof(mask(AbCD123-@$#, X, x, n, NULL))': 'string', 'typeof(AbCD123-@$#)': 'string'} +#query +#SELECT mask('AbCD123-@$#'::string); + +## Original Query: SELECT mask('AbCD123-@$#', 'Q'); +## PySpark 3.5.5 Result: {'mask(AbCD123-@$#, Q, x, n, NULL)': 'QxQQnnn-@$#', 'typeof(mask(AbCD123-@$#, Q, x, n, NULL))': 'string', 'typeof(AbCD123-@$#)': 'string', 'typeof(Q)': 'string'} +#query +#SELECT mask('AbCD123-@$#'::string, 'Q'::string); + +## Original Query: SELECT mask('AbCD123-@$#', 'Q', 'q'); +## PySpark 3.5.5 Result: {'mask(AbCD123-@$#, Q, q, n, NULL)': 'QqQQnnn-@$#', 'typeof(mask(AbCD123-@$#, Q, q, n, NULL))': 'string', 'typeof(AbCD123-@$#)': 'string', 'typeof(Q)': 'string', 'typeof(q)': 'string'} +#query +#SELECT mask('AbCD123-@$#'::string, 'Q'::string, 'q'::string); + +## Original Query: SELECT mask('AbCD123-@$#', 'Q', 'q', 'd'); +## PySpark 3.5.5 Result: {'mask(AbCD123-@$#, Q, q, d, NULL)': 'QqQQddd-@$#', 'typeof(mask(AbCD123-@$#, Q, q, d, NULL))': 'string', 'typeof(AbCD123-@$#)': 'string', 'typeof(Q)': 'string', 'typeof(q)': 'string', 'typeof(d)': 'string'} +#query +#SELECT mask('AbCD123-@$#'::string, 'Q'::string, 'q'::string, 'd'::string); + +## Original Query: SELECT mask('AbCD123-@$#', 'Q', 'q', 'd', 'o'); +## PySpark 3.5.5 Result: {'mask(AbCD123-@$#, Q, q, d, o)': 'QqQQdddoooo', 'typeof(mask(AbCD123-@$#, Q, q, d, o))': 'string', 'typeof(AbCD123-@$#)': 'string', 'typeof(Q)': 'string', 'typeof(q)': 'string', 'typeof(d)': 'string', 'typeof(o)': 'string'} +#query +#SELECT mask('AbCD123-@$#'::string, 'Q'::string, 'q'::string, 'd'::string, 'o'::string); + +## Original Query: SELECT mask('AbCD123-@$#', NULL, 'q', 'd', 'o'); +## PySpark 3.5.5 Result: {'mask(AbCD123-@$#, NULL, q, d, o)': 'AqCDdddoooo', 'typeof(mask(AbCD123-@$#, NULL, q, d, o))': 'string', 'typeof(AbCD123-@$#)': 'string', 'typeof(NULL)': 'void', 'typeof(q)': 'string', 'typeof(d)': 'string', 'typeof(o)': 'string'} +#query +#SELECT mask('AbCD123-@$#'::string, NULL::void, 'q'::string, 'd'::string, 'o'::string); + +## Original Query: SELECT mask('AbCD123-@$#', NULL, NULL, 'd', 'o'); +## PySpark 3.5.5 Result: {'mask(AbCD123-@$#, NULL, NULL, d, o)': 'AbCDdddoooo', 'typeof(mask(AbCD123-@$#, NULL, NULL, d, o))': 'string', 'typeof(AbCD123-@$#)': 'string', 'typeof(NULL)': 'void', 'typeof(d)': 'string', 'typeof(o)': 'string'} +#query +#SELECT mask('AbCD123-@$#'::string, NULL::void, 'd'::string, 'o'::string); + +## Original Query: SELECT mask('AbCD123-@$#', NULL, NULL, NULL, 'o'); +## PySpark 3.5.5 Result: {'mask(AbCD123-@$#, NULL, NULL, NULL, o)': 'AbCD123oooo', 'typeof(mask(AbCD123-@$#, NULL, NULL, NULL, o))': 'string', 'typeof(AbCD123-@$#)': 'string', 'typeof(NULL)': 'void', 'typeof(o)': 'string'} +#query +#SELECT mask('AbCD123-@$#'::string, NULL::void, 'o'::string); + +## Original Query: SELECT mask('AbCD123-@$#', NULL, NULL, NULL, NULL); +## PySpark 3.5.5 Result: {'mask(AbCD123-@$#, NULL, NULL, NULL, NULL)': 'AbCD123-@$#', 'typeof(mask(AbCD123-@$#, NULL, NULL, NULL, NULL))': 'string', 'typeof(AbCD123-@$#)': 'string', 'typeof(NULL)': 'void'} +#query +#SELECT mask('AbCD123-@$#'::string, NULL::void); + +## Original Query: SELECT mask('abcd-EFGH-8765-4321'); +## PySpark 3.5.5 Result: {'mask(abcd-EFGH-8765-4321, X, x, n, NULL)': 'xxxx-XXXX-nnnn-nnnn', 'typeof(mask(abcd-EFGH-8765-4321, X, x, n, NULL))': 'string', 'typeof(abcd-EFGH-8765-4321)': 'string'} +#query +#SELECT mask('abcd-EFGH-8765-4321'::string); + +## Original Query: SELECT mask('abcd-EFGH-8765-4321', 'Q'); +## PySpark 3.5.5 Result: {'mask(abcd-EFGH-8765-4321, Q, x, n, NULL)': 'xxxx-QQQQ-nnnn-nnnn', 'typeof(mask(abcd-EFGH-8765-4321, Q, x, n, NULL))': 'string', 'typeof(abcd-EFGH-8765-4321)': 'string', 'typeof(Q)': 'string'} +#query +#SELECT mask('abcd-EFGH-8765-4321'::string, 'Q'::string); + +## Original Query: SELECT mask(NULL); +## PySpark 3.5.5 Result: {'mask(NULL, X, x, n, NULL)': None, 'typeof(mask(NULL, X, x, n, NULL))': 'string', 'typeof(NULL)': 'void'} +#query +#SELECT mask(NULL::void); + +## Original Query: SELECT mask(NULL, NULL, NULL, NULL, 'o'); +## PySpark 3.5.5 Result: {'mask(NULL, NULL, NULL, NULL, o)': None, 'typeof(mask(NULL, NULL, NULL, NULL, o))': 'string', 'typeof(NULL)': 'void', 'typeof(o)': 'string'} +#query +#SELECT mask(NULL::void, 'o'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/octet_length.slt b/datafusion/sqllogictest/test_files/spark/string/octet_length.slt new file mode 100644 index 000000000000..1efab7973232 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/octet_length.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT octet_length('Spark SQL'); +## PySpark 3.5.5 Result: {'octet_length(Spark SQL)': 9, 'typeof(octet_length(Spark SQL))': 'int', 'typeof(Spark SQL)': 'string'} +#query +#SELECT octet_length('Spark SQL'::string); + +## Original Query: SELECT octet_length(x'537061726b2053514c'); +## PySpark 3.5.5 Result: {"octet_length(X'537061726B2053514C')": 9, "typeof(octet_length(X'537061726B2053514C'))": 'int', "typeof(X'537061726B2053514C')": 'binary'} +#query +#SELECT octet_length(X'537061726B2053514C'::binary); + diff --git a/datafusion/sqllogictest/test_files/spark/string/position.slt b/datafusion/sqllogictest/test_files/spark/string/position.slt new file mode 100644 index 000000000000..b79c24248cdd --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/position.slt @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT position('bar', 'foobarbar'); +## PySpark 3.5.5 Result: {'position(bar, foobarbar, 1)': 4, 'typeof(position(bar, foobarbar, 1))': 'int', 'typeof(bar)': 'string', 'typeof(foobarbar)': 'string'} +#query +#SELECT position('bar'::string, 'foobarbar'::string); + +## Original Query: SELECT position('bar', 'foobarbar', 5); +## PySpark 3.5.5 Result: {'position(bar, foobarbar, 5)': 7, 'typeof(position(bar, foobarbar, 5))': 'int', 'typeof(bar)': 'string', 'typeof(foobarbar)': 'string', 'typeof(5)': 'int'} +#query +#SELECT position('bar'::string, 'foobarbar'::string, 5::int); + diff --git a/datafusion/sqllogictest/test_files/spark/string/printf.slt b/datafusion/sqllogictest/test_files/spark/string/printf.slt new file mode 100644 index 000000000000..7a1991801974 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/printf.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT printf("Hello World %d %s", 100, "days"); +## PySpark 3.5.5 Result: {'printf(Hello World %d %s, 100, days)': 'Hello World 100 days', 'typeof(printf(Hello World %d %s, 100, days))': 'string', 'typeof(Hello World %d %s)': 'string', 'typeof(100)': 'int', 'typeof(days)': 'string'} +#query +#SELECT printf('Hello World %d %s'::string, 100::int, 'days'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/repeat.slt b/datafusion/sqllogictest/test_files/spark/string/repeat.slt new file mode 100644 index 000000000000..3c1d097e5948 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/repeat.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT repeat('123', 2); +## PySpark 3.5.5 Result: {'repeat(123, 2)': '123123', 'typeof(repeat(123, 2))': 'string', 'typeof(123)': 'string', 'typeof(2)': 'int'} +#query +#SELECT repeat('123'::string, 2::int); + diff --git a/datafusion/sqllogictest/test_files/spark/string/replace.slt b/datafusion/sqllogictest/test_files/spark/string/replace.slt new file mode 100644 index 000000000000..ee74c8e91a2a --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/replace.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT replace('ABCabc', 'abc', 'DEF'); +## PySpark 3.5.5 Result: {'replace(ABCabc, abc, DEF)': 'ABCDEF', 'typeof(replace(ABCabc, abc, DEF))': 'string', 'typeof(ABCabc)': 'string', 'typeof(abc)': 'string', 'typeof(DEF)': 'string'} +#query +#SELECT replace('ABCabc'::string, 'abc'::string, 'DEF'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/right.slt b/datafusion/sqllogictest/test_files/spark/string/right.slt new file mode 100644 index 000000000000..239705eab83c --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/right.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT right('Spark SQL', 3); +## PySpark 3.5.5 Result: {'right(Spark SQL, 3)': 'SQL', 'typeof(right(Spark SQL, 3))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(3)': 'int'} +#query +#SELECT right('Spark SQL'::string, 3::int); + diff --git a/datafusion/sqllogictest/test_files/spark/string/rpad.slt b/datafusion/sqllogictest/test_files/spark/string/rpad.slt new file mode 100644 index 000000000000..98ebd1f810d2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/rpad.slt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT rpad('hi', 1, '??'); +## PySpark 3.5.5 Result: {'rpad(hi, 1, ??)': 'h', 'typeof(rpad(hi, 1, ??))': 'string', 'typeof(hi)': 'string', 'typeof(1)': 'int', 'typeof(??)': 'string'} +#query +#SELECT rpad('hi'::string, 1::int, '??'::string); + +## Original Query: SELECT rpad('hi', 5); +## PySpark 3.5.5 Result: {'rpad(hi, 5, )': 'hi ', 'typeof(rpad(hi, 5, ))': 'string', 'typeof(hi)': 'string', 'typeof(5)': 'int'} +#query +#SELECT rpad('hi'::string, 5::int); + +## Original Query: SELECT rpad('hi', 5, '??'); +## PySpark 3.5.5 Result: {'rpad(hi, 5, ??)': 'hi???', 'typeof(rpad(hi, 5, ??))': 'string', 'typeof(hi)': 'string', 'typeof(5)': 'int', 'typeof(??)': 'string'} +#query +#SELECT rpad('hi'::string, 5::int, '??'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/rtrim.slt b/datafusion/sqllogictest/test_files/spark/string/rtrim.slt new file mode 100644 index 000000000000..c86264b6e781 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/rtrim.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT rtrim(' SparkSQL '); +## PySpark 3.5.5 Result: {'rtrim( SparkSQL )': ' SparkSQL', 'typeof(rtrim( SparkSQL ))': 'string', 'typeof( SparkSQL )': 'string'} +#query +#SELECT rtrim(' SparkSQL '::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/sentences.slt b/datafusion/sqllogictest/test_files/spark/string/sentences.slt new file mode 100644 index 000000000000..f6c69a64ce48 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/sentences.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT sentences('Hi there! Good morning.'); +## PySpark 3.5.5 Result: {'sentences(Hi there! Good morning., , )': [['Hi', 'there'], ['Good', 'morning']], 'typeof(sentences(Hi there! Good morning., , ))': 'array>', 'typeof(Hi there! Good morning.)': 'string'} +#query +#SELECT sentences('Hi there! Good morning.'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/soundex.slt b/datafusion/sqllogictest/test_files/spark/string/soundex.slt new file mode 100644 index 000000000000..1222f02c0d5b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/soundex.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT soundex('Miller'); +## PySpark 3.5.5 Result: {'soundex(Miller)': 'M460', 'typeof(soundex(Miller))': 'string', 'typeof(Miller)': 'string'} +#query +#SELECT soundex('Miller'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/split_part.slt b/datafusion/sqllogictest/test_files/spark/string/split_part.slt new file mode 100644 index 000000000000..89c752183f9b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/split_part.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT split_part('11.12.13', '.', 3); +## PySpark 3.5.5 Result: {'split_part(11.12.13, ., 3)': '13', 'typeof(split_part(11.12.13, ., 3))': 'string', 'typeof(11.12.13)': 'string', 'typeof(.)': 'string', 'typeof(3)': 'int'} +#query +#SELECT split_part('11.12.13'::string, '.'::string, 3::int); + diff --git a/datafusion/sqllogictest/test_files/spark/string/startswith.slt b/datafusion/sqllogictest/test_files/spark/string/startswith.slt new file mode 100644 index 000000000000..296ec4beb0a9 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/startswith.slt @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT startswith('Spark SQL', 'SQL'); +## PySpark 3.5.5 Result: {'startswith(Spark SQL, SQL)': False, 'typeof(startswith(Spark SQL, SQL))': 'boolean', 'typeof(Spark SQL)': 'string', 'typeof(SQL)': 'string'} +#query +#SELECT startswith('Spark SQL'::string, 'SQL'::string); + +## Original Query: SELECT startswith('Spark SQL', 'Spark'); +## PySpark 3.5.5 Result: {'startswith(Spark SQL, Spark)': True, 'typeof(startswith(Spark SQL, Spark))': 'boolean', 'typeof(Spark SQL)': 'string', 'typeof(Spark)': 'string'} +#query +#SELECT startswith('Spark SQL'::string, 'Spark'::string); + +## Original Query: SELECT startswith('Spark SQL', null); +## PySpark 3.5.5 Result: {'startswith(Spark SQL, NULL)': None, 'typeof(startswith(Spark SQL, NULL))': 'boolean', 'typeof(Spark SQL)': 'string', 'typeof(NULL)': 'void'} +#query +#SELECT startswith('Spark SQL'::string, NULL::void); + +## Original Query: SELECT startswith(x'537061726b2053514c', x'53514c'); +## PySpark 3.5.5 Result: {"startswith(X'537061726B2053514C', X'53514C')": False, "typeof(startswith(X'537061726B2053514C', X'53514C'))": 'boolean', "typeof(X'537061726B2053514C')": 'binary', "typeof(X'53514C')": 'binary'} +#query +#SELECT startswith(X'537061726B2053514C'::binary, X'53514C'::binary); + +## Original Query: SELECT startswith(x'537061726b2053514c', x'537061726b'); +## PySpark 3.5.5 Result: {"startswith(X'537061726B2053514C', X'537061726B')": True, "typeof(startswith(X'537061726B2053514C', X'537061726B'))": 'boolean', "typeof(X'537061726B2053514C')": 'binary', "typeof(X'537061726B')": 'binary'} +#query +#SELECT startswith(X'537061726B2053514C'::binary, X'537061726B'::binary); + diff --git a/datafusion/sqllogictest/test_files/spark/string/substr.slt b/datafusion/sqllogictest/test_files/spark/string/substr.slt new file mode 100644 index 000000000000..79bdee021317 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/substr.slt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT substr('Spark SQL', -3); +## PySpark 3.5.5 Result: {'substr(Spark SQL, -3, 2147483647)': 'SQL', 'typeof(substr(Spark SQL, -3, 2147483647))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(-3)': 'int'} +#query +#SELECT substr('Spark SQL'::string, -3::int); + +## Original Query: SELECT substr('Spark SQL', 5); +## PySpark 3.5.5 Result: {'substr(Spark SQL, 5, 2147483647)': 'k SQL', 'typeof(substr(Spark SQL, 5, 2147483647))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(5)': 'int'} +#query +#SELECT substr('Spark SQL'::string, 5::int); + +## Original Query: SELECT substr('Spark SQL', 5, 1); +## PySpark 3.5.5 Result: {'substr(Spark SQL, 5, 1)': 'k', 'typeof(substr(Spark SQL, 5, 1))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(5)': 'int', 'typeof(1)': 'int'} +#query +#SELECT substr('Spark SQL'::string, 5::int, 1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/string/substring.slt b/datafusion/sqllogictest/test_files/spark/string/substring.slt new file mode 100644 index 000000000000..be37a02529e9 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/substring.slt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT substring('Spark SQL', -3); +## PySpark 3.5.5 Result: {'substring(Spark SQL, -3, 2147483647)': 'SQL', 'typeof(substring(Spark SQL, -3, 2147483647))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(-3)': 'int'} +#query +#SELECT substring('Spark SQL'::string, -3::int); + +## Original Query: SELECT substring('Spark SQL', 5); +## PySpark 3.5.5 Result: {'substring(Spark SQL, 5, 2147483647)': 'k SQL', 'typeof(substring(Spark SQL, 5, 2147483647))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(5)': 'int'} +#query +#SELECT substring('Spark SQL'::string, 5::int); + +## Original Query: SELECT substring('Spark SQL', 5, 1); +## PySpark 3.5.5 Result: {'substring(Spark SQL, 5, 1)': 'k', 'typeof(substring(Spark SQL, 5, 1))': 'string', 'typeof(Spark SQL)': 'string', 'typeof(5)': 'int', 'typeof(1)': 'int'} +#query +#SELECT substring('Spark SQL'::string, 5::int, 1::int); + diff --git a/datafusion/sqllogictest/test_files/spark/string/substring_index.slt b/datafusion/sqllogictest/test_files/spark/string/substring_index.slt new file mode 100644 index 000000000000..99b6fae28be0 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/substring_index.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT substring_index('www.apache.org', '.', 2); +## PySpark 3.5.5 Result: {'substring_index(www.apache.org, ., 2)': 'www.apache', 'typeof(substring_index(www.apache.org, ., 2))': 'string', 'typeof(www.apache.org)': 'string', 'typeof(.)': 'string', 'typeof(2)': 'int'} +#query +#SELECT substring_index('www.apache.org'::string, '.'::string, 2::int); + diff --git a/datafusion/sqllogictest/test_files/spark/string/to_binary.slt b/datafusion/sqllogictest/test_files/spark/string/to_binary.slt new file mode 100644 index 000000000000..2cfe2f9c3f9a --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/to_binary.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_binary('abc', 'utf-8'); +## PySpark 3.5.5 Result: {'to_binary(abc, utf-8)': bytearray(b'abc'), 'typeof(to_binary(abc, utf-8))': 'binary', 'typeof(abc)': 'string', 'typeof(utf-8)': 'string'} +#query +#SELECT to_binary('abc'::string, 'utf-8'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/to_char.slt b/datafusion/sqllogictest/test_files/spark/string/to_char.slt new file mode 100644 index 000000000000..ca9d843fb305 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/to_char.slt @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_char(-12454.8, '99G999D9S'); +## PySpark 3.5.5 Result: {'to_char(-12454.8, 99G999D9S)': '12,454.8-', 'typeof(to_char(-12454.8, 99G999D9S))': 'string', 'typeof(-12454.8)': 'decimal(6,1)', 'typeof(99G999D9S)': 'string'} +#query +#SELECT to_char(-12454.8::decimal(6,1), '99G999D9S'::string); + +## Original Query: SELECT to_char(12454, '99G999'); +## PySpark 3.5.5 Result: {'to_char(12454, 99G999)': '12,454', 'typeof(to_char(12454, 99G999))': 'string', 'typeof(12454)': 'int', 'typeof(99G999)': 'string'} +#query +#SELECT to_char(12454::int, '99G999'::string); + +## Original Query: SELECT to_char(454, '999'); +## PySpark 3.5.5 Result: {'to_char(454, 999)': '454', 'typeof(to_char(454, 999))': 'string', 'typeof(454)': 'int', 'typeof(999)': 'string'} +#query +#SELECT to_char(454::int, '999'::string); + +## Original Query: SELECT to_char(454.00, '000D00'); +## PySpark 3.5.5 Result: {'to_char(454.00, 000D00)': '454.00', 'typeof(to_char(454.00, 000D00))': 'string', 'typeof(454.00)': 'decimal(5,2)', 'typeof(000D00)': 'string'} +#query +#SELECT to_char(454.00::decimal(5,2), '000D00'::string); + +## Original Query: SELECT to_char(78.12, '$99.99'); +## PySpark 3.5.5 Result: {'to_char(78.12, $99.99)': '$78.12', 'typeof(to_char(78.12, $99.99))': 'string', 'typeof(78.12)': 'decimal(4,2)', 'typeof($99.99)': 'string'} +#query +#SELECT to_char(78.12::decimal(4,2), '$99.99'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/to_number.slt b/datafusion/sqllogictest/test_files/spark/string/to_number.slt new file mode 100644 index 000000000000..6c10f6afc94b --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/to_number.slt @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_number('$78.12', '$99.99'); +## PySpark 3.5.5 Result: {'to_number($78.12, $99.99)': Decimal('78.12'), 'typeof(to_number($78.12, $99.99))': 'decimal(4,2)', 'typeof($78.12)': 'string', 'typeof($99.99)': 'string'} +#query +#SELECT to_number('$78.12'::string, '$99.99'::string); + +## Original Query: SELECT to_number('12,454', '99,999'); +## PySpark 3.5.5 Result: {'to_number(12,454, 99,999)': Decimal('12454'), 'typeof(to_number(12,454, 99,999))': 'decimal(5,0)', 'typeof(12,454)': 'string', 'typeof(99,999)': 'string'} +#query +#SELECT to_number('12,454'::string, '99,999'::string); + +## Original Query: SELECT to_number('12,454.8-', '99,999.9S'); +## PySpark 3.5.5 Result: {'to_number(12,454.8-, 99,999.9S)': Decimal('-12454.8'), 'typeof(to_number(12,454.8-, 99,999.9S))': 'decimal(6,1)', 'typeof(12,454.8-)': 'string', 'typeof(99,999.9S)': 'string'} +#query +#SELECT to_number('12,454.8-'::string, '99,999.9S'::string); + +## Original Query: SELECT to_number('454', '999'); +## PySpark 3.5.5 Result: {'to_number(454, 999)': Decimal('454'), 'typeof(to_number(454, 999))': 'decimal(3,0)', 'typeof(454)': 'string', 'typeof(999)': 'string'} +#query +#SELECT to_number('454'::string, '999'::string); + +## Original Query: SELECT to_number('454.00', '000.00'); +## PySpark 3.5.5 Result: {'to_number(454.00, 000.00)': Decimal('454.00'), 'typeof(to_number(454.00, 000.00))': 'decimal(5,2)', 'typeof(454.00)': 'string', 'typeof(000.00)': 'string'} +#query +#SELECT to_number('454.00'::string, '000.00'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/to_varchar.slt b/datafusion/sqllogictest/test_files/spark/string/to_varchar.slt new file mode 100644 index 000000000000..f1303324bebe --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/to_varchar.slt @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT to_varchar(-12454.8, '99G999D9S'); +## PySpark 3.5.5 Result: {'to_char(-12454.8, 99G999D9S)': '12,454.8-', 'typeof(to_char(-12454.8, 99G999D9S))': 'string', 'typeof(-12454.8)': 'decimal(6,1)', 'typeof(99G999D9S)': 'string'} +#query +#SELECT to_varchar(-12454.8::decimal(6,1), '99G999D9S'::string); + +## Original Query: SELECT to_varchar(12454, '99G999'); +## PySpark 3.5.5 Result: {'to_char(12454, 99G999)': '12,454', 'typeof(to_char(12454, 99G999))': 'string', 'typeof(12454)': 'int', 'typeof(99G999)': 'string'} +#query +#SELECT to_varchar(12454::int, '99G999'::string); + +## Original Query: SELECT to_varchar(454, '999'); +## PySpark 3.5.5 Result: {'to_char(454, 999)': '454', 'typeof(to_char(454, 999))': 'string', 'typeof(454)': 'int', 'typeof(999)': 'string'} +#query +#SELECT to_varchar(454::int, '999'::string); + +## Original Query: SELECT to_varchar(454.00, '000D00'); +## PySpark 3.5.5 Result: {'to_char(454.00, 000D00)': '454.00', 'typeof(to_char(454.00, 000D00))': 'string', 'typeof(454.00)': 'decimal(5,2)', 'typeof(000D00)': 'string'} +#query +#SELECT to_varchar(454.00::decimal(5,2), '000D00'::string); + +## Original Query: SELECT to_varchar(78.12, '$99.99'); +## PySpark 3.5.5 Result: {'to_char(78.12, $99.99)': '$78.12', 'typeof(to_char(78.12, $99.99))': 'string', 'typeof(78.12)': 'decimal(4,2)', 'typeof($99.99)': 'string'} +#query +#SELECT to_varchar(78.12::decimal(4,2), '$99.99'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/translate.slt b/datafusion/sqllogictest/test_files/spark/string/translate.slt new file mode 100644 index 000000000000..179ac93427db --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/translate.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT translate('AaBbCc', 'abc', '123'); +## PySpark 3.5.5 Result: {'translate(AaBbCc, abc, 123)': 'A1B2C3', 'typeof(translate(AaBbCc, abc, 123))': 'string', 'typeof(AaBbCc)': 'string', 'typeof(abc)': 'string', 'typeof(123)': 'string'} +#query +#SELECT translate('AaBbCc'::string, 'abc'::string, '123'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/trim.slt b/datafusion/sqllogictest/test_files/spark/string/trim.slt new file mode 100644 index 000000000000..ebd637b71fff --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/trim.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT trim(' SparkSQL '); +## PySpark 3.5.5 Result: {'trim( SparkSQL )': 'SparkSQL', 'typeof(trim( SparkSQL ))': 'string', 'typeof( SparkSQL )': 'string'} +#query +#SELECT trim(' SparkSQL '::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/try_to_binary.slt b/datafusion/sqllogictest/test_files/spark/string/try_to_binary.slt new file mode 100644 index 000000000000..3f6935327f66 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/try_to_binary.slt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT try_to_binary('abc', 'utf-8'); +## PySpark 3.5.5 Result: {'try_to_binary(abc, utf-8)': bytearray(b'abc'), 'typeof(try_to_binary(abc, utf-8))': 'binary', 'typeof(abc)': 'string', 'typeof(utf-8)': 'string'} +#query +#SELECT try_to_binary('abc'::string, 'utf-8'::string); + +## Original Query: select try_to_binary('a!', 'base64'); +## PySpark 3.5.5 Result: {'try_to_binary(a!, base64)': None, 'typeof(try_to_binary(a!, base64))': 'binary', 'typeof(a!)': 'string', 'typeof(base64)': 'string'} +#query +#SELECT try_to_binary('a!'::string, 'base64'::string); + +## Original Query: select try_to_binary('abc', 'invalidFormat'); +## PySpark 3.5.5 Result: {'try_to_binary(abc, invalidFormat)': None, 'typeof(try_to_binary(abc, invalidFormat))': 'binary', 'typeof(abc)': 'string', 'typeof(invalidFormat)': 'string'} +#query +#SELECT try_to_binary('abc'::string, 'invalidFormat'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/try_to_number.slt b/datafusion/sqllogictest/test_files/spark/string/try_to_number.slt new file mode 100644 index 000000000000..77e8f5a4cb75 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/try_to_number.slt @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT try_to_number('$78.12', '$99.99'); +## PySpark 3.5.5 Result: {'try_to_number($78.12, $99.99)': Decimal('78.12'), 'typeof(try_to_number($78.12, $99.99))': 'decimal(4,2)', 'typeof($78.12)': 'string', 'typeof($99.99)': 'string'} +#query +#SELECT try_to_number('$78.12'::string, '$99.99'::string); + +## Original Query: SELECT try_to_number('12,454', '99,999'); +## PySpark 3.5.5 Result: {'try_to_number(12,454, 99,999)': Decimal('12454'), 'typeof(try_to_number(12,454, 99,999))': 'decimal(5,0)', 'typeof(12,454)': 'string', 'typeof(99,999)': 'string'} +#query +#SELECT try_to_number('12,454'::string, '99,999'::string); + +## Original Query: SELECT try_to_number('12,454.8-', '99,999.9S'); +## PySpark 3.5.5 Result: {'try_to_number(12,454.8-, 99,999.9S)': Decimal('-12454.8'), 'typeof(try_to_number(12,454.8-, 99,999.9S))': 'decimal(6,1)', 'typeof(12,454.8-)': 'string', 'typeof(99,999.9S)': 'string'} +#query +#SELECT try_to_number('12,454.8-'::string, '99,999.9S'::string); + +## Original Query: SELECT try_to_number('454', '999'); +## PySpark 3.5.5 Result: {'try_to_number(454, 999)': Decimal('454'), 'typeof(try_to_number(454, 999))': 'decimal(3,0)', 'typeof(454)': 'string', 'typeof(999)': 'string'} +#query +#SELECT try_to_number('454'::string, '999'::string); + +## Original Query: SELECT try_to_number('454.00', '000.00'); +## PySpark 3.5.5 Result: {'try_to_number(454.00, 000.00)': Decimal('454.00'), 'typeof(try_to_number(454.00, 000.00))': 'decimal(5,2)', 'typeof(454.00)': 'string', 'typeof(000.00)': 'string'} +#query +#SELECT try_to_number('454.00'::string, '000.00'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/ucase.slt b/datafusion/sqllogictest/test_files/spark/string/ucase.slt new file mode 100644 index 000000000000..ff0d2a367452 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/ucase.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT ucase('SparkSql'); +## PySpark 3.5.5 Result: {'ucase(SparkSql)': 'SPARKSQL', 'typeof(ucase(SparkSql))': 'string', 'typeof(SparkSql)': 'string'} +#query +#SELECT ucase('SparkSql'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/unbase64.slt b/datafusion/sqllogictest/test_files/spark/string/unbase64.slt new file mode 100644 index 000000000000..f9afde439046 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/unbase64.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT unbase64('U3BhcmsgU1FM'); +## PySpark 3.5.5 Result: {'unbase64(U3BhcmsgU1FM)': bytearray(b'Spark SQL'), 'typeof(unbase64(U3BhcmsgU1FM))': 'binary', 'typeof(U3BhcmsgU1FM)': 'string'} +#query +#SELECT unbase64('U3BhcmsgU1FM'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/string/upper.slt b/datafusion/sqllogictest/test_files/spark/string/upper.slt new file mode 100644 index 000000000000..62124c315de7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/string/upper.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT upper('SparkSql'); +## PySpark 3.5.5 Result: {'upper(SparkSql)': 'SPARKSQL', 'typeof(upper(SparkSql))': 'string', 'typeof(SparkSql)': 'string'} +#query +#SELECT upper('SparkSql'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/struct/named_struct.slt b/datafusion/sqllogictest/test_files/spark/struct/named_struct.slt new file mode 100644 index 000000000000..ae08770e8b6e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/struct/named_struct.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT named_struct("a", 1, "b", 2, "c", 3); +## PySpark 3.5.5 Result: {'named_struct(a, 1, b, 2, c, 3)': Row(a=1, b=2, c=3), 'typeof(named_struct(a, 1, b, 2, c, 3))': 'struct', 'typeof(a)': 'string', 'typeof(1)': 'int', 'typeof(b)': 'string', 'typeof(2)': 'int', 'typeof(c)': 'string', 'typeof(3)': 'int'} +#query +#SELECT named_struct('a'::string, 1::int, 'b'::string, 2::int, 'c'::string, 3::int); + diff --git a/datafusion/sqllogictest/test_files/spark/struct/struct.slt b/datafusion/sqllogictest/test_files/spark/struct/struct.slt new file mode 100644 index 000000000000..abfd65462477 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/struct/struct.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT struct(1, 2, 3); +## PySpark 3.5.5 Result: {'struct(1, 2, 3)': Row(col1=1, col2=2, col3=3), 'typeof(struct(1, 2, 3))': 'struct', 'typeof(1)': 'int', 'typeof(2)': 'int', 'typeof(3)': 'int'} +#query +#SELECT struct(1::int, 2::int, 3::int); + diff --git a/datafusion/sqllogictest/test_files/spark/url/parse_url.slt b/datafusion/sqllogictest/test_files/spark/url/parse_url.slt new file mode 100644 index 000000000000..eeeb154d7ff2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/url/parse_url.slt @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT parse_url('http://spark.apache.org/path?query=1', 'HOST'); +## PySpark 3.5.5 Result: {'parse_url(http://spark.apache.org/path?query=1, HOST)': 'spark.apache.org', 'typeof(parse_url(http://spark.apache.org/path?query=1, HOST))': 'string', 'typeof(http://spark.apache.org/path?query=1)': 'string', 'typeof(HOST)': 'string'} +#query +#SELECT parse_url('http://spark.apache.org/path?query=1'::string, 'HOST'::string); + +## Original Query: SELECT parse_url('http://spark.apache.org/path?query=1', 'QUERY'); +## PySpark 3.5.5 Result: {'parse_url(http://spark.apache.org/path?query=1, QUERY)': 'query=1', 'typeof(parse_url(http://spark.apache.org/path?query=1, QUERY))': 'string', 'typeof(http://spark.apache.org/path?query=1)': 'string', 'typeof(QUERY)': 'string'} +#query +#SELECT parse_url('http://spark.apache.org/path?query=1'::string, 'QUERY'::string); + +## Original Query: SELECT parse_url('http://spark.apache.org/path?query=1', 'QUERY', 'query'); +## PySpark 3.5.5 Result: {'parse_url(http://spark.apache.org/path?query=1, QUERY, query)': '1', 'typeof(parse_url(http://spark.apache.org/path?query=1, QUERY, query))': 'string', 'typeof(http://spark.apache.org/path?query=1)': 'string', 'typeof(QUERY)': 'string', 'typeof(query)': 'string'} +#query +#SELECT parse_url('http://spark.apache.org/path?query=1'::string, 'QUERY'::string, 'query'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/url/url_decode.slt b/datafusion/sqllogictest/test_files/spark/url/url_decode.slt new file mode 100644 index 000000000000..cbb89b29984e --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/url/url_decode.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT url_decode('https%3A%2F%2Fspark.apache.org'); +## PySpark 3.5.5 Result: {'url_decode(https%3A%2F%2Fspark.apache.org)': 'https://spark.apache.org', 'typeof(url_decode(https%3A%2F%2Fspark.apache.org))': 'string', 'typeof(https%3A%2F%2Fspark.apache.org)': 'string'} +#query +#SELECT url_decode('https%3A%2F%2Fspark.apache.org'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/url/url_encode.slt b/datafusion/sqllogictest/test_files/spark/url/url_encode.slt new file mode 100644 index 000000000000..c66e11319332 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/url/url_encode.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT url_encode('https://spark.apache.org'); +## PySpark 3.5.5 Result: {'url_encode(https://spark.apache.org)': 'https%3A%2F%2Fspark.apache.org', 'typeof(url_encode(https://spark.apache.org))': 'string', 'typeof(https://spark.apache.org)': 'string'} +#query +#SELECT url_encode('https://spark.apache.org'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/xml/xpath.slt b/datafusion/sqllogictest/test_files/spark/xml/xpath.slt new file mode 100644 index 000000000000..a1aa1b85c347 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/xml/xpath.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT xpath('b1b2b3c1c2','a/b'); +## PySpark 3.5.5 Result: {'xpath(b1b2b3c1c2, a/b)': [None, None, None], 'typeof(xpath(b1b2b3c1c2, a/b))': 'array', 'typeof(b1b2b3c1c2)': 'string', 'typeof(a/b)': 'string'} +#query +#SELECT xpath('b1b2b3c1c2'::string, 'a/b'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/xml/xpath_boolean.slt b/datafusion/sqllogictest/test_files/spark/xml/xpath_boolean.slt new file mode 100644 index 000000000000..6d4a2c6db7fd --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/xml/xpath_boolean.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT xpath_boolean('1','a/b'); +## PySpark 3.5.5 Result: {'xpath_boolean(1, a/b)': True, 'typeof(xpath_boolean(1, a/b))': 'boolean', 'typeof(1)': 'string', 'typeof(a/b)': 'string'} +#query +#SELECT xpath_boolean('1'::string, 'a/b'::string); + diff --git a/datafusion/sqllogictest/test_files/spark/xml/xpath_string.slt b/datafusion/sqllogictest/test_files/spark/xml/xpath_string.slt new file mode 100644 index 000000000000..4b725ced11c9 --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/xml/xpath_string.slt @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was originally created by a porting script from: +# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function +# This file is part of the implementation of the datafusion-spark function library. +# For more information, please see: +# https://github.com/apache/datafusion/issues/15914 + +## Original Query: SELECT xpath_string('bcc','a/c'); +## PySpark 3.5.5 Result: {'xpath_string(bcc, a/c)': 'cc', 'typeof(xpath_string(bcc, a/c))': 'string', 'typeof(bcc)': 'string', 'typeof(a/c)': 'string'} +#query +#SELECT xpath_string('bcc'::string, 'a/c'::string); + diff --git a/datafusion/sqllogictest/test_files/strings.slt b/datafusion/sqllogictest/test_files/strings.slt index 81b8f4b2da9a..9fa453fa0252 100644 --- a/datafusion/sqllogictest/test_files/strings.slt +++ b/datafusion/sqllogictest/test_files/strings.slt @@ -115,6 +115,12 @@ p1 p1e1 p1m1e1 +query T rowsort +SELECT s FROM test WHERE s ILIKE 'p1'; +---- +P1 +p1 + # NOT ILIKE query T rowsort SELECT s FROM test WHERE s NOT ILIKE 'p1%'; diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index bdba73876103..95eeffc31903 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -53,9 +53,9 @@ select * from struct_values; query TT select arrow_typeof(s1), arrow_typeof(s2) from struct_values; ---- -Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(c0 Int32) Struct(a Int32, b Utf8View) +Struct(c0 Int32) Struct(a Int32, b Utf8View) +Struct(c0 Int32) Struct(a Int32, b Utf8View) # struct[i] @@ -229,12 +229,12 @@ select named_struct('field_a', 1, 'field_b', 2); query T select arrow_typeof(named_struct('first', 1, 'second', 2, 'third', 3)); ---- -Struct([Field { name: "first", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "second", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "third", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(first Int64, second Int64, third Int64) query T select arrow_typeof({'first': 1, 'second': 2, 'third': 3}); ---- -Struct([Field { name: "first", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "second", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "third", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(first Int64, second Int64, third Int64) # test nested struct literal query ? @@ -271,12 +271,33 @@ select a from values where (a, c) = (1, 'a'); ---- 1 +query I +select a from values as v where (v.a, v.c) = (1, 'a'); +---- +1 + +query I +select a from values as v where (v.a, v.c) != (1, 'a'); +---- +2 +3 + +query I +select a from values as v where (v.a, v.c) = (1, 'b'); +---- + query I select a from values where (a, c) IN ((1, 'a'), (2, 'b')); ---- 1 2 +query I +select a from values as v where (v.a, v.c) IN ((1, 'a'), (2, 'b')); +---- +1 +2 + statement ok drop table values; @@ -392,7 +413,7 @@ create table t(a struct, b struct) as valu query T select arrow_typeof([a, b]) from t; ---- -List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) query ? select [a, b] from t; @@ -443,12 +464,12 @@ select * from t; query T select arrow_typeof(c1) from t; ---- -Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(r Utf8View, b Int32) query T select arrow_typeof(c2) from t; ---- -Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(r Utf8View, b Float32) statement ok drop table t; @@ -465,8 +486,8 @@ select * from t; query T select arrow_typeof(column1) from t; ---- -Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(r Utf8, c Float64) +Struct(r Utf8, c Float64) statement ok drop table t; @@ -498,9 +519,9 @@ select coalesce(s1) from t; query T select arrow_typeof(coalesce(s1, s2)) from t; ---- -Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(a Float32, b Utf8View) +Struct(a Float32, b Utf8View) +Struct(a Float32, b Utf8View) statement ok drop table t; @@ -525,9 +546,9 @@ select coalesce(s1, s2) from t; query T select arrow_typeof(coalesce(s1, s2)) from t; ---- -Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) -Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(a Float32, b Utf8View) +Struct(a Float32, b Utf8View) +Struct(a Float32, b Utf8View) statement ok drop table t; @@ -562,7 +583,7 @@ create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as valu query T select arrow_typeof([a, b]) from t; ---- -List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) statement ok drop table t; @@ -585,13 +606,13 @@ create table t(a struct(r varchar, c int, g float), b struct(r varchar, c float, query T select arrow_typeof(a) from t; ---- -Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "g", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(r Utf8View, c Int32, g Float32) # type of each column should not coerced but perserve as it is query T select arrow_typeof(b) from t; ---- -Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "g", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct(r Utf8View, c Float32, g Int32) statement ok drop table t; diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index a0ac15b740d7..796570633f67 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -400,7 +400,7 @@ logical_plan 01)LeftSemi Join: 02)--TableScan: t1 projection=[t1_id, t1_name, t1_int] 03)--SubqueryAlias: __correlated_sq_1 -04)----Projection: +04)----Projection: 05)------Filter: t1.t1_int < t1.t1_id 06)--------TableScan: t1 projection=[t1_id, t1_int] @@ -1453,7 +1453,7 @@ logical_plan 01)LeftSemi Join: 02)--TableScan: t1 projection=[a] 03)--SubqueryAlias: __correlated_sq_1 -04)----Projection: +04)----Projection: 05)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 06)--------TableScan: t2 projection=[] diff --git a/datafusion/sqllogictest/test_files/subquery_sort.slt b/datafusion/sqllogictest/test_files/subquery_sort.slt index 5d22bf92e7e6..d993515f4de9 100644 --- a/datafusion/sqllogictest/test_files/subquery_sort.slt +++ b/datafusion/sqllogictest/test_files/subquery_sort.slt @@ -100,7 +100,7 @@ physical_plan 01)ProjectionExec: expr=[c1@0 as c1, r@1 as r] 02)--SortExec: TopK(fetch=2), expr=[c1@0 ASC NULLS LAST, c3@2 ASC NULLS LAST, c9@3 ASC NULLS LAST], preserve_partitioning=[false] 03)----ProjectionExec: expr=[c1@0 as c1, rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as r, c3@1 as c3, c9@2 as c9] -04)------BoundedWindowAggExec: wdw=[rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Utf8(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Utf8View(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] 05)--------SortExec: expr=[c1@0 DESC], preserve_partitioning=[false] 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c3, c9], file_type=csv, has_header=true @@ -127,9 +127,8 @@ physical_plan 02)--SortExec: TopK(fetch=2), expr=[c1@0 ASC NULLS LAST, c3@2 ASC NULLS LAST, c9@3 ASC NULLS LAST], preserve_partitioning=[false] 03)----ProjectionExec: expr=[c1@0 as c1, rank() ORDER BY [sink_table_with_utf8view.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as r, c3@1 as c3, c9@2 as c9] 04)------BoundedWindowAggExec: wdw=[rank() ORDER BY [sink_table_with_utf8view.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() ORDER BY [sink_table_with_utf8view.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Utf8View(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -05)--------SortPreservingMergeExec: [c1@0 DESC] -06)----------SortExec: expr=[c1@0 DESC], preserve_partitioning=[true] -07)------------DataSourceExec: partitions=4, partition_sizes=[1, 0, 0, 0] +05)--------SortExec: expr=[c1@0 DESC], preserve_partitioning=[false] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] statement ok DROP TABLE sink_table_with_utf8view; diff --git a/datafusion/sqllogictest/test_files/table_functions.slt b/datafusion/sqllogictest/test_files/table_functions.slt index 7d318c50bacf..0159abe8d06b 100644 --- a/datafusion/sqllogictest/test_files/table_functions.slt +++ b/datafusion/sqllogictest/test_files/table_functions.slt @@ -153,23 +153,23 @@ SELECT * FROM generate_series(1, 5, NULL) query TT EXPLAIN SELECT * FROM generate_series(1, 5) ---- -logical_plan TableScan: tmp_table projection=[value] +logical_plan TableScan: generate_series() projection=[value] physical_plan LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=5, batch_size=8192] # # Test generate_series with invalid arguments # -query error DataFusion error: Error during planning: start is bigger than end, but increment is positive: cannot generate infinite series +query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series SELECT * FROM generate_series(5, 1) -query error DataFusion error: Error during planning: start is smaller than end, but increment is negative: cannot generate infinite series +query error DataFusion error: Error during planning: Start is smaller than end, but increment is negative: Cannot generate infinite series SELECT * FROM generate_series(-6, 6, -1) -query error DataFusion error: Error during planning: step cannot be zero +query error DataFusion error: Error during planning: Step cannot be zero SELECT * FROM generate_series(-6, 6, 0) -query error DataFusion error: Error during planning: start is bigger than end, but increment is positive: cannot generate infinite series +query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series SELECT * FROM generate_series(6, -6, 1) @@ -177,7 +177,7 @@ statement error DataFusion error: Error during planning: generate_series functio SELECT * FROM generate_series(1, 2, 3, 4) -statement error DataFusion error: Error during planning: First argument must be an integer literal +statement error DataFusion error: Error during planning: Argument \#1 must be an INTEGER, TIMESTAMP, DATE or NULL, got Utf8 SELECT * FROM generate_series('foo', 'bar') # UDF and UDTF `generate_series` can be used simultaneously @@ -220,6 +220,12 @@ SELECT * FROM range(3, 6) 4 5 +query I rowsort +SELECT * FROM range(1, 1+2) +---- +1 +2 + # #generated_data > batch_size query I SELECT count(v1) FROM range(-66666,66666) t1(v1) @@ -270,23 +276,23 @@ SELECT * FROM range(1, 5, NULL) query TT EXPLAIN SELECT * FROM range(1, 5) ---- -logical_plan TableScan: tmp_table projection=[value] +logical_plan TableScan: range() projection=[value] physical_plan LazyMemoryExec: partitions=1, batch_generators=[range: start=1, end=5, batch_size=8192] # # Test range with invalid arguments # -query error DataFusion error: Error during planning: start is bigger than end, but increment is positive: cannot generate infinite series +query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series SELECT * FROM range(5, 1) -query error DataFusion error: Error during planning: start is smaller than end, but increment is negative: cannot generate infinite series +query error DataFusion error: Error during planning: Start is smaller than end, but increment is negative: Cannot generate infinite series SELECT * FROM range(-6, 6, -1) -query error DataFusion error: Error during planning: step cannot be zero +query error DataFusion error: Error during planning: Step cannot be zero SELECT * FROM range(-6, 6, 0) -query error DataFusion error: Error during planning: start is bigger than end, but increment is positive: cannot generate infinite series +query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series SELECT * FROM range(6, -6, 1) @@ -294,12 +300,197 @@ statement error DataFusion error: Error during planning: range function requires SELECT * FROM range(1, 2, 3, 4) -statement error DataFusion error: Error during planning: First argument must be an integer literal +statement error DataFusion error: Error during planning: Argument \#1 must be an INTEGER, TIMESTAMP, DATE or NULL, got Utf8 SELECT * FROM range('foo', 'bar') +statement error DataFusion error: Error during planning: Argument #2 must be an INTEGER or NULL, got Literal\(Utf8\("bar"\), None\) +SELECT * FROM range(1, 'bar') + # UDF and UDTF `range` can be used simultaneously query ? rowsort SELECT range(1, t1.end) FROM range(3, 5) as t1(end) ---- [1, 2, 3] [1, 2] + +# +# Test timestamp ranges +# + +# Basic timestamp range with 1 day interval +query P rowsort +SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-04T00:00:00', INTERVAL '1' DAY) +---- +2023-01-01T00:00:00 +2023-01-02T00:00:00 +2023-01-03T00:00:00 + +# Timestamp range with hour interval +query P rowsort +SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-01T03:00:00', INTERVAL '1' HOUR) +---- +2023-01-01T00:00:00 +2023-01-01T01:00:00 +2023-01-01T02:00:00 + +# Timestamp range with month interval +query P rowsort +SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-04-01T00:00:00', INTERVAL '1' MONTH) +---- +2023-01-01T00:00:00 +2023-02-01T00:00:00 +2023-03-01T00:00:00 + +# Timestamp generate_series (includes end) +query P rowsort +SELECT * FROM generate_series(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-03T00:00:00', INTERVAL '1' DAY) +---- +2023-01-01T00:00:00 +2023-01-02T00:00:00 +2023-01-03T00:00:00 + +# Timestamp range with timezone +query P +SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00+00:00', TIMESTAMP '2023-01-03T00:00:00+00:00', INTERVAL '1' DAY) +---- +2023-01-01T00:00:00 +2023-01-02T00:00:00 + +# Negative timestamp range (going backwards) +query P +SELECT * FROM range(TIMESTAMP '2023-01-03T00:00:00', TIMESTAMP '2023-01-01T00:00:00', INTERVAL '-1' DAY) +---- +2023-01-03T00:00:00 +2023-01-02T00:00:00 + +query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +SELECT * FROM range(TIMESTAMP '2023-01-03T00:00:00', TIMESTAMP '2023-01-01T00:00:00', INTERVAL '1' DAY) + +query error DataFusion error: Error during planning: Start is smaller than end, but increment is negative: Cannot generate infinite series +SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-02T00:00:00', INTERVAL '-1' DAY) + +query error DataFusion error: Error during planning: range function with timestamps requires exactly 3 arguments +SELECT * FROM range(TIMESTAMP '2023-01-03T00:00:00', TIMESTAMP '2023-01-01T00:00:00') + +# Single timestamp (start == end) +query P +SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-01T00:00:00', INTERVAL '1' DAY) +---- + +# Timestamp range with NULL values +query P +SELECT * FROM range(NULL::TIMESTAMP, TIMESTAMP '2023-01-03T00:00:00', INTERVAL '1' DAY) +---- + +query P +SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00', NULL::TIMESTAMP, INTERVAL '1' DAY) +---- + +# No interval gives no rows +query P +SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-03T00:00:00', NULL::INTERVAL) +---- + +# Zero-length interval gives error +query error DataFusion error: Error during planning: Step interval cannot be zero +SELECT * FROM range(TIMESTAMP '2023-01-01T00:00:00', TIMESTAMP '2023-01-03T00:00:00', INTERVAL '0' DAY) + +# Timezone-aware +query P +SELECT * FROM range(TIMESTAMPTZ '2023-02-01T00:00:00-07:00', TIMESTAMPTZ '2023-02-01T09:00:00+01:00', INTERVAL '1' HOUR); +---- +2023-02-01T07:00:00Z + +# Basic date range with hour interval +query P +SELECT * FROM range(DATE '1992-01-01', DATE '1992-01-03', INTERVAL '6' HOUR); +---- +1992-01-01T00:00:00 +1992-01-01T06:00:00 +1992-01-01T12:00:00 +1992-01-01T18:00:00 +1992-01-02T00:00:00 +1992-01-02T06:00:00 +1992-01-02T12:00:00 +1992-01-02T18:00:00 + +# Date range with day interval +query P +SELECT * FROM range(DATE '1992-09-01', DATE '1992-09-05', INTERVAL '1' DAY); +---- +1992-09-01T00:00:00 +1992-09-02T00:00:00 +1992-09-03T00:00:00 +1992-09-04T00:00:00 + +# Date range with month interval +query P +SELECT * FROM range(DATE '1992-09-01', DATE '1993-01-01', INTERVAL '1' MONTH); +---- +1992-09-01T00:00:00 +1992-10-01T00:00:00 +1992-11-01T00:00:00 +1992-12-01T00:00:00 + +# Date range generate_series includes end +query P +SELECT * FROM generate_series(DATE '1992-09-01', DATE '1992-09-03', INTERVAL '1' DAY); +---- +1992-09-01T00:00:00 +1992-09-02T00:00:00 +1992-09-03T00:00:00 + +query TT +EXPLAIN SELECT * FROM generate_series(DATE '1992-09-01', DATE '1992-09-03', INTERVAL '1' DAY); +---- +logical_plan TableScan: generate_series() projection=[value] +physical_plan LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=715305600000000000, end=715478400000000000, batch_size=8192] + +# Backwards date range +query P +SELECT * FROM range(DATE '1992-09-05', DATE '1992-09-01', INTERVAL '-1' DAY); +---- +1992-09-05T00:00:00 +1992-09-04T00:00:00 +1992-09-03T00:00:00 +1992-09-02T00:00:00 + +# NULL handling for dates +query P +SELECT * FROM range(DATE '1992-09-01', NULL::DATE, INTERVAL '1' MONTH) +---- + +query TT +EXPLAIN SELECT * FROM range(DATE '1992-09-01', NULL::DATE, INTERVAL '1' MONTH) +---- +logical_plan TableScan: range() projection=[value] +physical_plan LazyMemoryExec: partitions=1, batch_generators=[range: empty] + +query P +SELECT * FROM range(NULL::DATE, DATE '1992-09-01', INTERVAL '1' MONTH) +---- + +query P +SELECT * FROM range(DATE '1992-09-01', DATE '1992-10-01', NULL::INTERVAL) +---- + +query error DataFusion error: Error during planning: Start is bigger than end, but increment is positive: Cannot generate infinite series +SELECT * FROM range(DATE '2023-01-03', DATE '2023-01-01', INTERVAL '1' DAY) + +query error DataFusion error: Error during planning: Start is smaller than end, but increment is negative: Cannot generate infinite series +SELECT * FROM range(DATE '2023-01-01', DATE '2023-01-02', INTERVAL '-1' DAY) + +query error DataFusion error: Error during planning: range function with dates requires exactly 3 arguments +SELECT * FROM range(DATE '2023-01-01', DATE '2023-01-03') + +# Table function as relation +statement ok +CREATE OR REPLACE TABLE json_table (c INT) AS VALUES (1), (2); + +query II +SELECT c, f.* FROM json_table, LATERAL generate_series(1,2) f; +---- +1 1 +1 2 +2 1 +2 2 diff --git a/datafusion/sqllogictest/test_files/topk.slt b/datafusion/sqllogictest/test_files/topk.slt index ce23fe26528c..9ff382d32af9 100644 --- a/datafusion/sqllogictest/test_files/topk.slt +++ b/datafusion/sqllogictest/test_files/topk.slt @@ -316,7 +316,7 @@ explain select number, letter, age from partial_sorted order by number desc, let ---- physical_plan 01)SortExec: TopK(fetch=3), expr=[number@0 DESC, letter@1 ASC NULLS LAST, age@2 DESC], preserve_partitioning=[false], sort_prefix=[number@0 DESC, letter@1 ASC NULLS LAST] -02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet, predicate=DynamicFilterPhysicalExpr [ true ] # Explain variations of the above query with different orderings, and different sort prefixes. @@ -326,28 +326,28 @@ explain select number, letter, age from partial_sorted order by age desc limit 3 ---- physical_plan 01)SortExec: TopK(fetch=3), expr=[age@2 DESC], preserve_partitioning=[false] -02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet, predicate=DynamicFilterPhysicalExpr [ true ] query TT explain select number, letter, age from partial_sorted order by number desc, letter desc limit 3; ---- physical_plan 01)SortExec: TopK(fetch=3), expr=[number@0 DESC, letter@1 DESC], preserve_partitioning=[false], sort_prefix=[number@0 DESC] -02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet, predicate=DynamicFilterPhysicalExpr [ true ] query TT explain select number, letter, age from partial_sorted order by number asc limit 3; ---- physical_plan 01)SortExec: TopK(fetch=3), expr=[number@0 ASC NULLS LAST], preserve_partitioning=[false] -02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet, predicate=DynamicFilterPhysicalExpr [ true ] query TT explain select number, letter, age from partial_sorted order by letter asc, number desc limit 3; ---- physical_plan 01)SortExec: TopK(fetch=3), expr=[letter@1 ASC NULLS LAST, number@0 DESC], preserve_partitioning=[false] -02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet, predicate=DynamicFilterPhysicalExpr [ true ] # Explicit NULLS ordering cases (reversing the order of the NULLS on the number and letter orderings) query TT @@ -355,14 +355,14 @@ explain select number, letter, age from partial_sorted order by number desc, let ---- physical_plan 01)SortExec: TopK(fetch=3), expr=[number@0 DESC, letter@1 ASC], preserve_partitioning=[false], sort_prefix=[number@0 DESC] -02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet, predicate=DynamicFilterPhysicalExpr [ true ] query TT explain select number, letter, age from partial_sorted order by number desc NULLS LAST, letter asc limit 3; ---- physical_plan 01)SortExec: TopK(fetch=3), expr=[number@0 DESC NULLS LAST, letter@1 ASC NULLS LAST], preserve_partitioning=[false] -02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet, predicate=DynamicFilterPhysicalExpr [ true ] # Verify that the sort prefix is correctly computed on the normalized ordering (removing redundant aliased columns) @@ -370,7 +370,7 @@ query TT explain select number, letter, age, number as column4, letter as column5 from partial_sorted order by number desc, column4 desc, letter asc, column5 asc, age desc limit 3; ---- physical_plan -01)SortExec: TopK(fetch=3), expr=[number@0 DESC, column4@3 DESC, letter@1 ASC NULLS LAST, column5@4 ASC NULLS LAST, age@2 DESC], preserve_partitioning=[false], sort_prefix=[number@0 DESC, letter@1 ASC NULLS LAST] +01)SortExec: TopK(fetch=3), expr=[number@0 DESC, letter@1 ASC NULLS LAST, age@2 DESC], preserve_partitioning=[false], sort_prefix=[number@0 DESC, letter@1 ASC NULLS LAST] 02)--ProjectionExec: expr=[number@0 as number, letter@1 as letter, age@2 as age, number@0 as column4, letter@1 as column5] 03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, letter, age], output_ordering=[number@0 DESC, letter@1 ASC NULLS LAST], file_type=parquet @@ -380,7 +380,7 @@ explain select number + 1 as number_plus, number, number + 1 as other_number_plu ---- physical_plan 01)SortPreservingMergeExec: [number_plus@0 DESC, number@1 DESC, other_number_plus@2 DESC, age@3 ASC NULLS LAST], fetch=3 -02)--SortExec: TopK(fetch=3), expr=[number_plus@0 DESC, number@1 DESC, other_number_plus@2 DESC, age@3 ASC NULLS LAST], preserve_partitioning=[true], sort_prefix=[number_plus@0 DESC, number@1 DESC] +02)--SortExec: TopK(fetch=3), expr=[number_plus@0 DESC, number@1 DESC, age@3 ASC NULLS LAST], preserve_partitioning=[true], sort_prefix=[number_plus@0 DESC, number@1 DESC] 03)----ProjectionExec: expr=[__common_expr_1@0 as number_plus, number@1 as number, __common_expr_1@0 as other_number_plus, age@2 as age] 04)------ProjectionExec: expr=[CAST(number@0 AS Int64) + 1 as __common_expr_1, number@0 as number, age@1 as age] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q10.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q10.slt.part index fee496f92055..04de9153a047 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q10.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q10.slt.part @@ -65,8 +65,8 @@ logical_plan 12)--------------------Filter: orders.o_orderdate >= Date32("1993-10-01") AND orders.o_orderdate < Date32("1994-01-01") 13)----------------------TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate], partial_filters=[orders.o_orderdate >= Date32("1993-10-01"), orders.o_orderdate < Date32("1994-01-01")] 14)--------------Projection: lineitem.l_orderkey, lineitem.l_extendedprice, lineitem.l_discount -15)----------------Filter: lineitem.l_returnflag = Utf8("R") -16)------------------TableScan: lineitem projection=[l_orderkey, l_extendedprice, l_discount, l_returnflag], partial_filters=[lineitem.l_returnflag = Utf8("R")] +15)----------------Filter: lineitem.l_returnflag = Utf8View("R") +16)------------------TableScan: lineitem projection=[l_orderkey, l_extendedprice, l_discount, l_returnflag], partial_filters=[lineitem.l_returnflag = Utf8View("R")] 17)----------TableScan: nation projection=[n_nationkey, n_name] physical_plan 01)SortPreservingMergeExec: [revenue@2 DESC], fetch=10 diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part index 1dba8c053720..a6225daae436 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part @@ -58,8 +58,8 @@ logical_plan 09)----------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost], partial_filters=[Boolean(true)] 10)----------------TableScan: supplier projection=[s_suppkey, s_nationkey] 11)------------Projection: nation.n_nationkey -12)--------------Filter: nation.n_name = Utf8("GERMANY") -13)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] +12)--------------Filter: nation.n_name = Utf8View("GERMANY") +13)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] 14)------SubqueryAlias: __scalar_sq_1 15)--------Projection: CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS Decimal128(38, 15)) 16)----------Aggregate: groupBy=[[]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] @@ -70,8 +70,8 @@ logical_plan 21)--------------------TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost] 22)--------------------TableScan: supplier projection=[s_suppkey, s_nationkey] 23)----------------Projection: nation.n_nationkey -24)------------------Filter: nation.n_name = Utf8("GERMANY") -25)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] +24)------------------Filter: nation.n_name = Utf8View("GERMANY") +25)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")] physical_plan 01)SortExec: TopK(fetch=10), expr=[value@1 DESC], preserve_partitioning=[false] 02)--ProjectionExec: expr=[ps_partkey@0 as ps_partkey, sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q12.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q12.slt.part index 3757fc48dba0..f7344daed8c7 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q12.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q12.slt.part @@ -51,12 +51,12 @@ order by logical_plan 01)Sort: lineitem.l_shipmode ASC NULLS LAST 02)--Projection: lineitem.l_shipmode, sum(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS high_line_count, sum(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS low_line_count -03)----Aggregate: groupBy=[[lineitem.l_shipmode]], aggr=[[sum(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), sum(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)]] +03)----Aggregate: groupBy=[[lineitem.l_shipmode]], aggr=[[sum(CASE WHEN orders.o_orderpriority = Utf8View("1-URGENT") OR orders.o_orderpriority = Utf8View("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS sum(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), sum(CASE WHEN orders.o_orderpriority != Utf8View("1-URGENT") AND orders.o_orderpriority != Utf8View("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS sum(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)]] 04)------Projection: lineitem.l_shipmode, orders.o_orderpriority 05)--------Inner Join: lineitem.l_orderkey = orders.o_orderkey 06)----------Projection: lineitem.l_orderkey, lineitem.l_shipmode -07)------------Filter: (lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP")) AND lineitem.l_receiptdate > lineitem.l_commitdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("1994-01-01") AND lineitem.l_receiptdate < Date32("1995-01-01") -08)--------------TableScan: lineitem projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP"), lineitem.l_receiptdate > lineitem.l_commitdate, lineitem.l_shipdate < lineitem.l_commitdate, lineitem.l_receiptdate >= Date32("1994-01-01"), lineitem.l_receiptdate < Date32("1995-01-01")] +07)------------Filter: (lineitem.l_shipmode = Utf8View("MAIL") OR lineitem.l_shipmode = Utf8View("SHIP")) AND lineitem.l_receiptdate > lineitem.l_commitdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("1994-01-01") AND lineitem.l_receiptdate < Date32("1995-01-01") +08)--------------TableScan: lineitem projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8View("MAIL") OR lineitem.l_shipmode = Utf8View("SHIP"), lineitem.l_receiptdate > lineitem.l_commitdate, lineitem.l_shipdate < lineitem.l_commitdate, lineitem.l_receiptdate >= Date32("1994-01-01"), lineitem.l_receiptdate < Date32("1995-01-01")] 09)----------TableScan: orders projection=[o_orderkey, o_orderpriority] physical_plan 01)SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q13.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q13.slt.part index e9d9cf141d10..96f3bd6edf32 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q13.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q13.slt.part @@ -50,8 +50,8 @@ logical_plan 08)--------------Left Join: customer.c_custkey = orders.o_custkey 09)----------------TableScan: customer projection=[c_custkey] 10)----------------Projection: orders.o_orderkey, orders.o_custkey -11)------------------Filter: orders.o_comment NOT LIKE Utf8("%special%requests%") -12)--------------------TableScan: orders projection=[o_orderkey, o_custkey, o_comment], partial_filters=[orders.o_comment NOT LIKE Utf8("%special%requests%")] +11)------------------Filter: orders.o_comment NOT LIKE Utf8View("%special%requests%") +12)--------------------TableScan: orders projection=[o_orderkey, o_custkey, o_comment], partial_filters=[orders.o_comment NOT LIKE Utf8View("%special%requests%")] physical_plan 01)SortPreservingMergeExec: [custdist@1 DESC, c_count@0 DESC], fetch=10 02)--SortExec: TopK(fetch=10), expr=[custdist@1 DESC, c_count@0 DESC], preserve_partitioning=[true] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q14.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q14.slt.part index 1104af2bdc64..8d8dd68c3d7b 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q14.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q14.slt.part @@ -33,7 +33,7 @@ where ---- logical_plan 01)Projection: Float64(100) * CAST(sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Float64) / CAST(sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Float64) AS promo_revenue -02)--Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN __common_expr_1 ELSE Decimal128(Some(0),38,4) END) AS sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), sum(__common_expr_1) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +02)--Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN part.p_type LIKE Utf8View("PROMO%") THEN __common_expr_1 ELSE Decimal128(Some(0),38,4) END) AS sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), sum(__common_expr_1) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] 03)----Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS __common_expr_1, part.p_type 04)------Inner Join: lineitem.l_partkey = part.p_partkey 05)--------Projection: lineitem.l_partkey, lineitem.l_extendedprice, lineitem.l_discount diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part index c648f164c809..39f99a0fcf98 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q16.slt.part @@ -58,12 +58,12 @@ logical_plan 06)----------Projection: partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size 07)------------Inner Join: partsupp.ps_partkey = part.p_partkey 08)--------------TableScan: partsupp projection=[ps_partkey, ps_suppkey] -09)--------------Filter: part.p_brand != Utf8("Brand#45") AND part.p_type NOT LIKE Utf8("MEDIUM POLISHED%") AND part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)]) -10)----------------TableScan: part projection=[p_partkey, p_brand, p_type, p_size], partial_filters=[part.p_brand != Utf8("Brand#45"), part.p_type NOT LIKE Utf8("MEDIUM POLISHED%"), part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)])] +09)--------------Filter: part.p_brand != Utf8View("Brand#45") AND part.p_type NOT LIKE Utf8View("MEDIUM POLISHED%") AND part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)]) +10)----------------TableScan: part projection=[p_partkey, p_brand, p_type, p_size], partial_filters=[part.p_brand != Utf8View("Brand#45"), part.p_type NOT LIKE Utf8View("MEDIUM POLISHED%"), part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)])] 11)----------SubqueryAlias: __correlated_sq_1 12)------------Projection: supplier.s_suppkey -13)--------------Filter: supplier.s_comment LIKE Utf8("%Customer%Complaints%") -14)----------------TableScan: supplier projection=[s_suppkey, s_comment], partial_filters=[supplier.s_comment LIKE Utf8("%Customer%Complaints%")] +13)--------------Filter: supplier.s_comment LIKE Utf8View("%Customer%Complaints%") +14)----------------TableScan: supplier projection=[s_suppkey, s_comment], partial_filters=[supplier.s_comment LIKE Utf8View("%Customer%Complaints%")] physical_plan 01)SortPreservingMergeExec: [supplier_cnt@3 DESC, p_brand@0 ASC NULLS LAST, p_type@1 ASC NULLS LAST, p_size@2 ASC NULLS LAST], fetch=10 02)--SortExec: TopK(fetch=10), expr=[supplier_cnt@3 DESC, p_brand@0 ASC NULLS LAST, p_type@1 ASC NULLS LAST, p_size@2 ASC NULLS LAST], preserve_partitioning=[true] @@ -88,7 +88,7 @@ physical_plan 21)----------------------------------CoalesceBatchesExec: target_batch_size=8192 22)------------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 23)--------------------------------------CoalesceBatchesExec: target_batch_size=8192 -24)----------------------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND Use p_size@3 IN (SET) ([Literal { value: Int32(49) }, Literal { value: Int32(14) }, Literal { value: Int32(23) }, Literal { value: Int32(45) }, Literal { value: Int32(19) }, Literal { value: Int32(3) }, Literal { value: Int32(36) }, Literal { value: Int32(9) }]) +24)----------------------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND Use p_size@3 IN (SET) ([Literal { value: Int32(49), field: Field { name: "lit", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(14), field: Field { name: "lit", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(23), field: Field { name: "lit", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(45), field: Field { name: "lit", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(19), field: Field { name: "lit", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(3), field: Field { name: "lit", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(36), field: Field { name: "lit", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Int32(9), field: Field { name: "lit", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 25)------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 26)--------------------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_type, p_size], file_type=csv, has_header=false 27)--------------------------CoalesceBatchesExec: target_batch_size=8192 diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q17.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q17.slt.part index 02553890bcf5..51a0d096428c 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q17.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q17.slt.part @@ -44,8 +44,8 @@ logical_plan 06)----------Inner Join: lineitem.l_partkey = part.p_partkey 07)------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice] 08)------------Projection: part.p_partkey -09)--------------Filter: part.p_brand = Utf8("Brand#23") AND part.p_container = Utf8("MED BOX") -10)----------------TableScan: part projection=[p_partkey, p_brand, p_container], partial_filters=[part.p_brand = Utf8("Brand#23"), part.p_container = Utf8("MED BOX")] +09)--------------Filter: part.p_brand = Utf8View("Brand#23") AND part.p_container = Utf8View("MED BOX") +10)----------------TableScan: part projection=[p_partkey, p_brand, p_container], partial_filters=[part.p_brand = Utf8View("Brand#23"), part.p_container = Utf8View("MED BOX")] 11)--------SubqueryAlias: __scalar_sq_1 12)----------Projection: CAST(Float64(0.2) * CAST(avg(lineitem.l_quantity) AS Float64) AS Decimal128(30, 15)), lineitem.l_partkey 13)------------Aggregate: groupBy=[[lineitem.l_partkey]], aggr=[[avg(lineitem.l_quantity)]] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part index b0e5b2e904d0..4cfbdc18ca50 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part @@ -57,19 +57,19 @@ logical_plan 01)Projection: sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue 02)--Aggregate: groupBy=[[]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] 03)----Projection: lineitem.l_extendedprice, lineitem.l_discount -04)------Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) +04)------Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) 05)--------Projection: lineitem.l_partkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount -06)----------Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR") OR lineitem.l_shipmode = Utf8("AIR REG")) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON") -07)------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("AIR") OR lineitem.l_shipmode = Utf8("AIR REG"), lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON"), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] -08)--------Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) -09)----------TableScan: part projection=[p_partkey, p_brand, p_size, p_container], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)] +06)----------Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8View("AIR") OR lineitem.l_shipmode = Utf8View("AIR REG")) AND lineitem.l_shipinstruct = Utf8View("DELIVER IN PERSON") +07)------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8View("AIR") OR lineitem.l_shipmode = Utf8View("AIR REG"), lineitem.l_shipinstruct = Utf8View("DELIVER IN PERSON"), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] +08)--------Filter: (part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) +09)----------TableScan: part projection=[p_partkey, p_brand, p_size, p_container], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND part.p_size <= Int32(15)] physical_plan 01)ProjectionExec: expr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@0 as revenue] 02)--AggregateExec: mode=Final, gby=[], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] 05)--------CoalesceBatchesExec: target_batch_size=8192 -06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("SM CASE") }, Literal { value: Utf8("SM BOX") }, Literal { value: Utf8("SM PACK") }, Literal { value: Utf8("SM PKG") }]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("MED BAG") }, Literal { value: Utf8("MED BOX") }, Literal { value: Utf8("MED PKG") }, Literal { value: Utf8("MED PACK") }]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("LG CASE") }, Literal { value: Utf8("LG BOX") }, Literal { value: Utf8("LG PACK") }, Literal { value: Utf8("LG PKG") }]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_extendedprice@2, l_discount@3] +06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND p_container@3 IN ([Literal { value: Utf8View("SM CASE"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("SM BOX"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("SM PACK"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("SM PKG"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN ([Literal { value: Utf8View("MED BAG"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("MED BOX"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("MED PKG"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("MED PACK"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN ([Literal { value: Utf8View("LG CASE"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("LG BOX"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("LG PACK"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("LG PKG"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_extendedprice@2, l_discount@3] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 09)----------------CoalesceBatchesExec: target_batch_size=8192 @@ -78,6 +78,6 @@ physical_plan 12)------------CoalesceBatchesExec: target_batch_size=8192 13)--------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 14)----------------CoalesceBatchesExec: target_batch_size=8192 -15)------------------FilterExec: (p_brand@1 = Brand#12 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("SM CASE") }, Literal { value: Utf8("SM BOX") }, Literal { value: Utf8("SM PACK") }, Literal { value: Utf8("SM PKG") }]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("MED BAG") }, Literal { value: Utf8("MED BOX") }, Literal { value: Utf8("MED PKG") }, Literal { value: Utf8("MED PACK") }]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("LG CASE") }, Literal { value: Utf8("LG BOX") }, Literal { value: Utf8("LG PACK") }, Literal { value: Utf8("LG PKG") }]) AND p_size@2 <= 15) AND p_size@2 >= 1 +15)------------------FilterExec: (p_brand@1 = Brand#12 AND p_container@3 IN ([Literal { value: Utf8View("SM CASE"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("SM BOX"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("SM PACK"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("SM PKG"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN ([Literal { value: Utf8View("MED BAG"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("MED BOX"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("MED PKG"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("MED PACK"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN ([Literal { value: Utf8View("LG CASE"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("LG BOX"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("LG PACK"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("LG PKG"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) AND p_size@2 <= 15) AND p_size@2 >= 1 16)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 17)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_size, p_container], file_type=csv, has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q2.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q2.slt.part index 2a8ee9f229b7..b2e0fb0cd1cc 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q2.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q2.slt.part @@ -75,14 +75,14 @@ logical_plan 10)------------------Projection: part.p_partkey, part.p_mfgr, partsupp.ps_suppkey, partsupp.ps_supplycost 11)--------------------Inner Join: part.p_partkey = partsupp.ps_partkey 12)----------------------Projection: part.p_partkey, part.p_mfgr -13)------------------------Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS") -14)--------------------------TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE Utf8("%BRASS")] +13)------------------------Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8View("%BRASS") +14)--------------------------TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE Utf8View("%BRASS")] 15)----------------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] 16)------------------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] 17)--------------TableScan: nation projection=[n_nationkey, n_name, n_regionkey] 18)----------Projection: region.r_regionkey -19)------------Filter: region.r_name = Utf8("EUROPE") -20)--------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")] +19)------------Filter: region.r_name = Utf8View("EUROPE") +20)--------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8View("EUROPE")] 21)------SubqueryAlias: __scalar_sq_1 22)--------Projection: min(partsupp.ps_supplycost), partsupp.ps_partkey 23)----------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[min(partsupp.ps_supplycost)]] @@ -96,8 +96,8 @@ logical_plan 31)------------------------TableScan: supplier projection=[s_suppkey, s_nationkey] 32)--------------------TableScan: nation projection=[n_nationkey, n_regionkey] 33)----------------Projection: region.r_regionkey -34)------------------Filter: region.r_name = Utf8("EUROPE") -35)--------------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")] +34)------------------Filter: region.r_name = Utf8View("EUROPE") +35)--------------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8View("EUROPE")] physical_plan 01)SortPreservingMergeExec: [s_acctbal@0 DESC, n_name@2 ASC NULLS LAST, s_name@1 ASC NULLS LAST, p_partkey@3 ASC NULLS LAST], fetch=10 02)--SortExec: TopK(fetch=10), expr=[s_acctbal@0 DESC, n_name@2 ASC NULLS LAST, s_name@1 ASC NULLS LAST, p_partkey@3 ASC NULLS LAST], preserve_partitioning=[true] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q20.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q20.slt.part index 4844d5fae60b..0b994de411ea 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q20.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q20.slt.part @@ -63,8 +63,8 @@ logical_plan 05)--------Inner Join: supplier.s_nationkey = nation.n_nationkey 06)----------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey] 07)----------Projection: nation.n_nationkey -08)------------Filter: nation.n_name = Utf8("CANADA") -09)--------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("CANADA")] +08)------------Filter: nation.n_name = Utf8View("CANADA") +09)--------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("CANADA")] 10)------SubqueryAlias: __correlated_sq_2 11)--------Projection: partsupp.ps_suppkey 12)----------Inner Join: partsupp.ps_partkey = __scalar_sq_3.l_partkey, partsupp.ps_suppkey = __scalar_sq_3.l_suppkey Filter: CAST(partsupp.ps_availqty AS Float64) > __scalar_sq_3.Float64(0.5) * sum(lineitem.l_quantity) @@ -72,8 +72,8 @@ logical_plan 14)--------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] 15)--------------SubqueryAlias: __correlated_sq_1 16)----------------Projection: part.p_partkey -17)------------------Filter: part.p_name LIKE Utf8("forest%") -18)--------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("forest%")] +17)------------------Filter: part.p_name LIKE Utf8View("forest%") +18)--------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8View("forest%")] 19)------------SubqueryAlias: __scalar_sq_3 20)--------------Projection: Float64(0.5) * CAST(sum(lineitem.l_quantity) AS Float64), lineitem.l_partkey, lineitem.l_suppkey 21)----------------Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[sum(lineitem.l_quantity)]] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q21.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q21.slt.part index bb3e884e27be..e52171524007 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q21.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q21.slt.part @@ -76,11 +76,11 @@ logical_plan 16)----------------------------Filter: lineitem.l_receiptdate > lineitem.l_commitdate 17)------------------------------TableScan: lineitem projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate], partial_filters=[lineitem.l_receiptdate > lineitem.l_commitdate] 18)--------------------Projection: orders.o_orderkey -19)----------------------Filter: orders.o_orderstatus = Utf8("F") -20)------------------------TableScan: orders projection=[o_orderkey, o_orderstatus], partial_filters=[orders.o_orderstatus = Utf8("F")] +19)----------------------Filter: orders.o_orderstatus = Utf8View("F") +20)------------------------TableScan: orders projection=[o_orderkey, o_orderstatus], partial_filters=[orders.o_orderstatus = Utf8View("F")] 21)----------------Projection: nation.n_nationkey -22)------------------Filter: nation.n_name = Utf8("SAUDI ARABIA") -23)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("SAUDI ARABIA")] +22)------------------Filter: nation.n_name = Utf8View("SAUDI ARABIA") +23)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("SAUDI ARABIA")] 24)------------SubqueryAlias: __correlated_sq_1 25)--------------SubqueryAlias: l2 26)----------------TableScan: lineitem projection=[l_orderkey, l_suppkey] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part index 828bf967d8f4..e9b533f2044f 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part @@ -90,7 +90,7 @@ physical_plan 14)--------------------------CoalesceBatchesExec: target_batch_size=8192 15)----------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 16)------------------------------CoalesceBatchesExec: target_batch_size=8192 -17)--------------------------------FilterExec: substr(c_phone@1, 1, 2) IN ([Literal { value: Utf8View("13") }, Literal { value: Utf8View("31") }, Literal { value: Utf8View("23") }, Literal { value: Utf8View("29") }, Literal { value: Utf8View("30") }, Literal { value: Utf8View("18") }, Literal { value: Utf8View("17") }]) +17)--------------------------------FilterExec: substr(c_phone@1, 1, 2) IN ([Literal { value: Utf8View("13"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("31"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("23"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("29"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("30"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("18"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("17"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]) 18)----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 19)------------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_phone, c_acctbal], file_type=csv, has_header=false 20)--------------------------CoalesceBatchesExec: target_batch_size=8192 @@ -100,6 +100,6 @@ physical_plan 24)----------------------CoalescePartitionsExec 25)------------------------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)] 26)--------------------------CoalesceBatchesExec: target_batch_size=8192 -27)----------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN ([Literal { value: Utf8View("13") }, Literal { value: Utf8View("31") }, Literal { value: Utf8View("23") }, Literal { value: Utf8View("29") }, Literal { value: Utf8View("30") }, Literal { value: Utf8View("18") }, Literal { value: Utf8View("17") }]), projection=[c_acctbal@1] +27)----------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN ([Literal { value: Utf8View("13"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("31"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("23"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("29"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("30"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("18"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }, Literal { value: Utf8View("17"), field: Field { name: "lit", data_type: Utf8View, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} } }]), projection=[c_acctbal@1] 28)------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 29)--------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], file_type=csv, has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q3.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q3.slt.part index 2ad496ef26fd..d982ec32e954 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q3.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q3.slt.part @@ -50,8 +50,8 @@ logical_plan 06)----------Projection: orders.o_orderkey, orders.o_orderdate, orders.o_shippriority 07)------------Inner Join: customer.c_custkey = orders.o_custkey 08)--------------Projection: customer.c_custkey -09)----------------Filter: customer.c_mktsegment = Utf8("BUILDING") -10)------------------TableScan: customer projection=[c_custkey, c_mktsegment], partial_filters=[customer.c_mktsegment = Utf8("BUILDING")] +09)----------------Filter: customer.c_mktsegment = Utf8View("BUILDING") +10)------------------TableScan: customer projection=[c_custkey, c_mktsegment], partial_filters=[customer.c_mktsegment = Utf8View("BUILDING")] 11)--------------Filter: orders.o_orderdate < Date32("1995-03-15") 12)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate, o_shippriority], partial_filters=[orders.o_orderdate < Date32("1995-03-15")] 13)----------Projection: lineitem.l_orderkey, lineitem.l_extendedprice, lineitem.l_discount diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q5.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q5.slt.part index f192f987b3ef..15636056b871 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q5.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q5.slt.part @@ -64,8 +64,8 @@ logical_plan 19)------------------TableScan: supplier projection=[s_suppkey, s_nationkey] 20)--------------TableScan: nation projection=[n_nationkey, n_name, n_regionkey] 21)----------Projection: region.r_regionkey -22)------------Filter: region.r_name = Utf8("ASIA") -23)--------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("ASIA")] +22)------------Filter: region.r_name = Utf8View("ASIA") +23)--------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8View("ASIA")] physical_plan 01)SortPreservingMergeExec: [revenue@1 DESC] 02)--SortExec: expr=[revenue@1 DESC], preserve_partitioning=[true] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q7.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q7.slt.part index e03de9596fbe..291d56e43f2d 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q7.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q7.slt.part @@ -63,7 +63,7 @@ logical_plan 03)----Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[sum(shipping.volume)]] 04)------SubqueryAlias: shipping 05)--------Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, date_part(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS volume -06)----------Inner Join: customer.c_nationkey = n2.n_nationkey Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE") +06)----------Inner Join: customer.c_nationkey = n2.n_nationkey Filter: n1.n_name = Utf8View("FRANCE") AND n2.n_name = Utf8View("GERMANY") OR n1.n_name = Utf8View("GERMANY") AND n2.n_name = Utf8View("FRANCE") 07)------------Projection: lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_shipdate, customer.c_nationkey, n1.n_name 08)--------------Inner Join: supplier.s_nationkey = n1.n_nationkey 09)----------------Projection: supplier.s_nationkey, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_shipdate, customer.c_nationkey @@ -78,11 +78,11 @@ logical_plan 18)------------------------TableScan: orders projection=[o_orderkey, o_custkey] 19)--------------------TableScan: customer projection=[c_custkey, c_nationkey] 20)----------------SubqueryAlias: n1 -21)------------------Filter: nation.n_name = Utf8("FRANCE") OR nation.n_name = Utf8("GERMANY") -22)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("FRANCE") OR nation.n_name = Utf8("GERMANY")] +21)------------------Filter: nation.n_name = Utf8View("FRANCE") OR nation.n_name = Utf8View("GERMANY") +22)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("FRANCE") OR nation.n_name = Utf8View("GERMANY")] 23)------------SubqueryAlias: n2 -24)--------------Filter: nation.n_name = Utf8("GERMANY") OR nation.n_name = Utf8("FRANCE") -25)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY") OR nation.n_name = Utf8("FRANCE")] +24)--------------Filter: nation.n_name = Utf8View("GERMANY") OR nation.n_name = Utf8View("FRANCE") +25)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY") OR nation.n_name = Utf8View("FRANCE")] physical_plan 01)SortPreservingMergeExec: [supp_nation@0 ASC NULLS LAST, cust_nation@1 ASC NULLS LAST, l_year@2 ASC NULLS LAST] 02)--SortExec: expr=[supp_nation@0 ASC NULLS LAST, cust_nation@1 ASC NULLS LAST, l_year@2 ASC NULLS LAST], preserve_partitioning=[true] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q8.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q8.slt.part index 88ceffd62ad3..50171c528db6 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q8.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q8.slt.part @@ -58,7 +58,7 @@ order by logical_plan 01)Sort: all_nations.o_year ASC NULLS LAST 02)--Projection: all_nations.o_year, CAST(CAST(sum(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END) AS Decimal128(12, 2)) / CAST(sum(all_nations.volume) AS Decimal128(12, 2)) AS Decimal128(15, 2)) AS mkt_share -03)----Aggregate: groupBy=[[all_nations.o_year]], aggr=[[sum(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Decimal128(Some(0),38,4) END) AS sum(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END), sum(all_nations.volume)]] +03)----Aggregate: groupBy=[[all_nations.o_year]], aggr=[[sum(CASE WHEN all_nations.nation = Utf8View("BRAZIL") THEN all_nations.volume ELSE Decimal128(Some(0),38,4) END) AS sum(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END), sum(all_nations.volume)]] 04)------SubqueryAlias: all_nations 05)--------Projection: date_part(Utf8("YEAR"), orders.o_orderdate) AS o_year, lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS volume, n2.n_name AS nation 06)----------Inner Join: n1.n_regionkey = region.r_regionkey @@ -75,8 +75,8 @@ logical_plan 17)--------------------------------Projection: lineitem.l_orderkey, lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount 18)----------------------------------Inner Join: part.p_partkey = lineitem.l_partkey 19)------------------------------------Projection: part.p_partkey -20)--------------------------------------Filter: part.p_type = Utf8("ECONOMY ANODIZED STEEL") -21)----------------------------------------TableScan: part projection=[p_partkey, p_type], partial_filters=[part.p_type = Utf8("ECONOMY ANODIZED STEEL")] +20)--------------------------------------Filter: part.p_type = Utf8View("ECONOMY ANODIZED STEEL") +21)----------------------------------------TableScan: part projection=[p_partkey, p_type], partial_filters=[part.p_type = Utf8View("ECONOMY ANODIZED STEEL")] 22)------------------------------------TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey, l_extendedprice, l_discount] 23)--------------------------------TableScan: supplier projection=[s_suppkey, s_nationkey] 24)----------------------------Filter: orders.o_orderdate >= Date32("1995-01-01") AND orders.o_orderdate <= Date32("1996-12-31") @@ -87,8 +87,8 @@ logical_plan 29)----------------SubqueryAlias: n2 30)------------------TableScan: nation projection=[n_nationkey, n_name] 31)------------Projection: region.r_regionkey -32)--------------Filter: region.r_name = Utf8("AMERICA") -33)----------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("AMERICA")] +32)--------------Filter: region.r_name = Utf8View("AMERICA") +33)----------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8View("AMERICA")] physical_plan 01)SortPreservingMergeExec: [o_year@0 ASC NULLS LAST] 02)--SortExec: expr=[o_year@0 ASC NULLS LAST], preserve_partitioning=[true] diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q9.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q9.slt.part index 8ccf967187d7..3b31c1bc2e8e 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q9.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q9.slt.part @@ -67,8 +67,8 @@ logical_plan 13)------------------------Projection: lineitem.l_orderkey, lineitem.l_partkey, lineitem.l_suppkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount 14)--------------------------Inner Join: part.p_partkey = lineitem.l_partkey 15)----------------------------Projection: part.p_partkey -16)------------------------------Filter: part.p_name LIKE Utf8("%green%") -17)--------------------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("%green%")] +16)------------------------------Filter: part.p_name LIKE Utf8View("%green%") +17)--------------------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8View("%green%")] 18)----------------------------TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey, l_quantity, l_extendedprice, l_discount] 19)------------------------TableScan: supplier projection=[s_suppkey, s_nationkey] 20)--------------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index 356f1598bc0f..f901a4d373a3 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -230,7 +230,7 @@ logical_plan 02)--Union 03)----TableScan: t1 projection=[name] 04)----TableScan: t2 projection=[name] -05)----Projection: t2.name || Utf8("_new") AS name +05)----Projection: t2.name || Utf8View("_new") AS name 06)------TableScan: t2 projection=[name] physical_plan 01)AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] @@ -266,7 +266,7 @@ logical_plan 01)Union 02)--TableScan: t1 projection=[name] 03)--TableScan: t2 projection=[name] -04)--Projection: t2.name || Utf8("_new") AS name +04)--Projection: t2.name || Utf8View("_new") AS name 05)----TableScan: t2 projection=[name] physical_plan 01)UnionExec @@ -489,11 +489,11 @@ logical_plan 04)------Limit: skip=0, fetch=3 05)--------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 06)----------SubqueryAlias: a -07)------------Projection: +07)------------Projection: 08)--------------Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[]] 09)----------------Projection: aggregate_test_100.c1 -10)------------------Filter: aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434") -11)--------------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] +10)------------------Filter: aggregate_test_100.c13 != Utf8View("C2GT5KVyOPZpgKVl110TyZO0NcJ434") +11)--------------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8View("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] 12)----Projection: Int64(1) AS cnt 13)------Limit: skip=0, fetch=3 14)--------EmptyRelation @@ -829,10 +829,10 @@ ORDER BY c1 logical_plan 01)Sort: c1 ASC NULLS LAST 02)--Union -03)----Filter: aggregate_test_100.c1 = Utf8("a") -04)------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c1 = Utf8("a")] -05)----Filter: aggregate_test_100.c1 = Utf8("a") -06)------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c1 = Utf8("a")] +03)----Filter: aggregate_test_100.c1 = Utf8View("a") +04)------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c1 = Utf8View("a")] +05)----Filter: aggregate_test_100.c1 = Utf8View("a") +06)------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], partial_filters=[aggregate_test_100.c1 = Utf8View("a")] physical_plan 01)CoalescePartitionsExec 02)--UnionExec diff --git a/datafusion/sqllogictest/test_files/union_by_name.slt b/datafusion/sqllogictest/test_files/union_by_name.slt index 9572e6efc3e6..233885618f83 100644 --- a/datafusion/sqllogictest/test_files/union_by_name.slt +++ b/datafusion/sqllogictest/test_files/union_by_name.slt @@ -348,7 +348,7 @@ Schema { fields: [ Field { name: "x", - data_type: Utf8, + data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, @@ -356,7 +356,7 @@ Schema { }, Field { name: "y", - data_type: Utf8, + data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, @@ -364,7 +364,7 @@ Schema { }, Field { name: "z", - data_type: Utf8, + data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, @@ -387,7 +387,7 @@ Schema { fields: [ Field { name: "x", - data_type: Utf8, + data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, @@ -395,7 +395,7 @@ Schema { }, Field { name: "y", - data_type: Utf8, + data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, @@ -403,7 +403,7 @@ Schema { }, Field { name: "z", - data_type: Utf8, + data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, diff --git a/datafusion/sqllogictest/test_files/union_function.slt b/datafusion/sqllogictest/test_files/union_function.slt index 9c70b1011f58..74616490ab70 100644 --- a/datafusion/sqllogictest/test_files/union_function.slt +++ b/datafusion/sqllogictest/test_files/union_function.slt @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. +# Note: union_table is registered via Rust code in the sqllogictest test harness +# because there is no way to create a union type in SQL today + ########## ## UNION DataType Tests ########## @@ -23,7 +26,8 @@ query ?I select union_column, union_extract(union_column, 'int') from union_table; ---- {int=1} 1 -{int=2} 2 +{string=bar} NULL +{int=3} 3 query error DataFusion error: Execution error: field bool not found on union select union_extract(union_column, 'bool') from union_table; @@ -45,3 +49,19 @@ select union_extract(union_column, 1) from union_table; query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 3 select union_extract(union_column, 'a', 'b') from union_table; + +query ?T +select union_column, union_tag(union_column) from union_table; +---- +{int=1} int +{string=bar} string +{int=3} int + +query error DataFusion error: Error during planning: 'union_tag' does not support zero arguments +select union_tag() from union_table; + +query error DataFusion error: Error during planning: The function 'union_tag' expected 1 arguments but received 2 +select union_tag(union_column, 'int') from union_table; + +query error DataFusion error: Execution error: union_tag only support unions, got Utf8 +select union_tag('int') from union_table; diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index b9c13582952a..92e6f9995ae3 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -91,12 +91,12 @@ select * from unnest(null); ## Unnest empty array in select list -query I +query ? select unnest([]); ---- ## Unnest empty array in from clause -query I +query ? select * from unnest([]); ---- @@ -243,7 +243,7 @@ query error DataFusion error: This feature is not implemented: unnest\(\) does n select unnest(null) from unnest_table; ## Multiple unnest functions in selection -query II +query ?I select unnest([]), unnest(NULL::int[]); ---- @@ -263,10 +263,10 @@ NULL 10 NULL NULL NULL 17 NULL NULL 18 -query IIIT -select - unnest(column1), unnest(column2) + 2, - column3 * 10, unnest(array_remove(column1, '4')) +query IIII +select + unnest(column1), unnest(column2) + 2, + column3 * 10, unnest(array_remove(column1, 4)) from unnest_table; ---- 1 9 10 1 @@ -316,7 +316,7 @@ select * from unnest( 2 b NULL NULL NULL c NULL NULL -query II +query ?I select * from unnest([], NULL::int[]); ---- diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt index 908d2b34aea4..9f2c16b21106 100644 --- a/datafusion/sqllogictest/test_files/update.slt +++ b/datafusion/sqllogictest/test_files/update.slt @@ -31,7 +31,7 @@ explain update t1 set a=1, b=2, c=3.0, d=NULL; ---- logical_plan 01)Dml: op=[Update] table=[t1] -02)--Projection: CAST(Int64(1) AS Int32) AS a, CAST(Int64(2) AS Utf8) AS b, Float64(3) AS c, CAST(NULL AS Int32) AS d +02)--Projection: CAST(Int64(1) AS Int32) AS a, CAST(Int64(2) AS Utf8View) AS b, Float64(3) AS c, CAST(NULL AS Int32) AS d 03)----TableScan: t1 physical_plan_error This feature is not implemented: Unsupported logical plan: Dml(Update) @@ -40,7 +40,7 @@ explain update t1 set a=c+1, b=a, c=c+1.0, d=b; ---- logical_plan 01)Dml: op=[Update] table=[t1] -02)--Projection: CAST(t1.c + CAST(Int64(1) AS Float64) AS Int32) AS a, CAST(t1.a AS Utf8) AS b, t1.c + Float64(1) AS c, CAST(t1.b AS Int32) AS d +02)--Projection: CAST(t1.c + CAST(Int64(1) AS Float64) AS Int32) AS a, CAST(t1.a AS Utf8View) AS b, t1.c + Float64(1) AS c, CAST(t1.b AS Int32) AS d 03)----TableScan: t1 physical_plan_error This feature is not implemented: Unsupported logical plan: Dml(Update) @@ -69,7 +69,7 @@ explain update t1 set b = t2.b, c = t2.a, d = 1 from t2 where t1.a = t2.a and t1 logical_plan 01)Dml: op=[Update] table=[t1] 02)--Projection: t1.a AS a, t2.b AS b, CAST(t2.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d -03)----Filter: t1.a = t2.a AND t1.b > Utf8("foo") AND t2.c > Float64(1) +03)----Filter: t1.a = t2.a AND t1.b > CAST(Utf8("foo") AS Utf8View) AND t2.c > Float64(1) 04)------Cross Join: 05)--------TableScan: t1 06)--------TableScan: t2 @@ -89,7 +89,7 @@ explain update t1 as T set b = t2.b, c = t.a, d = 1 from t2 where t.a = t2.a and logical_plan 01)Dml: op=[Update] table=[t1] 02)--Projection: t.a AS a, t2.b AS b, CAST(t.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d -03)----Filter: t.a = t2.a AND t.b > Utf8("foo") AND t2.c > Float64(1) +03)----Filter: t.a = t2.a AND t.b > CAST(Utf8("foo") AS Utf8View) AND t2.c > Float64(1) 04)------Cross Join: 05)--------SubqueryAlias: t 06)----------TableScan: t1 diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 52cc80eae1c8..c86921012f9b 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -1767,11 +1767,11 @@ logical_plan 01)Projection: count(Int64(1)) AS count(*) AS global_count 02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 03)----SubqueryAlias: a -04)------Projection: +04)------Projection: 05)--------Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[]] 06)----------Projection: aggregate_test_100.c1 -07)------------Filter: aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434") -08)--------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] +07)------------Filter: aggregate_test_100.c13 != Utf8View("C2GT5KVyOPZpgKVl110TyZO0NcJ434") +08)--------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8View("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] physical_plan 01)ProjectionExec: expr=[count(Int64(1))@0 as global_count] 02)--AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index d23e986914fc..8bc3eccc684d 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -39,7 +39,7 @@ itertools = { workspace = true } object_store = { workspace = true } pbjson-types = { workspace = true } prost = { workspace = true } -substrait = { version = "0.55", features = ["serde"] } +substrait = { version = "0.57", features = ["serde"] } url = { workspace = true } tokio = { workspace = true, features = ["fs"] } diff --git a/datafusion/substrait/README.md b/datafusion/substrait/README.md index 92bb9abcc690..8e7f99b7df38 100644 --- a/datafusion/substrait/README.md +++ b/datafusion/substrait/README.md @@ -19,8 +19,9 @@ # Apache DataFusion Substrait -This crate contains a [Substrait] producer and consumer for Apache Arrow -[DataFusion] plans. See [API Docs] for details and examples. +This crate contains a [Substrait] producer and consumer for [Apache DataFusion] +plans. See [API Docs] for details and examples. [substrait]: https://substrait.io +[apache datafusion]: https://datafusion.apache.org [api docs]: https://docs.rs/datafusion-substrait/latest diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs deleted file mode 100644 index 1442267d3dbb..000000000000 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ /dev/null @@ -1,3452 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; -use arrow::buffer::OffsetBuffer; -use async_recursion::async_recursion; -use datafusion::arrow::array::MapArray; -use datafusion::arrow::datatypes::{ - DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, -}; -use datafusion::common::{ - not_impl_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, - substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, Spans, - TableReference, -}; -use datafusion::datasource::provider_as_source; -use datafusion::logical_expr::expr::{Exists, InSubquery, Sort, WindowFunctionParams}; - -use datafusion::logical_expr::{ - Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, Extension, - LogicalPlan, Operator, Projection, SortExpr, Subquery, TryCast, Values, -}; -use substrait::proto::aggregate_rel::Grouping; -use substrait::proto::expression as substrait_expression; -use substrait::proto::expression::subquery::set_predicate::PredicateOp; -use substrait::proto::expression_reference::ExprType; -use url::Url; - -use crate::extensions::Extensions; -use crate::variation_const::{ - DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, - DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, - VIEW_CONTAINER_TYPE_VARIATION_REF, -}; -#[allow(deprecated)] -use crate::variation_const::{ - INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, - INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF, - TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, - TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, -}; -use async_trait::async_trait; -use datafusion::arrow::array::{new_empty_array, AsArray}; -use datafusion::arrow::temporal_conversions::NANOSECONDS; -use datafusion::catalog::TableProvider; -use datafusion::common::scalar::ScalarStructBuilder; -use datafusion::execution::{FunctionRegistry, SessionState}; -use datafusion::logical_expr::builder::project; -use datafusion::logical_expr::expr::InList; -use datafusion::logical_expr::{ - col, expr, GroupingSet, Like, LogicalPlanBuilder, Partitioning, Repartition, - WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, -}; -use datafusion::prelude::{lit, JoinType}; -use datafusion::{ - arrow, error::Result, logical_expr::utils::split_conjunction, - logical_expr::utils::split_conjunction_owned, prelude::Column, scalar::ScalarValue, -}; -use std::collections::HashSet; -use std::sync::Arc; -use substrait::proto; -use substrait::proto::exchange_rel::ExchangeKind; -use substrait::proto::expression::cast::FailureBehavior::ReturnNull; -use substrait::proto::expression::literal::user_defined::Val; -use substrait::proto::expression::literal::{ - interval_day_to_second, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, -}; -use substrait::proto::expression::subquery::SubqueryType; -use substrait::proto::expression::{ - Enum, FieldReference, IfThen, Literal, MultiOrList, Nested, ScalarFunction, - SingularOrList, SwitchExpression, WindowFunction, -}; -use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; -use substrait::proto::rel_common::{Emit, EmitKind}; -use substrait::proto::set_rel::SetOp; -use substrait::proto::{ - aggregate_function::AggregationInvocation, - expression::{ - field_reference::ReferenceType::DirectReference, literal::LiteralType, - reference_segment::ReferenceType::StructField, - window_function::bound as SubstraitBound, - window_function::bound::Kind as BoundKind, window_function::Bound, - window_function::BoundsType, MaskExpression, RexType, - }, - fetch_rel, - function_argument::ArgType, - join_rel, plan_rel, r#type, - read_rel::ReadType, - rel::RelType, - rel_common, - sort_field::{SortDirection, SortKind::*}, - AggregateFunction, AggregateRel, ConsistentPartitionWindowRel, CrossRel, - DynamicParameter, ExchangeRel, Expression, ExtendedExpression, ExtensionLeafRel, - ExtensionMultiRel, ExtensionSingleRel, FetchRel, FilterRel, FunctionArgument, - JoinRel, NamedStruct, Plan, ProjectRel, ReadRel, Rel, RelCommon, SetRel, SortField, - SortRel, Type, -}; - -#[async_trait] -/// This trait is used to consume Substrait plans, converting them into DataFusion Logical Plans. -/// It can be implemented by users to allow for custom handling of relations, expressions, etc. -/// -/// Combined with the [crate::logical_plan::producer::SubstraitProducer] this allows for fully -/// customizable Substrait serde. -/// -/// # Example Usage -/// -/// ``` -/// # use async_trait::async_trait; -/// # use datafusion::catalog::TableProvider; -/// # use datafusion::common::{not_impl_err, substrait_err, DFSchema, ScalarValue, TableReference}; -/// # use datafusion::error::Result; -/// # use datafusion::execution::{FunctionRegistry, SessionState}; -/// # use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; -/// # use std::sync::Arc; -/// # use substrait::proto; -/// # use substrait::proto::{ExtensionLeafRel, FilterRel, ProjectRel}; -/// # use datafusion::arrow::datatypes::DataType; -/// # use datafusion::logical_expr::expr::ScalarFunction; -/// # use datafusion_substrait::extensions::Extensions; -/// # use datafusion_substrait::logical_plan::consumer::{ -/// # from_project_rel, from_substrait_rel, from_substrait_rex, SubstraitConsumer -/// # }; -/// -/// struct CustomSubstraitConsumer { -/// extensions: Arc, -/// state: Arc, -/// } -/// -/// #[async_trait] -/// impl SubstraitConsumer for CustomSubstraitConsumer { -/// async fn resolve_table_ref( -/// &self, -/// table_ref: &TableReference, -/// ) -> Result>> { -/// let table = table_ref.table().to_string(); -/// let schema = self.state.schema_for_ref(table_ref.clone())?; -/// let table_provider = schema.table(&table).await?; -/// Ok(table_provider) -/// } -/// -/// fn get_extensions(&self) -> &Extensions { -/// self.extensions.as_ref() -/// } -/// -/// fn get_function_registry(&self) -> &impl FunctionRegistry { -/// self.state.as_ref() -/// } -/// -/// // You can reuse existing consumer code to assist in handling advanced extensions -/// async fn consume_project(&self, rel: &ProjectRel) -> Result { -/// let df_plan = from_project_rel(self, rel).await?; -/// if let Some(advanced_extension) = rel.advanced_extension.as_ref() { -/// not_impl_err!( -/// "decode and handle an advanced extension: {:?}", -/// advanced_extension -/// ) -/// } else { -/// Ok(df_plan) -/// } -/// } -/// -/// // You can implement a fully custom consumer method if you need special handling -/// async fn consume_filter(&self, rel: &FilterRel) -> Result { -/// let input = self.consume_rel(rel.input.as_ref().unwrap()).await?; -/// let expression = -/// self.consume_expression(rel.condition.as_ref().unwrap(), input.schema()) -/// .await?; -/// // though this one is quite boring -/// LogicalPlanBuilder::from(input).filter(expression)?.build() -/// } -/// -/// // You can add handlers for extension relations -/// async fn consume_extension_leaf( -/// &self, -/// rel: &ExtensionLeafRel, -/// ) -> Result { -/// not_impl_err!( -/// "handle protobuf Any {} as you need", -/// rel.detail.as_ref().unwrap().type_url -/// ) -/// } -/// -/// // and handlers for user-define types -/// fn consume_user_defined_type(&self, typ: &proto::r#type::UserDefined) -> Result { -/// let type_string = self.extensions.types.get(&typ.type_reference).unwrap(); -/// match type_string.as_str() { -/// "u!foo" => not_impl_err!("handle foo conversion"), -/// "u!bar" => not_impl_err!("handle bar conversion"), -/// _ => substrait_err!("unexpected type") -/// } -/// } -/// -/// // and user-defined literals -/// fn consume_user_defined_literal(&self, literal: &proto::expression::literal::UserDefined) -> Result { -/// let type_string = self.extensions.types.get(&literal.type_reference).unwrap(); -/// match type_string.as_str() { -/// "u!foo" => not_impl_err!("handle foo conversion"), -/// "u!bar" => not_impl_err!("handle bar conversion"), -/// _ => substrait_err!("unexpected type") -/// } -/// } -/// } -/// ``` -/// -pub trait SubstraitConsumer: Send + Sync + Sized { - async fn resolve_table_ref( - &self, - table_ref: &TableReference, - ) -> Result>>; - - // TODO: Remove these two methods - // Ideally, the abstract consumer should not place any constraints on implementations. - // The functionality for which the Extensions and FunctionRegistry is needed should be abstracted - // out into methods on the trait. As an example, resolve_table_reference is such a method. - // See: https://github.com/apache/datafusion/issues/13863 - fn get_extensions(&self) -> &Extensions; - fn get_function_registry(&self) -> &impl FunctionRegistry; - - // Relation Methods - // There is one method per Substrait relation to allow for easy overriding of consumer behaviour. - // These methods have default implementations calling the common handler code, to allow for users - // to re-use common handling logic. - - /// All [Rel]s to be converted pass through this method. - /// You can provide your own implementation if you wish to customize the conversion behaviour. - async fn consume_rel(&self, rel: &Rel) -> Result { - from_substrait_rel(self, rel).await - } - - async fn consume_read(&self, rel: &ReadRel) -> Result { - from_read_rel(self, rel).await - } - - async fn consume_filter(&self, rel: &FilterRel) -> Result { - from_filter_rel(self, rel).await - } - - async fn consume_fetch(&self, rel: &FetchRel) -> Result { - from_fetch_rel(self, rel).await - } - - async fn consume_aggregate(&self, rel: &AggregateRel) -> Result { - from_aggregate_rel(self, rel).await - } - - async fn consume_sort(&self, rel: &SortRel) -> Result { - from_sort_rel(self, rel).await - } - - async fn consume_join(&self, rel: &JoinRel) -> Result { - from_join_rel(self, rel).await - } - - async fn consume_project(&self, rel: &ProjectRel) -> Result { - from_project_rel(self, rel).await - } - - async fn consume_set(&self, rel: &SetRel) -> Result { - from_set_rel(self, rel).await - } - - async fn consume_cross(&self, rel: &CrossRel) -> Result { - from_cross_rel(self, rel).await - } - - async fn consume_consistent_partition_window( - &self, - _rel: &ConsistentPartitionWindowRel, - ) -> Result { - not_impl_err!("Consistent Partition Window Rel not supported") - } - - async fn consume_exchange(&self, rel: &ExchangeRel) -> Result { - from_exchange_rel(self, rel).await - } - - // Expression Methods - // There is one method per Substrait expression to allow for easy overriding of consumer behaviour - // These methods have default implementations calling the common handler code, to allow for users - // to re-use common handling logic. - - /// All [Expression]s to be converted pass through this method. - /// You can provide your own implementation if you wish to customize the conversion behaviour. - async fn consume_expression( - &self, - expr: &Expression, - input_schema: &DFSchema, - ) -> Result { - from_substrait_rex(self, expr, input_schema).await - } - - async fn consume_literal(&self, expr: &Literal) -> Result { - from_literal(self, expr).await - } - - async fn consume_field_reference( - &self, - expr: &FieldReference, - input_schema: &DFSchema, - ) -> Result { - from_field_reference(self, expr, input_schema).await - } - - async fn consume_scalar_function( - &self, - expr: &ScalarFunction, - input_schema: &DFSchema, - ) -> Result { - from_scalar_function(self, expr, input_schema).await - } - - async fn consume_window_function( - &self, - expr: &WindowFunction, - input_schema: &DFSchema, - ) -> Result { - from_window_function(self, expr, input_schema).await - } - - async fn consume_if_then( - &self, - expr: &IfThen, - input_schema: &DFSchema, - ) -> Result { - from_if_then(self, expr, input_schema).await - } - - async fn consume_switch( - &self, - _expr: &SwitchExpression, - _input_schema: &DFSchema, - ) -> Result { - not_impl_err!("Switch expression not supported") - } - - async fn consume_singular_or_list( - &self, - expr: &SingularOrList, - input_schema: &DFSchema, - ) -> Result { - from_singular_or_list(self, expr, input_schema).await - } - - async fn consume_multi_or_list( - &self, - _expr: &MultiOrList, - _input_schema: &DFSchema, - ) -> Result { - not_impl_err!("Multi Or List expression not supported") - } - - async fn consume_cast( - &self, - expr: &substrait_expression::Cast, - input_schema: &DFSchema, - ) -> Result { - from_cast(self, expr, input_schema).await - } - - async fn consume_subquery( - &self, - expr: &substrait_expression::Subquery, - input_schema: &DFSchema, - ) -> Result { - from_subquery(self, expr, input_schema).await - } - - async fn consume_nested( - &self, - _expr: &Nested, - _input_schema: &DFSchema, - ) -> Result { - not_impl_err!("Nested expression not supported") - } - - async fn consume_enum(&self, _expr: &Enum, _input_schema: &DFSchema) -> Result { - not_impl_err!("Enum expression not supported") - } - - async fn consume_dynamic_parameter( - &self, - _expr: &DynamicParameter, - _input_schema: &DFSchema, - ) -> Result { - not_impl_err!("Dynamic Parameter expression not supported") - } - - // User-Defined Functionality - - // The details of extension relations, and how to handle them, are fully up to users to specify. - // The following methods allow users to customize the consumer behaviour - - async fn consume_extension_leaf( - &self, - rel: &ExtensionLeafRel, - ) -> Result { - if let Some(detail) = rel.detail.as_ref() { - return substrait_err!( - "Missing handler for ExtensionLeafRel: {}", - detail.type_url - ); - } - substrait_err!("Missing handler for ExtensionLeafRel") - } - - async fn consume_extension_single( - &self, - rel: &ExtensionSingleRel, - ) -> Result { - if let Some(detail) = rel.detail.as_ref() { - return substrait_err!( - "Missing handler for ExtensionSingleRel: {}", - detail.type_url - ); - } - substrait_err!("Missing handler for ExtensionSingleRel") - } - - async fn consume_extension_multi( - &self, - rel: &ExtensionMultiRel, - ) -> Result { - if let Some(detail) = rel.detail.as_ref() { - return substrait_err!( - "Missing handler for ExtensionMultiRel: {}", - detail.type_url - ); - } - substrait_err!("Missing handler for ExtensionMultiRel") - } - - // Users can bring their own types to Substrait which require custom handling - - fn consume_user_defined_type( - &self, - user_defined_type: &r#type::UserDefined, - ) -> Result { - substrait_err!( - "Missing handler for user-defined type: {}", - user_defined_type.type_reference - ) - } - - fn consume_user_defined_literal( - &self, - user_defined_literal: &proto::expression::literal::UserDefined, - ) -> Result { - substrait_err!( - "Missing handler for user-defined literals {}", - user_defined_literal.type_reference - ) - } -} - -/// Convert Substrait Rel to DataFusion DataFrame -#[async_recursion] -pub async fn from_substrait_rel( - consumer: &impl SubstraitConsumer, - relation: &Rel, -) -> Result { - let plan: Result = match &relation.rel_type { - Some(rel_type) => match rel_type { - RelType::Read(rel) => consumer.consume_read(rel).await, - RelType::Filter(rel) => consumer.consume_filter(rel).await, - RelType::Fetch(rel) => consumer.consume_fetch(rel).await, - RelType::Aggregate(rel) => consumer.consume_aggregate(rel).await, - RelType::Sort(rel) => consumer.consume_sort(rel).await, - RelType::Join(rel) => consumer.consume_join(rel).await, - RelType::Project(rel) => consumer.consume_project(rel).await, - RelType::Set(rel) => consumer.consume_set(rel).await, - RelType::ExtensionSingle(rel) => consumer.consume_extension_single(rel).await, - RelType::ExtensionMulti(rel) => consumer.consume_extension_multi(rel).await, - RelType::ExtensionLeaf(rel) => consumer.consume_extension_leaf(rel).await, - RelType::Cross(rel) => consumer.consume_cross(rel).await, - RelType::Window(rel) => { - consumer.consume_consistent_partition_window(rel).await - } - RelType::Exchange(rel) => consumer.consume_exchange(rel).await, - rt => not_impl_err!("{rt:?} rel not supported yet"), - }, - None => return substrait_err!("rel must set rel_type"), - }; - apply_emit_kind(retrieve_rel_common(relation), plan?) -} - -/// Default SubstraitConsumer for converting standard Substrait without user-defined extensions. -/// -/// Used as the consumer in [from_substrait_plan] -pub struct DefaultSubstraitConsumer<'a> { - extensions: &'a Extensions, - state: &'a SessionState, -} - -impl<'a> DefaultSubstraitConsumer<'a> { - pub fn new(extensions: &'a Extensions, state: &'a SessionState) -> Self { - DefaultSubstraitConsumer { extensions, state } - } -} - -#[async_trait] -impl SubstraitConsumer for DefaultSubstraitConsumer<'_> { - async fn resolve_table_ref( - &self, - table_ref: &TableReference, - ) -> Result>> { - let table = table_ref.table().to_string(); - let schema = self.state.schema_for_ref(table_ref.clone())?; - let table_provider = schema.table(&table).await?; - Ok(table_provider) - } - - fn get_extensions(&self) -> &Extensions { - self.extensions - } - - fn get_function_registry(&self) -> &impl FunctionRegistry { - self.state - } - - async fn consume_extension_leaf( - &self, - rel: &ExtensionLeafRel, - ) -> Result { - let Some(ext_detail) = &rel.detail else { - return substrait_err!("Unexpected empty detail in ExtensionLeafRel"); - }; - let plan = self - .state - .serializer_registry() - .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; - Ok(LogicalPlan::Extension(Extension { node: plan })) - } - - async fn consume_extension_single( - &self, - rel: &ExtensionSingleRel, - ) -> Result { - let Some(ext_detail) = &rel.detail else { - return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); - }; - let plan = self - .state - .serializer_registry() - .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; - let Some(input_rel) = &rel.input else { - return substrait_err!( - "ExtensionSingleRel missing input rel, try using ExtensionLeafRel instead" - ); - }; - let input_plan = self.consume_rel(input_rel).await?; - let plan = plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?; - Ok(LogicalPlan::Extension(Extension { node: plan })) - } - - async fn consume_extension_multi( - &self, - rel: &ExtensionMultiRel, - ) -> Result { - let Some(ext_detail) = &rel.detail else { - return substrait_err!("Unexpected empty detail in ExtensionMultiRel"); - }; - let plan = self - .state - .serializer_registry() - .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; - let mut inputs = Vec::with_capacity(rel.inputs.len()); - for input in &rel.inputs { - let input_plan = self.consume_rel(input).await?; - inputs.push(input_plan); - } - let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; - Ok(LogicalPlan::Extension(Extension { node: plan })) - } -} - -// Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which -// is the same as the expectation for any non-empty timezone in DF, so any non-empty timezone -// results in correct points on the timeline, and we pick UTC as a reasonable default. -// However, DF uses the timezone also for some arithmetic and display purposes (see e.g. -// https://github.com/apache/arrow-rs/blob/ee5694078c86c8201549654246900a4232d531a9/arrow-cast/src/cast/mod.rs#L1749). -const DEFAULT_TIMEZONE: &str = "UTC"; - -pub fn name_to_op(name: &str) -> Option { - match name { - "equal" => Some(Operator::Eq), - "not_equal" => Some(Operator::NotEq), - "lt" => Some(Operator::Lt), - "lte" => Some(Operator::LtEq), - "gt" => Some(Operator::Gt), - "gte" => Some(Operator::GtEq), - "add" => Some(Operator::Plus), - "subtract" => Some(Operator::Minus), - "multiply" => Some(Operator::Multiply), - "divide" => Some(Operator::Divide), - "mod" => Some(Operator::Modulo), - "modulus" => Some(Operator::Modulo), - "and" => Some(Operator::And), - "or" => Some(Operator::Or), - "is_distinct_from" => Some(Operator::IsDistinctFrom), - "is_not_distinct_from" => Some(Operator::IsNotDistinctFrom), - "regex_match" => Some(Operator::RegexMatch), - "regex_imatch" => Some(Operator::RegexIMatch), - "regex_not_match" => Some(Operator::RegexNotMatch), - "regex_not_imatch" => Some(Operator::RegexNotIMatch), - "bitwise_and" => Some(Operator::BitwiseAnd), - "bitwise_or" => Some(Operator::BitwiseOr), - "str_concat" => Some(Operator::StringConcat), - "at_arrow" => Some(Operator::AtArrow), - "arrow_at" => Some(Operator::ArrowAt), - "bitwise_xor" => Some(Operator::BitwiseXor), - "bitwise_shift_right" => Some(Operator::BitwiseShiftRight), - "bitwise_shift_left" => Some(Operator::BitwiseShiftLeft), - _ => None, - } -} - -pub fn substrait_fun_name(name: &str) -> &str { - let name = match name.rsplit_once(':') { - // Since 0.32.0, Substrait requires the function names to be in a compound format - // https://substrait.io/extensions/#function-signature-compound-names - // for example, `add:i8_i8`. - // On the consumer side, we don't really care about the signature though, just the name. - Some((name, _)) => name, - None => name, - }; - name -} - -fn split_eq_and_noneq_join_predicate_with_nulls_equality( - filter: &Expr, -) -> (Vec<(Column, Column)>, bool, Option) { - let exprs = split_conjunction(filter); - - let mut accum_join_keys: Vec<(Column, Column)> = vec![]; - let mut accum_filters: Vec = vec![]; - let mut nulls_equal_nulls = false; - - for expr in exprs { - #[allow(clippy::collapsible_match)] - match expr { - Expr::BinaryExpr(binary_expr) => match binary_expr { - x @ (BinaryExpr { - left, - op: Operator::Eq, - right, - } - | BinaryExpr { - left, - op: Operator::IsNotDistinctFrom, - right, - }) => { - nulls_equal_nulls = match x.op { - Operator::Eq => false, - Operator::IsNotDistinctFrom => true, - _ => unreachable!(), - }; - - match (left.as_ref(), right.as_ref()) { - (Expr::Column(l), Expr::Column(r)) => { - accum_join_keys.push((l.clone(), r.clone())); - } - _ => accum_filters.push(expr.clone()), - } - } - _ => accum_filters.push(expr.clone()), - }, - _ => accum_filters.push(expr.clone()), - } - } - - let join_filter = accum_filters.into_iter().reduce(Expr::and); - (accum_join_keys, nulls_equal_nulls, join_filter) -} - -async fn union_rels( - consumer: &impl SubstraitConsumer, - rels: &[Rel], - is_all: bool, -) -> Result { - let mut union_builder = Ok(LogicalPlanBuilder::from( - consumer.consume_rel(&rels[0]).await?, - )); - for input in &rels[1..] { - let rel_plan = consumer.consume_rel(input).await?; - - union_builder = if is_all { - union_builder?.union(rel_plan) - } else { - union_builder?.union_distinct(rel_plan) - }; - } - union_builder?.build() -} - -async fn intersect_rels( - consumer: &impl SubstraitConsumer, - rels: &[Rel], - is_all: bool, -) -> Result { - let mut rel = consumer.consume_rel(&rels[0]).await?; - - for input in &rels[1..] { - rel = LogicalPlanBuilder::intersect( - rel, - consumer.consume_rel(input).await?, - is_all, - )? - } - - Ok(rel) -} - -async fn except_rels( - consumer: &impl SubstraitConsumer, - rels: &[Rel], - is_all: bool, -) -> Result { - let mut rel = consumer.consume_rel(&rels[0]).await?; - - for input in &rels[1..] { - rel = LogicalPlanBuilder::except(rel, consumer.consume_rel(input).await?, is_all)? - } - - Ok(rel) -} - -/// Convert Substrait Plan to DataFusion LogicalPlan -pub async fn from_substrait_plan( - state: &SessionState, - plan: &Plan, -) -> Result { - // Register function extension - let extensions = Extensions::try_from(&plan.extensions)?; - if !extensions.type_variations.is_empty() { - return not_impl_err!("Type variation extensions are not supported"); - } - - let consumer = DefaultSubstraitConsumer { - extensions: &extensions, - state, - }; - from_substrait_plan_with_consumer(&consumer, plan).await -} - -/// Convert Substrait Plan to DataFusion LogicalPlan using the given consumer -pub async fn from_substrait_plan_with_consumer( - consumer: &impl SubstraitConsumer, - plan: &Plan, -) -> Result { - match plan.relations.len() { - 1 => { - match plan.relations[0].rel_type.as_ref() { - Some(rt) => match rt { - plan_rel::RelType::Rel(rel) => Ok(consumer.consume_rel(rel).await?), - plan_rel::RelType::Root(root) => { - let plan = consumer.consume_rel(root.input.as_ref().unwrap()).await?; - if root.names.is_empty() { - // Backwards compatibility for plans missing names - return Ok(plan); - } - let renamed_schema = make_renamed_schema(plan.schema(), &root.names)?; - if renamed_schema.has_equivalent_names_and_types(plan.schema()).is_ok() { - // Nothing to do if the schema is already equivalent - return Ok(plan); - } - match plan { - // If the last node of the plan produces expressions, bake the renames into those expressions. - // This isn't necessary for correctness, but helps with roundtrip tests. - LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema.fields())?, p.input)?)), - LogicalPlan::Aggregate(a) => { - let (group_fields, expr_fields) = renamed_schema.fields().split_at(a.group_expr.len()); - let new_group_exprs = rename_expressions(a.group_expr, a.input.schema(), group_fields)?; - let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), expr_fields)?; - Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, new_group_exprs, new_aggr_exprs)?)) - }, - // There are probably more plans where we could bake things in, can add them later as needed. - // Otherwise, add a new Project to handle the renaming. - _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema.fields())?, Arc::new(plan))?)) - } - } - }, - None => plan_err!("Cannot parse plan relation: None") - } - }, - _ => not_impl_err!( - "Substrait plan with more than 1 relation trees not supported. Number of relation trees: {:?}", - plan.relations.len() - ) - } -} - -/// An ExprContainer is a container for a collection of expressions with a common input schema -/// -/// In addition, each expression is associated with a field, which defines the -/// expression's output. The data type and nullability of the field are calculated from the -/// expression and the input schema. However the names of the field (and its nested fields) are -/// derived from the Substrait message. -pub struct ExprContainer { - /// The input schema for the expressions - pub input_schema: DFSchemaRef, - /// The expressions - /// - /// Each item contains an expression and the field that defines the expected nullability and name of the expr's output - pub exprs: Vec<(Expr, Field)>, -} - -/// Convert Substrait ExtendedExpression to ExprContainer -/// -/// A Substrait ExtendedExpression message contains one or more expressions, -/// with names for the outputs, and an input schema. These pieces are all included -/// in the ExprContainer. -/// -/// This is a top-level message and can be used to send expressions (not plans) -/// between systems. This is often useful for scenarios like pushdown where filter -/// expressions need to be sent to remote systems. -pub async fn from_substrait_extended_expr( - state: &SessionState, - extended_expr: &ExtendedExpression, -) -> Result { - // Register function extension - let extensions = Extensions::try_from(&extended_expr.extensions)?; - if !extensions.type_variations.is_empty() { - return not_impl_err!("Type variation extensions are not supported"); - } - - let consumer = DefaultSubstraitConsumer { - extensions: &extensions, - state, - }; - - let input_schema = DFSchemaRef::new(match &extended_expr.base_schema { - Some(base_schema) => from_substrait_named_struct(&consumer, base_schema), - None => { - plan_err!("required property `base_schema` missing from Substrait ExtendedExpression message") - } - }?); - - // Parse expressions - let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len()); - for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() { - let scalar_expr = match &substrait_expr.expr_type { - Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr), - Some(ExprType::Measure(_)) => { - not_impl_err!("Measure expressions are not yet supported") - } - None => { - plan_err!("required property `expr_type` missing from Substrait ExpressionReference message") - } - }?; - let expr = consumer - .consume_expression(scalar_expr, &input_schema) - .await?; - let (output_type, expected_nullability) = - expr.data_type_and_nullable(&input_schema)?; - let output_field = Field::new("", output_type, expected_nullability); - let mut names_idx = 0; - let output_field = rename_field( - &output_field, - &substrait_expr.output_names, - expr_idx, - &mut names_idx, - /*rename_self=*/ true, - )?; - exprs.push((expr, output_field)); - } - - Ok(ExprContainer { - input_schema, - exprs, - }) -} - -pub fn apply_masking( - schema: DFSchema, - mask_expression: &::core::option::Option, -) -> Result { - match mask_expression { - Some(MaskExpression { select, .. }) => match &select.as_ref() { - Some(projection) => { - let column_indices: Vec = projection - .struct_items - .iter() - .map(|item| item.field as usize) - .collect(); - - let fields = column_indices - .iter() - .map(|i| schema.qualified_field(*i)) - .map(|(qualifier, field)| { - (qualifier.cloned(), Arc::new(field.clone())) - }) - .collect(); - - Ok(DFSchema::new_with_metadata( - fields, - schema.metadata().clone(), - )?) - } - None => Ok(schema), - }, - None => Ok(schema), - } -} - -/// Ensure the expressions have the right name(s) according to the new schema. -/// This includes the top-level (column) name, which will be renamed through aliasing if needed, -/// as well as nested names (if the expression produces any struct types), which will be renamed -/// through casting if needed. -fn rename_expressions( - exprs: impl IntoIterator, - input_schema: &DFSchema, - new_schema_fields: &[Arc], -) -> Result> { - exprs - .into_iter() - .zip(new_schema_fields) - .map(|(old_expr, new_field)| { - // Check if type (i.e. nested struct field names) match, use Cast to rename if needed - let new_expr = if &old_expr.get_type(input_schema)? != new_field.data_type() { - Expr::Cast(Cast::new( - Box::new(old_expr), - new_field.data_type().to_owned(), - )) - } else { - old_expr - }; - // Alias column if needed to fix the top-level name - match &new_expr { - // If expr is a column reference, alias_if_changed would cause an aliasing if the old expr has a qualifier - Expr::Column(c) if &c.name == new_field.name() => Ok(new_expr), - _ => new_expr.alias_if_changed(new_field.name().to_owned()), - } - }) - .collect() -} - -fn rename_field( - field: &Field, - dfs_names: &Vec, - unnamed_field_suffix: usize, // If Substrait doesn't provide a name, we'll use this "c{unnamed_field_suffix}" - name_idx: &mut usize, // Index into dfs_names - rename_self: bool, // Some fields (e.g. list items) don't have names in Substrait and this will be false to keep old name -) -> Result { - let name = if rename_self { - next_struct_field_name(unnamed_field_suffix, dfs_names, name_idx)? - } else { - field.name().to_string() - }; - match field.data_type() { - DataType::Struct(children) => { - let children = children - .iter() - .enumerate() - .map(|(child_idx, f)| { - rename_field( - f.as_ref(), - dfs_names, - child_idx, - name_idx, - /*rename_self=*/ true, - ) - }) - .collect::>()?; - Ok(field - .to_owned() - .with_name(name) - .with_data_type(DataType::Struct(children))) - } - DataType::List(inner) => { - let renamed_inner = rename_field( - inner.as_ref(), - dfs_names, - 0, - name_idx, - /*rename_self=*/ false, - )?; - Ok(field - .to_owned() - .with_data_type(DataType::List(FieldRef::new(renamed_inner))) - .with_name(name)) - } - DataType::LargeList(inner) => { - let renamed_inner = rename_field( - inner.as_ref(), - dfs_names, - 0, - name_idx, - /*rename_self= */ false, - )?; - Ok(field - .to_owned() - .with_data_type(DataType::LargeList(FieldRef::new(renamed_inner))) - .with_name(name)) - } - _ => Ok(field.to_owned().with_name(name)), - } -} - -/// Produce a version of the given schema with names matching the given list of names. -/// Substrait doesn't deal with column (incl. nested struct field) names within the schema, -/// but it does give us the list of expected names at the end of the plan, so we use this -/// to rename the schema to match the expected names. -fn make_renamed_schema( - schema: &DFSchemaRef, - dfs_names: &Vec, -) -> Result { - let mut name_idx = 0; - - let (qualifiers, fields): (_, Vec) = schema - .iter() - .enumerate() - .map(|(field_idx, (q, f))| { - let renamed_f = rename_field( - f.as_ref(), - dfs_names, - field_idx, - &mut name_idx, - /*rename_self=*/ true, - )?; - Ok((q.cloned(), renamed_f)) - }) - .collect::>>()? - .into_iter() - .unzip(); - - if name_idx != dfs_names.len() { - return substrait_err!( - "Names list must match exactly to nested schema, but found {} uses for {} names", - name_idx, - dfs_names.len()); - } - - DFSchema::from_field_specific_qualified_schema( - qualifiers, - &Arc::new(Schema::new(fields)), - ) -} - -#[async_recursion] -pub async fn from_project_rel( - consumer: &impl SubstraitConsumer, - p: &ProjectRel, -) -> Result { - if let Some(input) = p.input.as_ref() { - let input = consumer.consume_rel(input).await?; - let original_schema = Arc::clone(input.schema()); - - // Ensure that all expressions have a unique display name, so that - // validate_unique_names does not fail when constructing the project. - let mut name_tracker = NameTracker::new(); - - // By default, a Substrait Project emits all inputs fields followed by all expressions. - // We build the explicit expressions first, and then the input expressions to avoid - // adding aliases to the explicit expressions (as part of ensuring unique names). - // - // This is helpful for plan visualization and tests, because when DataFusion produces - // Substrait Projects it adds an output mapping that excludes all input columns - // leaving only explicit expressions. - - let mut explicit_exprs: Vec = vec![]; - // For WindowFunctions, we need to wrap them in a Window relation. If there are duplicates, - // we can do the window'ing only once, then the project will duplicate the result. - // Order here doesn't matter since LPB::window_plan sorts the expressions. - let mut window_exprs: HashSet = HashSet::new(); - for expr in &p.expressions { - let e = consumer - .consume_expression(expr, input.clone().schema()) - .await?; - // if the expression is WindowFunction, wrap in a Window relation - if let Expr::WindowFunction(_) = &e { - // Adding the same expression here and in the project below - // works because the project's builder uses columnize_expr(..) - // to transform it into a column reference - window_exprs.insert(e.clone()); - } - explicit_exprs.push(name_tracker.get_uniquely_named_expr(e)?); - } - - let input = if !window_exprs.is_empty() { - LogicalPlanBuilder::window_plan(input, window_exprs)? - } else { - input - }; - - let mut final_exprs: Vec = vec![]; - for index in 0..original_schema.fields().len() { - let e = Expr::Column(Column::from(original_schema.qualified_field(index))); - final_exprs.push(name_tracker.get_uniquely_named_expr(e)?); - } - final_exprs.append(&mut explicit_exprs); - project(input, final_exprs) - } else { - not_impl_err!("Projection without an input is not supported") - } -} - -#[async_recursion] -pub async fn from_filter_rel( - consumer: &impl SubstraitConsumer, - filter: &FilterRel, -) -> Result { - if let Some(input) = filter.input.as_ref() { - let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); - if let Some(condition) = filter.condition.as_ref() { - let expr = consumer - .consume_expression(condition, input.schema()) - .await?; - input.filter(expr)?.build() - } else { - not_impl_err!("Filter without an condition is not valid") - } - } else { - not_impl_err!("Filter without an input is not valid") - } -} - -#[async_recursion] -pub async fn from_fetch_rel( - consumer: &impl SubstraitConsumer, - fetch: &FetchRel, -) -> Result { - if let Some(input) = fetch.input.as_ref() { - let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); - let empty_schema = DFSchemaRef::new(DFSchema::empty()); - let offset = match &fetch.offset_mode { - Some(fetch_rel::OffsetMode::Offset(offset)) => Some(lit(*offset)), - Some(fetch_rel::OffsetMode::OffsetExpr(expr)) => { - Some(consumer.consume_expression(expr, &empty_schema).await?) - } - None => None, - }; - let count = match &fetch.count_mode { - Some(fetch_rel::CountMode::Count(count)) => { - // -1 means that ALL records should be returned, equivalent to None - (*count != -1).then(|| lit(*count)) - } - Some(fetch_rel::CountMode::CountExpr(expr)) => { - Some(consumer.consume_expression(expr, &empty_schema).await?) - } - None => None, - }; - input.limit_by_expr(offset, count)?.build() - } else { - not_impl_err!("Fetch without an input is not valid") - } -} - -pub async fn from_sort_rel( - consumer: &impl SubstraitConsumer, - sort: &SortRel, -) -> Result { - if let Some(input) = sort.input.as_ref() { - let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); - let sorts = from_substrait_sorts(consumer, &sort.sorts, input.schema()).await?; - input.sort(sorts)?.build() - } else { - not_impl_err!("Sort without an input is not valid") - } -} - -pub async fn from_aggregate_rel( - consumer: &impl SubstraitConsumer, - agg: &AggregateRel, -) -> Result { - if let Some(input) = agg.input.as_ref() { - let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); - let mut ref_group_exprs = vec![]; - - for e in &agg.grouping_expressions { - let x = consumer.consume_expression(e, input.schema()).await?; - ref_group_exprs.push(x); - } - - let mut group_exprs = vec![]; - let mut aggr_exprs = vec![]; - - match agg.groupings.len() { - 1 => { - group_exprs.extend_from_slice( - &from_substrait_grouping( - consumer, - &agg.groupings[0], - &ref_group_exprs, - input.schema(), - ) - .await?, - ); - } - _ => { - let mut grouping_sets = vec![]; - for grouping in &agg.groupings { - let grouping_set = from_substrait_grouping( - consumer, - grouping, - &ref_group_exprs, - input.schema(), - ) - .await?; - grouping_sets.push(grouping_set); - } - // Single-element grouping expression of type Expr::GroupingSet. - // Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when - // parsed by the producer and consumer, since Substrait does not have a type dedicated - // to ROLLUP. Only vector of Groupings (grouping sets) is available. - group_exprs - .push(Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets))); - } - }; - - for m in &agg.measures { - let filter = match &m.filter { - Some(fil) => Some(Box::new( - consumer.consume_expression(fil, input.schema()).await?, - )), - None => None, - }; - let agg_func = match &m.measure { - Some(f) => { - let distinct = match f.invocation { - _ if f.invocation == AggregationInvocation::Distinct as i32 => { - true - } - _ if f.invocation == AggregationInvocation::All as i32 => false, - _ => false, - }; - let order_by = if !f.sorts.is_empty() { - Some( - from_substrait_sorts(consumer, &f.sorts, input.schema()) - .await?, - ) - } else { - None - }; - - from_substrait_agg_func( - consumer, - f, - input.schema(), - filter, - order_by, - distinct, - ) - .await - } - None => { - not_impl_err!("Aggregate without aggregate function is not supported") - } - }; - aggr_exprs.push(agg_func?.as_ref().clone()); - } - input.aggregate(group_exprs, aggr_exprs)?.build() - } else { - not_impl_err!("Aggregate without an input is not valid") - } -} - -pub async fn from_join_rel( - consumer: &impl SubstraitConsumer, - join: &JoinRel, -) -> Result { - if join.post_join_filter.is_some() { - return not_impl_err!("JoinRel with post_join_filter is not yet supported"); - } - - let left: LogicalPlanBuilder = LogicalPlanBuilder::from( - consumer.consume_rel(join.left.as_ref().unwrap()).await?, - ); - let right = LogicalPlanBuilder::from( - consumer.consume_rel(join.right.as_ref().unwrap()).await?, - ); - let (left, right) = requalify_sides_if_needed(left, right)?; - - let join_type = from_substrait_jointype(join.r#type)?; - // The join condition expression needs full input schema and not the output schema from join since we lose columns from - // certain join types such as semi and anti joins - let in_join_schema = left.schema().join(right.schema())?; - - // If join expression exists, parse the `on` condition expression, build join and return - // Otherwise, build join with only the filter, without join keys - match &join.expression.as_ref() { - Some(expr) => { - let on = consumer.consume_expression(expr, &in_join_schema).await?; - // The join expression can contain both equal and non-equal ops. - // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. - // So we extract each part as follows: - // - If an Eq or IsNotDistinctFrom op is encountered, add the left column, right column and is_null_equal_nulls to `join_ons` vector - // - Otherwise we add the expression to join_filter (use conjunction if filter already exists) - let (join_ons, nulls_equal_nulls, join_filter) = - split_eq_and_noneq_join_predicate_with_nulls_equality(&on); - let (left_cols, right_cols): (Vec<_>, Vec<_>) = - itertools::multiunzip(join_ons); - left.join_detailed( - right.build()?, - join_type, - (left_cols, right_cols), - join_filter, - nulls_equal_nulls, - )? - .build() - } - None => { - let on: Vec = vec![]; - left.join_detailed(right.build()?, join_type, (on.clone(), on), None, false)? - .build() - } - } -} - -pub async fn from_cross_rel( - consumer: &impl SubstraitConsumer, - cross: &CrossRel, -) -> Result { - let left = LogicalPlanBuilder::from( - consumer.consume_rel(cross.left.as_ref().unwrap()).await?, - ); - let right = LogicalPlanBuilder::from( - consumer.consume_rel(cross.right.as_ref().unwrap()).await?, - ); - let (left, right) = requalify_sides_if_needed(left, right)?; - left.cross_join(right.build()?)?.build() -} - -#[allow(deprecated)] -pub async fn from_read_rel( - consumer: &impl SubstraitConsumer, - read: &ReadRel, -) -> Result { - async fn read_with_schema( - consumer: &impl SubstraitConsumer, - table_ref: TableReference, - schema: DFSchema, - projection: &Option, - filter: &Option>, - ) -> Result { - let schema = schema.replace_qualifier(table_ref.clone()); - - let filters = if let Some(f) = filter { - let filter_expr = consumer.consume_expression(f, &schema).await?; - split_conjunction_owned(filter_expr) - } else { - vec![] - }; - - let plan = { - let provider = match consumer.resolve_table_ref(&table_ref).await? { - Some(ref provider) => Arc::clone(provider), - _ => return plan_err!("No table named '{table_ref}'"), - }; - - LogicalPlanBuilder::scan_with_filters( - table_ref, - provider_as_source(Arc::clone(&provider)), - None, - filters, - )? - .build()? - }; - - ensure_schema_compatibility(plan.schema(), schema.clone())?; - - let schema = apply_masking(schema, projection)?; - - apply_projection(plan, schema) - } - - let named_struct = read.base_schema.as_ref().ok_or_else(|| { - substrait_datafusion_err!("No base schema provided for Read Relation") - })?; - - let substrait_schema = from_substrait_named_struct(consumer, named_struct)?; - - match &read.read_type { - Some(ReadType::NamedTable(nt)) => { - let table_reference = match nt.names.len() { - 0 => { - return plan_err!("No table name found in NamedTable"); - } - 1 => TableReference::Bare { - table: nt.names[0].clone().into(), - }, - 2 => TableReference::Partial { - schema: nt.names[0].clone().into(), - table: nt.names[1].clone().into(), - }, - _ => TableReference::Full { - catalog: nt.names[0].clone().into(), - schema: nt.names[1].clone().into(), - table: nt.names[2].clone().into(), - }, - }; - - read_with_schema( - consumer, - table_reference, - substrait_schema, - &read.projection, - &read.filter, - ) - .await - } - Some(ReadType::VirtualTable(vt)) => { - if vt.values.is_empty() { - return Ok(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: DFSchemaRef::new(substrait_schema), - })); - } - - let values = vt - .values - .iter() - .map(|row| { - let mut name_idx = 0; - let lits = row - .fields - .iter() - .map(|lit| { - name_idx += 1; // top-level names are provided through schema - Ok(Expr::Literal(from_substrait_literal( - consumer, - lit, - &named_struct.names, - &mut name_idx, - )?)) - }) - .collect::>()?; - if name_idx != named_struct.names.len() { - return substrait_err!( - "Names list must match exactly to nested schema, but found {} uses for {} names", - name_idx, - named_struct.names.len() - ); - } - Ok(lits) - }) - .collect::>()?; - - Ok(LogicalPlan::Values(Values { - schema: DFSchemaRef::new(substrait_schema), - values, - })) - } - Some(ReadType::LocalFiles(lf)) => { - fn extract_filename(name: &str) -> Option { - let corrected_url = - if name.starts_with("file://") && !name.starts_with("file:///") { - name.replacen("file://", "file:///", 1) - } else { - name.to_string() - }; - - Url::parse(&corrected_url).ok().and_then(|url| { - let path = url.path(); - std::path::Path::new(path) - .file_name() - .map(|filename| filename.to_string_lossy().to_string()) - }) - } - - // we could use the file name to check the original table provider - // TODO: currently does not support multiple local files - let filename: Option = - lf.items.first().and_then(|x| match x.path_type.as_ref() { - Some(UriFile(name)) => extract_filename(name), - _ => None, - }); - - if lf.items.len() > 1 || filename.is_none() { - return not_impl_err!("Only single file reads are supported"); - } - let name = filename.unwrap(); - // directly use unwrap here since we could determine it is a valid one - let table_reference = TableReference::Bare { table: name.into() }; - - read_with_schema( - consumer, - table_reference, - substrait_schema, - &read.projection, - &read.filter, - ) - .await - } - _ => { - not_impl_err!("Unsupported ReadType: {:?}", read.read_type) - } - } -} - -pub async fn from_set_rel( - consumer: &impl SubstraitConsumer, - set: &SetRel, -) -> Result { - if set.inputs.len() < 2 { - substrait_err!("Set operation requires at least two inputs") - } else { - match set.op() { - SetOp::UnionAll => union_rels(consumer, &set.inputs, true).await, - SetOp::UnionDistinct => union_rels(consumer, &set.inputs, false).await, - SetOp::IntersectionPrimary => LogicalPlanBuilder::intersect( - consumer.consume_rel(&set.inputs[0]).await?, - union_rels(consumer, &set.inputs[1..], true).await?, - false, - ), - SetOp::IntersectionMultiset => { - intersect_rels(consumer, &set.inputs, false).await - } - SetOp::IntersectionMultisetAll => { - intersect_rels(consumer, &set.inputs, true).await - } - SetOp::MinusPrimary => except_rels(consumer, &set.inputs, false).await, - SetOp::MinusPrimaryAll => except_rels(consumer, &set.inputs, true).await, - set_op => not_impl_err!("Unsupported set operator: {set_op:?}"), - } - } -} - -pub async fn from_exchange_rel( - consumer: &impl SubstraitConsumer, - exchange: &ExchangeRel, -) -> Result { - let Some(input) = exchange.input.as_ref() else { - return substrait_err!("Unexpected empty input in ExchangeRel"); - }; - let input = Arc::new(consumer.consume_rel(input).await?); - - let Some(exchange_kind) = &exchange.exchange_kind else { - return substrait_err!("Unexpected empty input in ExchangeRel"); - }; - - // ref: https://substrait.io/relations/physical_relations/#exchange-types - let partitioning_scheme = match exchange_kind { - ExchangeKind::ScatterByFields(scatter_fields) => { - let mut partition_columns = vec![]; - let input_schema = input.schema(); - for field_ref in &scatter_fields.fields { - let column = from_substrait_field_reference(field_ref, input_schema)?; - partition_columns.push(column); - } - Partitioning::Hash(partition_columns, exchange.partition_count as usize) - } - ExchangeKind::RoundRobin(_) => { - Partitioning::RoundRobinBatch(exchange.partition_count as usize) - } - ExchangeKind::SingleTarget(_) - | ExchangeKind::MultiTarget(_) - | ExchangeKind::Broadcast(_) => { - return not_impl_err!("Unsupported exchange kind: {exchange_kind:?}"); - } - }; - Ok(LogicalPlan::Repartition(Repartition { - input, - partitioning_scheme, - })) -} - -fn retrieve_rel_common(rel: &Rel) -> Option<&RelCommon> { - match rel.rel_type.as_ref() { - None => None, - Some(rt) => match rt { - RelType::Read(r) => r.common.as_ref(), - RelType::Filter(f) => f.common.as_ref(), - RelType::Fetch(f) => f.common.as_ref(), - RelType::Aggregate(a) => a.common.as_ref(), - RelType::Sort(s) => s.common.as_ref(), - RelType::Join(j) => j.common.as_ref(), - RelType::Project(p) => p.common.as_ref(), - RelType::Set(s) => s.common.as_ref(), - RelType::ExtensionSingle(e) => e.common.as_ref(), - RelType::ExtensionMulti(e) => e.common.as_ref(), - RelType::ExtensionLeaf(e) => e.common.as_ref(), - RelType::Cross(c) => c.common.as_ref(), - RelType::Reference(_) => None, - RelType::Write(w) => w.common.as_ref(), - RelType::Ddl(d) => d.common.as_ref(), - RelType::HashJoin(j) => j.common.as_ref(), - RelType::MergeJoin(j) => j.common.as_ref(), - RelType::NestedLoopJoin(j) => j.common.as_ref(), - RelType::Window(w) => w.common.as_ref(), - RelType::Exchange(e) => e.common.as_ref(), - RelType::Expand(e) => e.common.as_ref(), - RelType::Update(_) => None, - }, - } -} - -fn retrieve_emit_kind(rel_common: Option<&RelCommon>) -> EmitKind { - // the default EmitKind is Direct if it is not set explicitly - let default = EmitKind::Direct(rel_common::Direct {}); - rel_common - .and_then(|rc| rc.emit_kind.as_ref()) - .map_or(default, |ek| ek.clone()) -} - -fn contains_volatile_expr(proj: &Projection) -> bool { - proj.expr.iter().any(|e| e.is_volatile()) -} - -fn apply_emit_kind( - rel_common: Option<&RelCommon>, - plan: LogicalPlan, -) -> Result { - match retrieve_emit_kind(rel_common) { - EmitKind::Direct(_) => Ok(plan), - EmitKind::Emit(Emit { output_mapping }) => { - // It is valid to reference the same field multiple times in the Emit - // In this case, we need to provide unique names to avoid collisions - let mut name_tracker = NameTracker::new(); - match plan { - // To avoid adding a projection on top of a projection, we apply special case - // handling to flatten Substrait Emits. This is only applicable if none of the - // expressions in the projection are volatile. This is to avoid issues like - // converting a single call of the random() function into multiple calls due to - // duplicate fields in the output_mapping. - LogicalPlan::Projection(proj) if !contains_volatile_expr(&proj) => { - let mut exprs: Vec = vec![]; - for field in output_mapping { - let expr = proj.expr - .get(field as usize) - .ok_or_else(|| substrait_datafusion_err!( - "Emit output field {} cannot be resolved in input schema {}", - field, proj.input.schema() - ))?; - exprs.push(name_tracker.get_uniquely_named_expr(expr.clone())?); - } - - let input = Arc::unwrap_or_clone(proj.input); - project(input, exprs) - } - // Otherwise we just handle the output_mapping as a projection - _ => { - let input_schema = plan.schema(); - - let mut exprs: Vec = vec![]; - for index in output_mapping.into_iter() { - let column = Expr::Column(Column::from( - input_schema.qualified_field(index as usize), - )); - let expr = name_tracker.get_uniquely_named_expr(column)?; - exprs.push(expr); - } - - project(plan, exprs) - } - } - } - } -} - -struct NameTracker { - seen_names: HashSet, -} - -enum NameTrackerStatus { - NeverSeen, - SeenBefore, -} - -impl NameTracker { - fn new() -> Self { - NameTracker { - seen_names: HashSet::default(), - } - } - fn get_unique_name(&mut self, name: String) -> (String, NameTrackerStatus) { - match self.seen_names.insert(name.clone()) { - true => (name, NameTrackerStatus::NeverSeen), - false => { - let mut counter = 0; - loop { - let candidate_name = format!("{}__temp__{}", name, counter); - if self.seen_names.insert(candidate_name.clone()) { - return (candidate_name, NameTrackerStatus::SeenBefore); - } - counter += 1; - } - } - } - } - - fn get_uniquely_named_expr(&mut self, expr: Expr) -> Result { - match self.get_unique_name(expr.name_for_alias()?) { - (_, NameTrackerStatus::NeverSeen) => Ok(expr), - (name, NameTrackerStatus::SeenBefore) => Ok(expr.alias(name)), - } - } -} - -/// Ensures that the given Substrait schema is compatible with the schema as given by DataFusion -/// -/// This means: -/// 1. All fields present in the Substrait schema are present in the DataFusion schema. The -/// DataFusion schema may have MORE fields, but not the other way around. -/// 2. All fields are compatible. See [`ensure_field_compatibility`] for details -fn ensure_schema_compatibility( - table_schema: &DFSchema, - substrait_schema: DFSchema, -) -> Result<()> { - substrait_schema - .strip_qualifiers() - .fields() - .iter() - .try_for_each(|substrait_field| { - let df_field = - table_schema.field_with_unqualified_name(substrait_field.name())?; - ensure_field_compatibility(df_field, substrait_field) - }) -} - -/// This function returns a DataFrame with fields adjusted if necessary in the event that the -/// Substrait schema is a subset of the DataFusion schema. -fn apply_projection( - plan: LogicalPlan, - substrait_schema: DFSchema, -) -> Result { - let df_schema = plan.schema(); - - if df_schema.logically_equivalent_names_and_types(&substrait_schema) { - return Ok(plan); - } - - let df_schema = df_schema.to_owned(); - - match plan { - LogicalPlan::TableScan(mut scan) => { - let column_indices: Vec = substrait_schema - .strip_qualifiers() - .fields() - .iter() - .map(|substrait_field| { - Ok(df_schema - .index_of_column_by_name(None, substrait_field.name().as_str()) - .unwrap()) - }) - .collect::>()?; - - let fields = column_indices - .iter() - .map(|i| df_schema.qualified_field(*i)) - .map(|(qualifier, field)| (qualifier.cloned(), Arc::new(field.clone()))) - .collect(); - - scan.projected_schema = DFSchemaRef::new(DFSchema::new_with_metadata( - fields, - df_schema.metadata().clone(), - )?); - scan.projection = Some(column_indices); - - Ok(LogicalPlan::TableScan(scan)) - } - _ => plan_err!("DataFrame passed to apply_projection must be a TableScan"), - } -} - -/// Ensures that the given Substrait field is compatible with the given DataFusion field -/// -/// A field is compatible between Substrait and DataFusion if: -/// 1. They have logically equivalent types. -/// 2. They have the same nullability OR the Substrait field is nullable and the DataFusion fields -/// is not nullable. -/// -/// If a Substrait field is not nullable, the Substrait plan may be built around assuming it is not -/// nullable. As such if DataFusion has that field as nullable the plan should be rejected. -fn ensure_field_compatibility( - datafusion_field: &Field, - substrait_field: &Field, -) -> Result<()> { - if !DFSchema::datatype_is_logically_equal( - datafusion_field.data_type(), - substrait_field.data_type(), - ) { - return substrait_err!( - "Field '{}' in Substrait schema has a different type ({}) than the corresponding field in the table schema ({}).", - substrait_field.name(), - substrait_field.data_type(), - datafusion_field.data_type() - ); - } - - if !compatible_nullabilities( - datafusion_field.is_nullable(), - substrait_field.is_nullable(), - ) { - // TODO: from_substrait_struct_type needs to be updated to set the nullability correctly. It defaults to true for now. - return substrait_err!( - "Field '{}' is nullable in the DataFusion schema but not nullable in the Substrait schema.", - substrait_field.name() - ); - } - Ok(()) -} - -/// Returns true if the DataFusion and Substrait nullabilities are compatible, false otherwise -fn compatible_nullabilities( - datafusion_nullability: bool, - substrait_nullability: bool, -) -> bool { - // DataFusion and Substrait have the same nullability - (datafusion_nullability == substrait_nullability) - // DataFusion is not nullable and Substrait is nullable - || (!datafusion_nullability && substrait_nullability) -} - -/// (Re)qualify the sides of a join if needed, i.e. if the columns from one side would otherwise -/// conflict with the columns from the other. -/// Substrait doesn't currently allow specifying aliases, neither for columns nor for tables. For -/// Substrait the names don't matter since it only refers to columns by indices, however DataFusion -/// requires columns to be uniquely identifiable, in some places (see e.g. DFSchema::check_names). -fn requalify_sides_if_needed( - left: LogicalPlanBuilder, - right: LogicalPlanBuilder, -) -> Result<(LogicalPlanBuilder, LogicalPlanBuilder)> { - let left_cols = left.schema().columns(); - let right_cols = right.schema().columns(); - if left_cols.iter().any(|l| { - right_cols.iter().any(|r| { - l == r || (l.name == r.name && (l.relation.is_none() || r.relation.is_none())) - }) - }) { - // These names have no connection to the original plan, but they'll make the columns - // (mostly) unique. - Ok(( - left.alias(TableReference::bare("left"))?, - right.alias(TableReference::bare("right"))?, - )) - } else { - Ok((left, right)) - } -} - -fn from_substrait_jointype(join_type: i32) -> Result { - if let Ok(substrait_join_type) = join_rel::JoinType::try_from(join_type) { - match substrait_join_type { - join_rel::JoinType::Inner => Ok(JoinType::Inner), - join_rel::JoinType::Left => Ok(JoinType::Left), - join_rel::JoinType::Right => Ok(JoinType::Right), - join_rel::JoinType::Outer => Ok(JoinType::Full), - join_rel::JoinType::LeftAnti => Ok(JoinType::LeftAnti), - join_rel::JoinType::LeftSemi => Ok(JoinType::LeftSemi), - join_rel::JoinType::LeftMark => Ok(JoinType::LeftMark), - _ => plan_err!("unsupported join type {substrait_join_type:?}"), - } - } else { - plan_err!("invalid join type variant {join_type:?}") - } -} - -/// Convert Substrait Sorts to DataFusion Exprs -pub async fn from_substrait_sorts( - consumer: &impl SubstraitConsumer, - substrait_sorts: &Vec, - input_schema: &DFSchema, -) -> Result> { - let mut sorts: Vec = vec![]; - for s in substrait_sorts { - let expr = consumer - .consume_expression(s.expr.as_ref().unwrap(), input_schema) - .await?; - let asc_nullfirst = match &s.sort_kind { - Some(k) => match k { - Direction(d) => { - let Ok(direction) = SortDirection::try_from(*d) else { - return not_impl_err!( - "Unsupported Substrait SortDirection value {d}" - ); - }; - - match direction { - SortDirection::AscNullsFirst => Ok((true, true)), - SortDirection::AscNullsLast => Ok((true, false)), - SortDirection::DescNullsFirst => Ok((false, true)), - SortDirection::DescNullsLast => Ok((false, false)), - SortDirection::Clustered => not_impl_err!( - "Sort with direction clustered is not yet supported" - ), - SortDirection::Unspecified => { - not_impl_err!("Unspecified sort direction is invalid") - } - } - } - ComparisonFunctionReference(_) => not_impl_err!( - "Sort using comparison function reference is not supported" - ), - }, - None => not_impl_err!("Sort without sort kind is invalid"), - }; - let (asc, nulls_first) = asc_nullfirst.unwrap(); - sorts.push(Sort { - expr, - asc, - nulls_first, - }); - } - Ok(sorts) -} - -/// Convert Substrait Expressions to DataFusion Exprs -pub async fn from_substrait_rex_vec( - consumer: &impl SubstraitConsumer, - exprs: &Vec, - input_schema: &DFSchema, -) -> Result> { - let mut expressions: Vec = vec![]; - for expr in exprs { - let expression = consumer.consume_expression(expr, input_schema).await?; - expressions.push(expression); - } - Ok(expressions) -} - -/// Convert Substrait FunctionArguments to DataFusion Exprs -pub async fn from_substrait_func_args( - consumer: &impl SubstraitConsumer, - arguments: &Vec, - input_schema: &DFSchema, -) -> Result> { - let mut args: Vec = vec![]; - for arg in arguments { - let arg_expr = match &arg.arg_type { - Some(ArgType::Value(e)) => consumer.consume_expression(e, input_schema).await, - _ => not_impl_err!("Function argument non-Value type not supported"), - }; - args.push(arg_expr?); - } - Ok(args) -} - -/// Convert Substrait AggregateFunction to DataFusion Expr -pub async fn from_substrait_agg_func( - consumer: &impl SubstraitConsumer, - f: &AggregateFunction, - input_schema: &DFSchema, - filter: Option>, - order_by: Option>, - distinct: bool, -) -> Result> { - let Some(fn_signature) = consumer - .get_extensions() - .functions - .get(&f.function_reference) - else { - return plan_err!( - "Aggregate function not registered: function anchor = {:?}", - f.function_reference - ); - }; - - let fn_name = substrait_fun_name(fn_signature); - let udaf = consumer.get_function_registry().udaf(fn_name); - let udaf = udaf.map_err(|_| { - not_impl_datafusion_err!( - "Aggregate function {} is not supported: function anchor = {:?}", - fn_signature, - f.function_reference - ) - })?; - - let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?; - - // Datafusion does not support aggregate functions with no arguments, so - // we inject a dummy argument that does not affect the query, but allows - // us to bypass this limitation. - let args = if udaf.name() == "count" && args.is_empty() { - vec![Expr::Literal(ScalarValue::Int64(Some(1)))] - } else { - args - }; - - Ok(Arc::new(Expr::AggregateFunction( - expr::AggregateFunction::new_udf(udaf, args, distinct, filter, order_by, None), - ))) -} - -/// Convert Substrait Rex to DataFusion Expr -pub async fn from_substrait_rex( - consumer: &impl SubstraitConsumer, - expression: &Expression, - input_schema: &DFSchema, -) -> Result { - match &expression.rex_type { - Some(t) => match t { - RexType::Literal(expr) => consumer.consume_literal(expr).await, - RexType::Selection(expr) => { - consumer.consume_field_reference(expr, input_schema).await - } - RexType::ScalarFunction(expr) => { - consumer.consume_scalar_function(expr, input_schema).await - } - RexType::WindowFunction(expr) => { - consumer.consume_window_function(expr, input_schema).await - } - RexType::IfThen(expr) => consumer.consume_if_then(expr, input_schema).await, - RexType::SwitchExpression(expr) => { - consumer.consume_switch(expr, input_schema).await - } - RexType::SingularOrList(expr) => { - consumer.consume_singular_or_list(expr, input_schema).await - } - - RexType::MultiOrList(expr) => { - consumer.consume_multi_or_list(expr, input_schema).await - } - - RexType::Cast(expr) => { - consumer.consume_cast(expr.as_ref(), input_schema).await - } - - RexType::Subquery(expr) => { - consumer.consume_subquery(expr.as_ref(), input_schema).await - } - RexType::Nested(expr) => consumer.consume_nested(expr, input_schema).await, - RexType::Enum(expr) => consumer.consume_enum(expr, input_schema).await, - RexType::DynamicParameter(expr) => { - consumer.consume_dynamic_parameter(expr, input_schema).await - } - }, - None => substrait_err!("Expression must set rex_type: {:?}", expression), - } -} - -pub async fn from_singular_or_list( - consumer: &impl SubstraitConsumer, - expr: &SingularOrList, - input_schema: &DFSchema, -) -> Result { - let substrait_expr = expr.value.as_ref().unwrap(); - let substrait_list = expr.options.as_ref(); - Ok(Expr::InList(InList { - expr: Box::new( - consumer - .consume_expression(substrait_expr, input_schema) - .await?, - ), - list: from_substrait_rex_vec(consumer, substrait_list, input_schema).await?, - negated: false, - })) -} - -pub async fn from_field_reference( - _consumer: &impl SubstraitConsumer, - field_ref: &FieldReference, - input_schema: &DFSchema, -) -> Result { - from_substrait_field_reference(field_ref, input_schema) -} - -pub async fn from_if_then( - consumer: &impl SubstraitConsumer, - if_then: &IfThen, - input_schema: &DFSchema, -) -> Result { - // Parse `ifs` - // If the first element does not have a `then` part, then we can assume it's a base expression - let mut when_then_expr: Vec<(Box, Box)> = vec![]; - let mut expr = None; - for (i, if_expr) in if_then.ifs.iter().enumerate() { - if i == 0 { - // Check if the first element is type base expression - if if_expr.then.is_none() { - expr = Some(Box::new( - consumer - .consume_expression(if_expr.r#if.as_ref().unwrap(), input_schema) - .await?, - )); - continue; - } - } - when_then_expr.push(( - Box::new( - consumer - .consume_expression(if_expr.r#if.as_ref().unwrap(), input_schema) - .await?, - ), - Box::new( - consumer - .consume_expression(if_expr.then.as_ref().unwrap(), input_schema) - .await?, - ), - )); - } - // Parse `else` - let else_expr = match &if_then.r#else { - Some(e) => Some(Box::new( - consumer.consume_expression(e, input_schema).await?, - )), - None => None, - }; - Ok(Expr::Case(Case { - expr, - when_then_expr, - else_expr, - })) -} - -pub async fn from_scalar_function( - consumer: &impl SubstraitConsumer, - f: &ScalarFunction, - input_schema: &DFSchema, -) -> Result { - let Some(fn_signature) = consumer - .get_extensions() - .functions - .get(&f.function_reference) - else { - return plan_err!( - "Scalar function not found: function reference = {:?}", - f.function_reference - ); - }; - let fn_name = substrait_fun_name(fn_signature); - let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?; - - // try to first match the requested function into registered udfs, then built-in ops - // and finally built-in expressions - if let Ok(func) = consumer.get_function_registry().udf(fn_name) { - Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( - func.to_owned(), - args, - ))) - } else if let Some(op) = name_to_op(fn_name) { - if f.arguments.len() < 2 { - return not_impl_err!( - "Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}", - f.arguments.len() - ); - } - // Some expressions are binary in DataFusion but take in a variadic number of args in Substrait. - // In those cases we iterate through all the arguments, applying the binary expression against them all - let combined_expr = args - .into_iter() - .fold(None, |combined_expr: Option, arg: Expr| { - Some(match combined_expr { - Some(expr) => Expr::BinaryExpr(BinaryExpr { - left: Box::new(expr), - op, - right: Box::new(arg), - }), - None => arg, - }) - }) - .unwrap(); - - Ok(combined_expr) - } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { - builder.build(consumer, f, input_schema).await - } else { - not_impl_err!("Unsupported function name: {fn_name:?}") - } -} - -pub async fn from_literal( - consumer: &impl SubstraitConsumer, - expr: &Literal, -) -> Result { - let scalar_value = from_substrait_literal_without_names(consumer, expr)?; - Ok(Expr::Literal(scalar_value)) -} - -pub async fn from_cast( - consumer: &impl SubstraitConsumer, - cast: &substrait_expression::Cast, - input_schema: &DFSchema, -) -> Result { - match cast.r#type.as_ref() { - Some(output_type) => { - let input_expr = Box::new( - consumer - .consume_expression( - cast.input.as_ref().unwrap().as_ref(), - input_schema, - ) - .await?, - ); - let data_type = from_substrait_type_without_names(consumer, output_type)?; - if cast.failure_behavior() == ReturnNull { - Ok(Expr::TryCast(TryCast::new(input_expr, data_type))) - } else { - Ok(Expr::Cast(Cast::new(input_expr, data_type))) - } - } - None => substrait_err!("Cast expression without output type is not allowed"), - } -} - -pub async fn from_window_function( - consumer: &impl SubstraitConsumer, - window: &WindowFunction, - input_schema: &DFSchema, -) -> Result { - let Some(fn_signature) = consumer - .get_extensions() - .functions - .get(&window.function_reference) - else { - return plan_err!( - "Window function not found: function reference = {:?}", - window.function_reference - ); - }; - let fn_name = substrait_fun_name(fn_signature); - - // check udwf first, then udaf, then built-in window and aggregate functions - let fun = if let Ok(udwf) = consumer.get_function_registry().udwf(fn_name) { - Ok(WindowFunctionDefinition::WindowUDF(udwf)) - } else if let Ok(udaf) = consumer.get_function_registry().udaf(fn_name) { - Ok(WindowFunctionDefinition::AggregateUDF(udaf)) - } else { - not_impl_err!( - "Window function {} is not supported: function anchor = {:?}", - fn_name, - window.function_reference - ) - }?; - - let mut order_by = - from_substrait_sorts(consumer, &window.sorts, input_schema).await?; - - let bound_units = match BoundsType::try_from(window.bounds_type).map_err(|e| { - plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type) - })? { - BoundsType::Rows => WindowFrameUnits::Rows, - BoundsType::Range => WindowFrameUnits::Range, - BoundsType::Unspecified => { - // If the plan does not specify the bounds type, then we use a simple logic to determine the units - // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary - // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row - if order_by.is_empty() { - WindowFrameUnits::Rows - } else { - WindowFrameUnits::Range - } - } - }; - let window_frame = datafusion::logical_expr::WindowFrame::new_bounds( - bound_units, - from_substrait_bound(&window.lower_bound, true)?, - from_substrait_bound(&window.upper_bound, false)?, - ); - - window_frame.regularize_order_bys(&mut order_by)?; - - // Datafusion does not support aggregate functions with no arguments, so - // we inject a dummy argument that does not affect the query, but allows - // us to bypass this limitation. - let args = if fun.name() == "count" && window.arguments.is_empty() { - vec![Expr::Literal(ScalarValue::Int64(Some(1)))] - } else { - from_substrait_func_args(consumer, &window.arguments, input_schema).await? - }; - - Ok(Expr::WindowFunction(expr::WindowFunction { - fun, - params: WindowFunctionParams { - args, - partition_by: from_substrait_rex_vec( - consumer, - &window.partitions, - input_schema, - ) - .await?, - order_by, - window_frame, - null_treatment: None, - }, - })) -} - -pub async fn from_subquery( - consumer: &impl SubstraitConsumer, - subquery: &substrait_expression::Subquery, - input_schema: &DFSchema, -) -> Result { - match &subquery.subquery_type { - Some(subquery_type) => match subquery_type { - SubqueryType::InPredicate(in_predicate) => { - if in_predicate.needles.len() != 1 { - substrait_err!("InPredicate Subquery type must have exactly one Needle expression") - } else { - let needle_expr = &in_predicate.needles[0]; - let haystack_expr = &in_predicate.haystack; - if let Some(haystack_expr) = haystack_expr { - let haystack_expr = consumer.consume_rel(haystack_expr).await?; - let outer_refs = haystack_expr.all_out_ref_exprs(); - Ok(Expr::InSubquery(InSubquery { - expr: Box::new( - consumer - .consume_expression(needle_expr, input_schema) - .await?, - ), - subquery: Subquery { - subquery: Arc::new(haystack_expr), - outer_ref_columns: outer_refs, - spans: Spans::new(), - }, - negated: false, - })) - } else { - substrait_err!( - "InPredicate Subquery type must have a Haystack expression" - ) - } - } - } - SubqueryType::Scalar(query) => { - let plan = consumer - .consume_rel(&(query.input.clone()).unwrap_or_default()) - .await?; - let outer_ref_columns = plan.all_out_ref_exprs(); - Ok(Expr::ScalarSubquery(Subquery { - subquery: Arc::new(plan), - outer_ref_columns, - spans: Spans::new(), - })) - } - SubqueryType::SetPredicate(predicate) => { - match predicate.predicate_op() { - // exist - PredicateOp::Exists => { - let relation = &predicate.tuples; - let plan = consumer - .consume_rel(&relation.clone().unwrap_or_default()) - .await?; - let outer_ref_columns = plan.all_out_ref_exprs(); - Ok(Expr::Exists(Exists::new( - Subquery { - subquery: Arc::new(plan), - outer_ref_columns, - spans: Spans::new(), - }, - false, - ))) - } - other_type => substrait_err!( - "unimplemented type {:?} for set predicate", - other_type - ), - } - } - other_type => { - substrait_err!("Subquery type {:?} not implemented", other_type) - } - }, - None => { - substrait_err!("Subquery expression without SubqueryType is not allowed") - } - } -} - -pub(crate) fn from_substrait_type_without_names( - consumer: &impl SubstraitConsumer, - dt: &Type, -) -> Result { - from_substrait_type(consumer, dt, &[], &mut 0) -} - -fn from_substrait_type( - consumer: &impl SubstraitConsumer, - dt: &Type, - dfs_names: &[String], - name_idx: &mut usize, -) -> Result { - match &dt.kind { - Some(s_kind) => match s_kind { - r#type::Kind::Bool(_) => Ok(DataType::Boolean), - r#type::Kind::I8(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int8), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt8), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::I16(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int16), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt16), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::I32(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int32), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt32), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::I64(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int64), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt64), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::Fp32(_) => Ok(DataType::Float32), - r#type::Kind::Fp64(_) => Ok(DataType::Float64), - r#type::Kind::Timestamp(ts) => { - // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead - #[allow(deprecated)] - match ts.type_variation_reference { - TIMESTAMP_SECOND_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Second, None)) - } - TIMESTAMP_MILLI_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) - } - TIMESTAMP_MICRO_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) - } - TIMESTAMP_NANO_TYPE_VARIATION_REF => { - Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) - } - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - } - } - r#type::Kind::PrecisionTimestamp(pts) => { - let unit = match pts.precision { - 0 => Ok(TimeUnit::Second), - 3 => Ok(TimeUnit::Millisecond), - 6 => Ok(TimeUnit::Microsecond), - 9 => Ok(TimeUnit::Nanosecond), - p => not_impl_err!( - "Unsupported Substrait precision {p} for PrecisionTimestamp" - ), - }?; - Ok(DataType::Timestamp(unit, None)) - } - r#type::Kind::PrecisionTimestampTz(pts) => { - let unit = match pts.precision { - 0 => Ok(TimeUnit::Second), - 3 => Ok(TimeUnit::Millisecond), - 6 => Ok(TimeUnit::Microsecond), - 9 => Ok(TimeUnit::Nanosecond), - p => not_impl_err!( - "Unsupported Substrait precision {p} for PrecisionTimestampTz" - ), - }?; - Ok(DataType::Timestamp(unit, Some(DEFAULT_TIMEZONE.into()))) - } - r#type::Kind::Date(date) => match date.type_variation_reference { - DATE_32_TYPE_VARIATION_REF => Ok(DataType::Date32), - DATE_64_TYPE_VARIATION_REF => Ok(DataType::Date64), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::Binary(binary) => match binary.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Binary), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeBinary), - VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::BinaryView), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::FixedBinary(fixed) => { - Ok(DataType::FixedSizeBinary(fixed.length)) - } - r#type::Kind::String(string) => match string.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeUtf8), - VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8View), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::List(list) => { - let inner_type = list.r#type.as_ref().ok_or_else(|| { - substrait_datafusion_err!("List type must have inner type") - })?; - let field = Arc::new(Field::new_list_field( - from_substrait_type(consumer, inner_type, dfs_names, name_idx)?, - // We ignore Substrait's nullability here to match to_substrait_literal - // which always creates nullable lists - true, - )); - match list.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::List(field)), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeList(field)), - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - )?, - } - } - r#type::Kind::Map(map) => { - let key_type = map.key.as_ref().ok_or_else(|| { - substrait_datafusion_err!("Map type must have key type") - })?; - let value_type = map.value.as_ref().ok_or_else(|| { - substrait_datafusion_err!("Map type must have value type") - })?; - let key_field = Arc::new(Field::new( - "key", - from_substrait_type(consumer, key_type, dfs_names, name_idx)?, - false, - )); - let value_field = Arc::new(Field::new( - "value", - from_substrait_type(consumer, value_type, dfs_names, name_idx)?, - true, - )); - Ok(DataType::Map( - Arc::new(Field::new_struct( - "entries", - [key_field, value_field], - false, // The inner map field is always non-nullable (Arrow #1697), - )), - false, // whether keys are sorted - )) - } - r#type::Kind::Decimal(d) => match d.type_variation_reference { - DECIMAL_128_TYPE_VARIATION_REF => { - Ok(DataType::Decimal128(d.precision as u8, d.scale as i8)) - } - DECIMAL_256_TYPE_VARIATION_REF => { - Ok(DataType::Decimal256(d.precision as u8, d.scale as i8)) - } - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, - r#type::Kind::IntervalYear(_) => { - Ok(DataType::Interval(IntervalUnit::YearMonth)) - } - r#type::Kind::IntervalDay(_) => Ok(DataType::Interval(IntervalUnit::DayTime)), - r#type::Kind::IntervalCompound(_) => { - Ok(DataType::Interval(IntervalUnit::MonthDayNano)) - } - r#type::Kind::UserDefined(u) => { - if let Ok(data_type) = consumer.consume_user_defined_type(u) { - return Ok(data_type); - } - - // TODO: remove the code below once the producer has been updated - if let Some(name) = consumer.get_extensions().types.get(&u.type_reference) - { - #[allow(deprecated)] - match name.as_ref() { - // Kept for backwards compatibility, producers should use IntervalCompound instead - INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), - _ => not_impl_err!( - "Unsupported Substrait user defined type with ref {} and variation {}", - u.type_reference, - u.type_variation_reference - ), - } - } else { - #[allow(deprecated)] - match u.type_reference { - // Kept for backwards compatibility, producers should use IntervalYear instead - INTERVAL_YEAR_MONTH_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::YearMonth)) - } - // Kept for backwards compatibility, producers should use IntervalDay instead - INTERVAL_DAY_TIME_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::DayTime)) - } - // Kept for backwards compatibility, producers should use IntervalCompound instead - INTERVAL_MONTH_DAY_NANO_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::MonthDayNano)) - } - _ => not_impl_err!( - "Unsupported Substrait user defined type with ref {} and variation {}", - u.type_reference, - u.type_variation_reference - ), - } - } - } - r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type( - consumer, s, dfs_names, name_idx, - )?)), - r#type::Kind::Varchar(_) => Ok(DataType::Utf8), - r#type::Kind::FixedChar(_) => Ok(DataType::Utf8), - _ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"), - }, - _ => not_impl_err!("`None` Substrait kind is not supported"), - } -} - -fn from_substrait_struct_type( - consumer: &impl SubstraitConsumer, - s: &r#type::Struct, - dfs_names: &[String], - name_idx: &mut usize, -) -> Result { - let mut fields = vec![]; - for (i, f) in s.types.iter().enumerate() { - let field = Field::new( - next_struct_field_name(i, dfs_names, name_idx)?, - from_substrait_type(consumer, f, dfs_names, name_idx)?, - true, // We assume everything to be nullable since that's easier than ensuring it matches - ); - fields.push(field); - } - Ok(fields.into()) -} - -fn next_struct_field_name( - column_idx: usize, - dfs_names: &[String], - name_idx: &mut usize, -) -> Result { - if dfs_names.is_empty() { - // If names are not given, create dummy names - // c0, c1, ... align with e.g. SqlToRel::create_named_struct - Ok(format!("c{column_idx}")) - } else { - let name = dfs_names.get(*name_idx).cloned().ok_or_else(|| { - substrait_datafusion_err!("Named schema must contain names for all fields") - })?; - *name_idx += 1; - Ok(name) - } -} - -/// Convert Substrait NamedStruct to DataFusion DFSchemaRef -pub fn from_substrait_named_struct( - consumer: &impl SubstraitConsumer, - base_schema: &NamedStruct, -) -> Result { - let mut name_idx = 0; - let fields = from_substrait_struct_type( - consumer, - base_schema.r#struct.as_ref().ok_or_else(|| { - substrait_datafusion_err!("Named struct must contain a struct") - })?, - &base_schema.names, - &mut name_idx, - ); - if name_idx != base_schema.names.len() { - return substrait_err!( - "Names list must match exactly to nested schema, but found {} uses for {} names", - name_idx, - base_schema.names.len() - ); - } - DFSchema::try_from(Schema::new(fields?)) -} - -fn from_substrait_bound( - bound: &Option, - is_lower: bool, -) -> Result { - match bound { - Some(b) => match &b.kind { - Some(k) => match k { - BoundKind::CurrentRow(SubstraitBound::CurrentRow {}) => { - Ok(WindowFrameBound::CurrentRow) - } - BoundKind::Preceding(SubstraitBound::Preceding { offset }) => { - if *offset <= 0 { - return plan_err!("Preceding bound must be positive"); - } - Ok(WindowFrameBound::Preceding(ScalarValue::UInt64(Some( - *offset as u64, - )))) - } - BoundKind::Following(SubstraitBound::Following { offset }) => { - if *offset <= 0 { - return plan_err!("Following bound must be positive"); - } - Ok(WindowFrameBound::Following(ScalarValue::UInt64(Some( - *offset as u64, - )))) - } - BoundKind::Unbounded(SubstraitBound::Unbounded {}) => { - if is_lower { - Ok(WindowFrameBound::Preceding(ScalarValue::Null)) - } else { - Ok(WindowFrameBound::Following(ScalarValue::Null)) - } - } - }, - None => substrait_err!("WindowFunction missing Substrait Bound kind"), - }, - None => { - if is_lower { - Ok(WindowFrameBound::Preceding(ScalarValue::Null)) - } else { - Ok(WindowFrameBound::Following(ScalarValue::Null)) - } - } - } -} - -pub(crate) fn from_substrait_literal_without_names( - consumer: &impl SubstraitConsumer, - lit: &Literal, -) -> Result { - from_substrait_literal(consumer, lit, &vec![], &mut 0) -} - -fn from_substrait_literal( - consumer: &impl SubstraitConsumer, - lit: &Literal, - dfs_names: &Vec, - name_idx: &mut usize, -) -> Result { - let scalar_value = match &lit.literal_type { - Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)), - Some(LiteralType::I8(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int8(Some(*n as i8)), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt8(Some(*n as u8)), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - }, - Some(LiteralType::I16(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int16(Some(*n as i16)), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt16(Some(*n as u16)), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - }, - Some(LiteralType::I32(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int32(Some(*n)), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt32(Some(*n as u32)), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - }, - Some(LiteralType::I64(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int64(Some(*n)), - UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt64(Some(*n as u64)), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - }, - Some(LiteralType::Fp32(f)) => ScalarValue::Float32(Some(*f)), - Some(LiteralType::Fp64(f)) => ScalarValue::Float64(Some(*f)), - Some(LiteralType::Timestamp(t)) => { - // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead - #[allow(deprecated)] - match lit.type_variation_reference { - TIMESTAMP_SECOND_TYPE_VARIATION_REF => { - ScalarValue::TimestampSecond(Some(*t), None) - } - TIMESTAMP_MILLI_TYPE_VARIATION_REF => { - ScalarValue::TimestampMillisecond(Some(*t), None) - } - TIMESTAMP_MICRO_TYPE_VARIATION_REF => { - ScalarValue::TimestampMicrosecond(Some(*t), None) - } - TIMESTAMP_NANO_TYPE_VARIATION_REF => { - ScalarValue::TimestampNanosecond(Some(*t), None) - } - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - } - } - Some(LiteralType::PrecisionTimestamp(pt)) => match pt.precision { - 0 => ScalarValue::TimestampSecond(Some(pt.value), None), - 3 => ScalarValue::TimestampMillisecond(Some(pt.value), None), - 6 => ScalarValue::TimestampMicrosecond(Some(pt.value), None), - 9 => ScalarValue::TimestampNanosecond(Some(pt.value), None), - p => { - return not_impl_err!( - "Unsupported Substrait precision {p} for PrecisionTimestamp" - ); - } - }, - Some(LiteralType::PrecisionTimestampTz(pt)) => match pt.precision { - 0 => ScalarValue::TimestampSecond( - Some(pt.value), - Some(DEFAULT_TIMEZONE.into()), - ), - 3 => ScalarValue::TimestampMillisecond( - Some(pt.value), - Some(DEFAULT_TIMEZONE.into()), - ), - 6 => ScalarValue::TimestampMicrosecond( - Some(pt.value), - Some(DEFAULT_TIMEZONE.into()), - ), - 9 => ScalarValue::TimestampNanosecond( - Some(pt.value), - Some(DEFAULT_TIMEZONE.into()), - ), - p => { - return not_impl_err!( - "Unsupported Substrait precision {p} for PrecisionTimestamp" - ); - } - }, - Some(LiteralType::Date(d)) => ScalarValue::Date32(Some(*d)), - Some(LiteralType::String(s)) => match lit.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8(Some(s.clone())), - LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeUtf8(Some(s.clone())), - VIEW_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8View(Some(s.clone())), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - }, - Some(LiteralType::Binary(b)) => match lit.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Binary(Some(b.clone())), - LARGE_CONTAINER_TYPE_VARIATION_REF => { - ScalarValue::LargeBinary(Some(b.clone())) - } - VIEW_CONTAINER_TYPE_VARIATION_REF => ScalarValue::BinaryView(Some(b.clone())), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - }, - Some(LiteralType::FixedBinary(b)) => { - ScalarValue::FixedSizeBinary(b.len() as _, Some(b.clone())) - } - Some(LiteralType::Decimal(d)) => { - let value: [u8; 16] = d - .value - .clone() - .try_into() - .or(substrait_err!("Failed to parse decimal value"))?; - let p = d.precision.try_into().map_err(|e| { - substrait_datafusion_err!("Failed to parse decimal precision: {e}") - })?; - let s = d.scale.try_into().map_err(|e| { - substrait_datafusion_err!("Failed to parse decimal scale: {e}") - })?; - ScalarValue::Decimal128(Some(i128::from_le_bytes(value)), p, s) - } - Some(LiteralType::List(l)) => { - // Each element should start the name index from the same value, then we increase it - // once at the end - let mut element_name_idx = *name_idx; - let elements = l - .values - .iter() - .map(|el| { - element_name_idx = *name_idx; - from_substrait_literal(consumer, el, dfs_names, &mut element_name_idx) - }) - .collect::>>()?; - *name_idx = element_name_idx; - if elements.is_empty() { - return substrait_err!( - "Empty list must be encoded as EmptyList literal type, not List" - ); - } - let element_type = elements[0].data_type(); - match lit.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::List( - ScalarValue::new_list_nullable(elements.as_slice(), &element_type), - ), - LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( - ScalarValue::new_large_list(elements.as_slice(), &element_type), - ), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - } - } - Some(LiteralType::EmptyList(l)) => { - let element_type = from_substrait_type( - consumer, - l.r#type.clone().unwrap().as_ref(), - dfs_names, - name_idx, - )?; - match lit.type_variation_reference { - DEFAULT_CONTAINER_TYPE_VARIATION_REF => { - ScalarValue::List(ScalarValue::new_list_nullable(&[], &element_type)) - } - LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( - ScalarValue::new_large_list(&[], &element_type), - ), - others => { - return substrait_err!("Unknown type variation reference {others}"); - } - } - } - Some(LiteralType::Map(m)) => { - // Each entry should start the name index from the same value, then we increase it - // once at the end - let mut entry_name_idx = *name_idx; - let entries = m - .key_values - .iter() - .map(|kv| { - entry_name_idx = *name_idx; - let key_sv = from_substrait_literal( - consumer, - kv.key.as_ref().unwrap(), - dfs_names, - &mut entry_name_idx, - )?; - let value_sv = from_substrait_literal( - consumer, - kv.value.as_ref().unwrap(), - dfs_names, - &mut entry_name_idx, - )?; - ScalarStructBuilder::new() - .with_scalar(Field::new("key", key_sv.data_type(), false), key_sv) - .with_scalar( - Field::new("value", value_sv.data_type(), true), - value_sv, - ) - .build() - }) - .collect::>>()?; - *name_idx = entry_name_idx; - - if entries.is_empty() { - return substrait_err!( - "Empty map must be encoded as EmptyMap literal type, not Map" - ); - } - - ScalarValue::Map(Arc::new(MapArray::new( - Arc::new(Field::new("entries", entries[0].data_type(), false)), - OffsetBuffer::new(vec![0, entries.len() as i32].into()), - ScalarValue::iter_to_array(entries)?.as_struct().to_owned(), - None, - false, - ))) - } - Some(LiteralType::EmptyMap(m)) => { - let key = match &m.key { - Some(k) => Ok(k), - _ => plan_err!("Missing key type for empty map"), - }?; - let value = match &m.value { - Some(v) => Ok(v), - _ => plan_err!("Missing value type for empty map"), - }?; - let key_type = from_substrait_type(consumer, key, dfs_names, name_idx)?; - let value_type = from_substrait_type(consumer, value, dfs_names, name_idx)?; - - // new_empty_array on a MapType creates a too empty array - // We want it to contain an empty struct array to align with an empty MapBuilder one - let entries = Field::new_struct( - "entries", - vec![ - Field::new("key", key_type, false), - Field::new("value", value_type, true), - ], - false, - ); - let struct_array = - new_empty_array(entries.data_type()).as_struct().to_owned(); - ScalarValue::Map(Arc::new(MapArray::new( - Arc::new(entries), - OffsetBuffer::new(vec![0, 0].into()), - struct_array, - None, - false, - ))) - } - Some(LiteralType::Struct(s)) => { - let mut builder = ScalarStructBuilder::new(); - for (i, field) in s.fields.iter().enumerate() { - let name = next_struct_field_name(i, dfs_names, name_idx)?; - let sv = from_substrait_literal(consumer, field, dfs_names, name_idx)?; - // We assume everything to be nullable, since Arrow's strict about things matching - // and it's hard to match otherwise. - builder = builder.with_scalar(Field::new(name, sv.data_type(), true), sv); - } - builder.build()? - } - Some(LiteralType::Null(null_type)) => { - let data_type = - from_substrait_type(consumer, null_type, dfs_names, name_idx)?; - ScalarValue::try_from(&data_type)? - } - Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond { - days, - seconds, - subseconds, - precision_mode, - })) => { - use interval_day_to_second::PrecisionMode; - // DF only supports millisecond precision, so for any more granular type we lose precision - let milliseconds = match precision_mode { - Some(PrecisionMode::Microseconds(ms)) => ms / 1000, - None => - if *subseconds != 0 { - return substrait_err!("Cannot set subseconds field of IntervalDayToSecond without setting precision"); - } else { - 0_i32 - } - Some(PrecisionMode::Precision(0)) => *subseconds as i32 * 1000, - Some(PrecisionMode::Precision(3)) => *subseconds as i32, - Some(PrecisionMode::Precision(6)) => (subseconds / 1000) as i32, - Some(PrecisionMode::Precision(9)) => (subseconds / 1000 / 1000) as i32, - _ => { - return not_impl_err!( - "Unsupported Substrait interval day to second precision mode: {precision_mode:?}") - } - }; - - ScalarValue::new_interval_dt(*days, (seconds * 1000) + milliseconds) - } - Some(LiteralType::IntervalYearToMonth(IntervalYearToMonth { years, months })) => { - ScalarValue::new_interval_ym(*years, *months) - } - Some(LiteralType::IntervalCompound(IntervalCompound { - interval_year_to_month, - interval_day_to_second, - })) => match (interval_year_to_month, interval_day_to_second) { - ( - Some(IntervalYearToMonth { years, months }), - Some(IntervalDayToSecond { - days, - seconds, - subseconds, - precision_mode: - Some(interval_day_to_second::PrecisionMode::Precision(p)), - }), - ) => { - if *p < 0 || *p > 9 { - return plan_err!( - "Unsupported Substrait interval day to second precision: {}", - p - ); - } - let nanos = *subseconds * i64::pow(10, (9 - p) as u32); - ScalarValue::new_interval_mdn( - *years * 12 + months, - *days, - *seconds as i64 * NANOSECONDS + nanos, - ) - } - _ => return plan_err!("Substrait compound interval missing components"), - }, - Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), - Some(LiteralType::UserDefined(user_defined)) => { - if let Ok(value) = consumer.consume_user_defined_literal(user_defined) { - return Ok(value); - } - - // TODO: remove the code below once the producer has been updated - - // Helper function to prevent duplicating this code - can be inlined once the non-extension path is removed - let interval_month_day_nano = - |user_defined: &proto::expression::literal::UserDefined| -> Result { - let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { - return substrait_err!("Interval month day nano value is empty"); - }; - let value_slice: [u8; 16] = - (*raw_val.value).try_into().map_err(|_| { - substrait_datafusion_err!( - "Failed to parse interval month day nano value" - ) - })?; - let months = - i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); - let days = i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); - let nanoseconds = - i64::from_le_bytes(value_slice[8..16].try_into().unwrap()); - Ok(ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNano { - months, - days, - nanoseconds, - }, - ))) - }; - - if let Some(name) = consumer - .get_extensions() - .types - .get(&user_defined.type_reference) - { - match name.as_ref() { - // Kept for backwards compatibility - producers should use IntervalCompound instead - #[allow(deprecated)] - INTERVAL_MONTH_DAY_NANO_TYPE_NAME => { - interval_month_day_nano(user_defined)? - } - _ => { - return not_impl_err!( - "Unsupported Substrait user defined type with ref {} and name {}", - user_defined.type_reference, - name - ) - } - } - } else { - #[allow(deprecated)] - match user_defined.type_reference { - // Kept for backwards compatibility, producers should useIntervalYearToMonth instead - INTERVAL_YEAR_MONTH_TYPE_REF => { - let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { - return substrait_err!("Interval year month value is empty"); - }; - let value_slice: [u8; 4] = - (*raw_val.value).try_into().map_err(|_| { - substrait_datafusion_err!( - "Failed to parse interval year month value" - ) - })?; - ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes( - value_slice, - ))) - } - // Kept for backwards compatibility, producers should useIntervalDayToSecond instead - INTERVAL_DAY_TIME_TYPE_REF => { - let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { - return substrait_err!("Interval day time value is empty"); - }; - let value_slice: [u8; 8] = - (*raw_val.value).try_into().map_err(|_| { - substrait_datafusion_err!( - "Failed to parse interval day time value" - ) - })?; - let days = - i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); - let milliseconds = - i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); - ScalarValue::IntervalDayTime(Some(IntervalDayTime { - days, - milliseconds, - })) - } - // Kept for backwards compatibility, producers should useIntervalCompound instead - INTERVAL_MONTH_DAY_NANO_TYPE_REF => { - interval_month_day_nano(user_defined)? - } - _ => { - return not_impl_err!( - "Unsupported Substrait user defined type literal with ref {}", - user_defined.type_reference - ) - } - } - } - } - _ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type), - }; - - Ok(scalar_value) -} - -#[allow(deprecated)] -async fn from_substrait_grouping( - consumer: &impl SubstraitConsumer, - grouping: &Grouping, - expressions: &[Expr], - input_schema: &DFSchemaRef, -) -> Result> { - let mut group_exprs = vec![]; - if !grouping.grouping_expressions.is_empty() { - for e in &grouping.grouping_expressions { - let expr = consumer.consume_expression(e, input_schema).await?; - group_exprs.push(expr); - } - return Ok(group_exprs); - } - for idx in &grouping.expression_references { - let e = &expressions[*idx as usize]; - group_exprs.push(e.clone()); - } - Ok(group_exprs) -} - -fn from_substrait_field_reference( - field_ref: &FieldReference, - input_schema: &DFSchema, -) -> Result { - match &field_ref.reference_type { - Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { - Some(StructField(x)) => match &x.child.as_ref() { - Some(_) => not_impl_err!( - "Direct reference StructField with child is not supported" - ), - None => Ok(Expr::Column(Column::from( - input_schema.qualified_field(x.field as usize), - ))), - }, - _ => not_impl_err!( - "Direct reference with types other than StructField is not supported" - ), - }, - _ => not_impl_err!("unsupported field ref type"), - } -} - -/// Build [`Expr`] from its name and required inputs. -struct BuiltinExprBuilder { - expr_name: String, -} - -impl BuiltinExprBuilder { - pub fn try_from_name(name: &str) -> Option { - match name { - "not" | "like" | "ilike" | "is_null" | "is_not_null" | "is_true" - | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" - | "is_not_unknown" | "negative" | "negate" => Some(Self { - expr_name: name.to_string(), - }), - _ => None, - } - } - - pub async fn build( - self, - consumer: &impl SubstraitConsumer, - f: &ScalarFunction, - input_schema: &DFSchema, - ) -> Result { - match self.expr_name.as_str() { - "like" => Self::build_like_expr(consumer, false, f, input_schema).await, - "ilike" => Self::build_like_expr(consumer, true, f, input_schema).await, - "not" | "negative" | "negate" | "is_null" | "is_not_null" | "is_true" - | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" - | "is_not_unknown" => { - Self::build_unary_expr(consumer, &self.expr_name, f, input_schema).await - } - _ => { - not_impl_err!("Unsupported builtin expression: {}", self.expr_name) - } - } - } - - async fn build_unary_expr( - consumer: &impl SubstraitConsumer, - fn_name: &str, - f: &ScalarFunction, - input_schema: &DFSchema, - ) -> Result { - if f.arguments.len() != 1 { - return substrait_err!("Expect one argument for {fn_name} expr"); - } - let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { - return substrait_err!("Invalid arguments type for {fn_name} expr"); - }; - let arg = consumer - .consume_expression(expr_substrait, input_schema) - .await?; - let arg = Box::new(arg); - - let expr = match fn_name { - "not" => Expr::Not(arg), - "negative" | "negate" => Expr::Negative(arg), - "is_null" => Expr::IsNull(arg), - "is_not_null" => Expr::IsNotNull(arg), - "is_true" => Expr::IsTrue(arg), - "is_false" => Expr::IsFalse(arg), - "is_not_true" => Expr::IsNotTrue(arg), - "is_not_false" => Expr::IsNotFalse(arg), - "is_unknown" => Expr::IsUnknown(arg), - "is_not_unknown" => Expr::IsNotUnknown(arg), - _ => return not_impl_err!("Unsupported builtin expression: {}", fn_name), - }; - - Ok(expr) - } - - async fn build_like_expr( - consumer: &impl SubstraitConsumer, - case_insensitive: bool, - f: &ScalarFunction, - input_schema: &DFSchema, - ) -> Result { - let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; - if f.arguments.len() != 2 && f.arguments.len() != 3 { - return substrait_err!("Expect two or three arguments for `{fn_name}` expr"); - } - - let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { - return substrait_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let expr = consumer - .consume_expression(expr_substrait, input_schema) - .await?; - let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { - return substrait_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let pattern = consumer - .consume_expression(pattern_substrait, input_schema) - .await?; - - // Default case: escape character is Literal(Utf8(None)) - let escape_char = if f.arguments.len() == 3 { - let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type - else { - return substrait_err!("Invalid arguments type for `{fn_name}` expr"); - }; - - let escape_char_expr = consumer - .consume_expression(escape_char_substrait, input_schema) - .await?; - - match escape_char_expr { - Expr::Literal(ScalarValue::Utf8(escape_char_string)) => { - // Convert Option to Option - escape_char_string.and_then(|s| s.chars().next()) - } - _ => { - return substrait_err!( - "Expect Utf8 literal for escape char, but found {escape_char_expr:?}" - ) - } - } - } else { - None - }; - - Ok(Expr::Like(Like { - negated: false, - expr: Box::new(expr), - pattern: Box::new(pattern), - escape_char, - case_insensitive, - })) - } -} - -#[cfg(test)] -mod test { - use crate::extensions::Extensions; - use crate::logical_plan::consumer::{ - from_substrait_literal_without_names, from_substrait_rex, - DefaultSubstraitConsumer, - }; - use arrow::array::types::IntervalMonthDayNano; - use datafusion::arrow; - use datafusion::common::DFSchema; - use datafusion::error::Result; - use datafusion::execution::SessionState; - use datafusion::prelude::{Expr, SessionContext}; - use datafusion::scalar::ScalarValue; - use std::sync::LazyLock; - use substrait::proto::expression::literal::{ - interval_day_to_second, IntervalCompound, IntervalDayToSecond, - IntervalYearToMonth, LiteralType, - }; - use substrait::proto::expression::window_function::BoundsType; - use substrait::proto::expression::Literal; - - static TEST_SESSION_STATE: LazyLock = - LazyLock::new(|| SessionContext::default().state()); - static TEST_EXTENSIONS: LazyLock = LazyLock::new(Extensions::default); - fn test_consumer() -> DefaultSubstraitConsumer<'static> { - let extensions = &TEST_EXTENSIONS; - let state = &TEST_SESSION_STATE; - DefaultSubstraitConsumer::new(extensions, state) - } - - #[test] - fn interval_compound_different_precision() -> Result<()> { - // DF producer (and thus roundtrip) always uses precision = 9, - // this test exists to test with some other value. - let substrait = Literal { - nullable: false, - type_variation_reference: 0, - literal_type: Some(LiteralType::IntervalCompound(IntervalCompound { - interval_year_to_month: Some(IntervalYearToMonth { - years: 1, - months: 2, - }), - interval_day_to_second: Some(IntervalDayToSecond { - days: 3, - seconds: 4, - subseconds: 5, - precision_mode: Some( - interval_day_to_second::PrecisionMode::Precision(6), - ), - }), - })), - }; - - let consumer = test_consumer(); - assert_eq!( - from_substrait_literal_without_names(&consumer, &substrait)?, - ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { - months: 14, - days: 3, - nanoseconds: 4_000_005_000 - })) - ); - - Ok(()) - } - - #[tokio::test] - async fn window_function_with_range_unit_and_no_order_by() -> Result<()> { - let substrait = substrait::proto::Expression { - rex_type: Some(substrait::proto::expression::RexType::WindowFunction( - substrait::proto::expression::WindowFunction { - function_reference: 0, - bounds_type: BoundsType::Range as i32, - sorts: vec![], - ..Default::default() - }, - )), - }; - - let mut consumer = test_consumer(); - - // Just registering a single function (index 0) so that the plan - // does not throw a "function not found" error. - let mut extensions = Extensions::default(); - extensions.register_function("count".to_string()); - consumer.extensions = &extensions; - - match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? { - Expr::WindowFunction(window_function) => { - assert_eq!(window_function.params.order_by.len(), 1) - } - _ => panic!("expr was not a WindowFunction"), - }; - - Ok(()) - } - - #[tokio::test] - async fn window_function_with_count() -> Result<()> { - let substrait = substrait::proto::Expression { - rex_type: Some(substrait::proto::expression::RexType::WindowFunction( - substrait::proto::expression::WindowFunction { - function_reference: 0, - ..Default::default() - }, - )), - }; - - let mut consumer = test_consumer(); - - let mut extensions = Extensions::default(); - extensions.register_function("count".to_string()); - consumer.extensions = &extensions; - - match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? { - Expr::WindowFunction(window_function) => { - assert_eq!(window_function.params.args.len(), 1) - } - _ => panic!("expr was not a WindowFunction"), - }; - - Ok(()) - } -} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/aggregate_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/aggregate_function.rs new file mode 100644 index 000000000000..114fe1e7aecd --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/aggregate_function.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::{ + from_substrait_func_args, substrait_fun_name, SubstraitConsumer, +}; +use datafusion::common::{not_impl_datafusion_err, plan_err, DFSchema, ScalarValue}; +use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::{expr, Expr, SortExpr}; +use std::sync::Arc; +use substrait::proto::AggregateFunction; + +/// Convert Substrait AggregateFunction to DataFusion Expr +pub async fn from_substrait_agg_func( + consumer: &impl SubstraitConsumer, + f: &AggregateFunction, + input_schema: &DFSchema, + filter: Option>, + order_by: Option>, + distinct: bool, +) -> datafusion::common::Result> { + let Some(fn_signature) = consumer + .get_extensions() + .functions + .get(&f.function_reference) + else { + return plan_err!( + "Aggregate function not registered: function anchor = {:?}", + f.function_reference + ); + }; + + let fn_name = substrait_fun_name(fn_signature); + let udaf = consumer.get_function_registry().udaf(fn_name); + let udaf = udaf.map_err(|_| { + not_impl_datafusion_err!( + "Aggregate function {} is not supported: function anchor = {:?}", + fn_signature, + f.function_reference + ) + })?; + + let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?; + + // Datafusion does not support aggregate functions with no arguments, so + // we inject a dummy argument that does not affect the query, but allows + // us to bypass this limitation. + let args = if udaf.name() == "count" && args.is_empty() { + vec![Expr::Literal(ScalarValue::Int64(Some(1)), None)] + } else { + args + }; + + Ok(Arc::new(Expr::AggregateFunction( + expr::AggregateFunction::new_udf(udaf, args, distinct, filter, order_by, None), + ))) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs new file mode 100644 index 000000000000..5e8d3d93065f --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::types::from_substrait_type_without_names; +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{substrait_err, DFSchema}; +use datafusion::logical_expr::{Cast, Expr, TryCast}; +use substrait::proto::expression as substrait_expression; +use substrait::proto::expression::cast::FailureBehavior::ReturnNull; + +pub async fn from_cast( + consumer: &impl SubstraitConsumer, + cast: &substrait_expression::Cast, + input_schema: &DFSchema, +) -> datafusion::common::Result { + match cast.r#type.as_ref() { + Some(output_type) => { + let input_expr = Box::new( + consumer + .consume_expression( + cast.input.as_ref().unwrap().as_ref(), + input_schema, + ) + .await?, + ); + let data_type = from_substrait_type_without_names(consumer, output_type)?; + if cast.failure_behavior() == ReturnNull { + Ok(Expr::TryCast(TryCast::new(input_expr, data_type))) + } else { + Ok(Expr::Cast(Cast::new(input_expr, data_type))) + } + } + None => substrait_err!("Cast expression without output type is not allowed"), + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs b/datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs new file mode 100644 index 000000000000..90b5b6418149 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{not_impl_err, Column, DFSchema}; +use datafusion::logical_expr::Expr; +use substrait::proto::expression::field_reference::ReferenceType::DirectReference; +use substrait::proto::expression::reference_segment::ReferenceType::StructField; +use substrait::proto::expression::FieldReference; + +pub async fn from_field_reference( + _consumer: &impl SubstraitConsumer, + field_ref: &FieldReference, + input_schema: &DFSchema, +) -> datafusion::common::Result { + from_substrait_field_reference(field_ref, input_schema) +} + +pub(crate) fn from_substrait_field_reference( + field_ref: &FieldReference, + input_schema: &DFSchema, +) -> datafusion::common::Result { + match &field_ref.reference_type { + Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { + Some(StructField(x)) => match &x.child.as_ref() { + Some(_) => not_impl_err!( + "Direct reference StructField with child is not supported" + ), + None => Ok(Expr::Column(Column::from( + input_schema.qualified_field(x.field as usize), + ))), + }, + _ => not_impl_err!( + "Direct reference with types other than StructField is not supported" + ), + }, + _ => not_impl_err!("unsupported field ref type"), + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/function_arguments.rs b/datafusion/substrait/src/logical_plan/consumer/expr/function_arguments.rs new file mode 100644 index 000000000000..0b610b61b1de --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/function_arguments.rs @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{not_impl_err, DFSchema}; +use datafusion::logical_expr::Expr; +use substrait::proto::function_argument::ArgType; +use substrait::proto::FunctionArgument; + +/// Convert Substrait FunctionArguments to DataFusion Exprs +pub async fn from_substrait_func_args( + consumer: &impl SubstraitConsumer, + arguments: &Vec, + input_schema: &DFSchema, +) -> datafusion::common::Result> { + let mut args: Vec = vec![]; + for arg in arguments { + let arg_expr = match &arg.arg_type { + Some(ArgType::Value(e)) => consumer.consume_expression(e, input_schema).await, + _ => not_impl_err!("Function argument non-Value type not supported"), + }; + args.push(arg_expr?); + } + Ok(args) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/if_then.rs b/datafusion/substrait/src/logical_plan/consumer/expr/if_then.rs new file mode 100644 index 000000000000..c4cc6c2fcd24 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/if_then.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::DFSchema; +use datafusion::logical_expr::{Case, Expr}; +use substrait::proto::expression::IfThen; + +pub async fn from_if_then( + consumer: &impl SubstraitConsumer, + if_then: &IfThen, + input_schema: &DFSchema, +) -> datafusion::common::Result { + // Parse `ifs` + // If the first element does not have a `then` part, then we can assume it's a base expression + let mut when_then_expr: Vec<(Box, Box)> = vec![]; + let mut expr = None; + for (i, if_expr) in if_then.ifs.iter().enumerate() { + if i == 0 { + // Check if the first element is type base expression + if if_expr.then.is_none() { + expr = Some(Box::new( + consumer + .consume_expression(if_expr.r#if.as_ref().unwrap(), input_schema) + .await?, + )); + continue; + } + } + when_then_expr.push(( + Box::new( + consumer + .consume_expression(if_expr.r#if.as_ref().unwrap(), input_schema) + .await?, + ), + Box::new( + consumer + .consume_expression(if_expr.then.as_ref().unwrap(), input_schema) + .await?, + ), + )); + } + // Parse `else` + let else_expr = match &if_then.r#else { + Some(e) => Some(Box::new( + consumer.consume_expression(e, input_schema).await?, + )), + None => None, + }; + Ok(Expr::Case(Case { + expr, + when_then_expr, + else_expr, + })) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs b/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs new file mode 100644 index 000000000000..d054e5267554 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/literal.rs @@ -0,0 +1,547 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::types::from_substrait_type; +use crate::logical_plan::consumer::utils::{next_struct_field_name, DEFAULT_TIMEZONE}; +use crate::logical_plan::consumer::SubstraitConsumer; +#[allow(deprecated)] +use crate::variation_const::{ + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, + INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, + TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, + TIMESTAMP_SECOND_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, +}; +use datafusion::arrow::array::{new_empty_array, AsArray, MapArray}; +use datafusion::arrow::buffer::OffsetBuffer; +use datafusion::arrow::datatypes::{Field, IntervalDayTime, IntervalMonthDayNano}; +use datafusion::arrow::temporal_conversions::NANOSECONDS; +use datafusion::common::scalar::ScalarStructBuilder; +use datafusion::common::{ + not_impl_err, plan_err, substrait_datafusion_err, substrait_err, ScalarValue, +}; +use datafusion::logical_expr::Expr; +use std::sync::Arc; +use substrait::proto; +use substrait::proto::expression::literal::user_defined::Val; +use substrait::proto::expression::literal::{ + interval_day_to_second, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, + LiteralType, +}; +use substrait::proto::expression::Literal; + +pub async fn from_literal( + consumer: &impl SubstraitConsumer, + expr: &Literal, +) -> datafusion::common::Result { + let scalar_value = from_substrait_literal_without_names(consumer, expr)?; + Ok(Expr::Literal(scalar_value, None)) +} + +pub(crate) fn from_substrait_literal_without_names( + consumer: &impl SubstraitConsumer, + lit: &Literal, +) -> datafusion::common::Result { + from_substrait_literal(consumer, lit, &vec![], &mut 0) +} + +pub(crate) fn from_substrait_literal( + consumer: &impl SubstraitConsumer, + lit: &Literal, + dfs_names: &Vec, + name_idx: &mut usize, +) -> datafusion::common::Result { + let scalar_value = match &lit.literal_type { + Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)), + Some(LiteralType::I8(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int8(Some(*n as i8)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt8(Some(*n as u8)), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + Some(LiteralType::I16(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int16(Some(*n as i16)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt16(Some(*n as u16)), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + Some(LiteralType::I32(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int32(Some(*n)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt32(Some(*n as u32)), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + Some(LiteralType::I64(n)) => match lit.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int64(Some(*n)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt64(Some(*n as u64)), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + Some(LiteralType::Fp32(f)) => ScalarValue::Float32(Some(*f)), + Some(LiteralType::Fp64(f)) => ScalarValue::Float64(Some(*f)), + Some(LiteralType::Timestamp(t)) => { + // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead + #[allow(deprecated)] + match lit.type_variation_reference { + TIMESTAMP_SECOND_TYPE_VARIATION_REF => { + ScalarValue::TimestampSecond(Some(*t), None) + } + TIMESTAMP_MILLI_TYPE_VARIATION_REF => { + ScalarValue::TimestampMillisecond(Some(*t), None) + } + TIMESTAMP_MICRO_TYPE_VARIATION_REF => { + ScalarValue::TimestampMicrosecond(Some(*t), None) + } + TIMESTAMP_NANO_TYPE_VARIATION_REF => { + ScalarValue::TimestampNanosecond(Some(*t), None) + } + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + } + } + Some(LiteralType::PrecisionTimestamp(pt)) => match pt.precision { + 0 => ScalarValue::TimestampSecond(Some(pt.value), None), + 3 => ScalarValue::TimestampMillisecond(Some(pt.value), None), + 6 => ScalarValue::TimestampMicrosecond(Some(pt.value), None), + 9 => ScalarValue::TimestampNanosecond(Some(pt.value), None), + p => { + return not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestamp" + ); + } + }, + Some(LiteralType::PrecisionTimestampTz(pt)) => match pt.precision { + 0 => ScalarValue::TimestampSecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + 3 => ScalarValue::TimestampMillisecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + 6 => ScalarValue::TimestampMicrosecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + 9 => ScalarValue::TimestampNanosecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + p => { + return not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestamp" + ); + } + }, + Some(LiteralType::Date(d)) => ScalarValue::Date32(Some(*d)), + Some(LiteralType::String(s)) => match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8(Some(s.clone())), + LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeUtf8(Some(s.clone())), + VIEW_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8View(Some(s.clone())), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + Some(LiteralType::Binary(b)) => match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Binary(Some(b.clone())), + LARGE_CONTAINER_TYPE_VARIATION_REF => { + ScalarValue::LargeBinary(Some(b.clone())) + } + VIEW_CONTAINER_TYPE_VARIATION_REF => ScalarValue::BinaryView(Some(b.clone())), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + }, + Some(LiteralType::FixedBinary(b)) => { + ScalarValue::FixedSizeBinary(b.len() as _, Some(b.clone())) + } + Some(LiteralType::Decimal(d)) => { + let value: [u8; 16] = d + .value + .clone() + .try_into() + .or(substrait_err!("Failed to parse decimal value"))?; + let p = d.precision.try_into().map_err(|e| { + substrait_datafusion_err!("Failed to parse decimal precision: {e}") + })?; + let s = d.scale.try_into().map_err(|e| { + substrait_datafusion_err!("Failed to parse decimal scale: {e}") + })?; + ScalarValue::Decimal128(Some(i128::from_le_bytes(value)), p, s) + } + Some(LiteralType::List(l)) => { + // Each element should start the name index from the same value, then we increase it + // once at the end + let mut element_name_idx = *name_idx; + let elements = l + .values + .iter() + .map(|el| { + element_name_idx = *name_idx; + from_substrait_literal(consumer, el, dfs_names, &mut element_name_idx) + }) + .collect::>>()?; + *name_idx = element_name_idx; + if elements.is_empty() { + return substrait_err!( + "Empty list must be encoded as EmptyList literal type, not List" + ); + } + let element_type = elements[0].data_type(); + match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::List( + ScalarValue::new_list_nullable(elements.as_slice(), &element_type), + ), + LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( + ScalarValue::new_large_list(elements.as_slice(), &element_type), + ), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + } + } + Some(LiteralType::EmptyList(l)) => { + let element_type = from_substrait_type( + consumer, + l.r#type.clone().unwrap().as_ref(), + dfs_names, + name_idx, + )?; + match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => { + ScalarValue::List(ScalarValue::new_list_nullable(&[], &element_type)) + } + LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( + ScalarValue::new_large_list(&[], &element_type), + ), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + } + } + Some(LiteralType::Map(m)) => { + // Each entry should start the name index from the same value, then we increase it + // once at the end + let mut entry_name_idx = *name_idx; + let entries = m + .key_values + .iter() + .map(|kv| { + entry_name_idx = *name_idx; + let key_sv = from_substrait_literal( + consumer, + kv.key.as_ref().unwrap(), + dfs_names, + &mut entry_name_idx, + )?; + let value_sv = from_substrait_literal( + consumer, + kv.value.as_ref().unwrap(), + dfs_names, + &mut entry_name_idx, + )?; + ScalarStructBuilder::new() + .with_scalar(Field::new("key", key_sv.data_type(), false), key_sv) + .with_scalar( + Field::new("value", value_sv.data_type(), true), + value_sv, + ) + .build() + }) + .collect::>>()?; + *name_idx = entry_name_idx; + + if entries.is_empty() { + return substrait_err!( + "Empty map must be encoded as EmptyMap literal type, not Map" + ); + } + + ScalarValue::Map(Arc::new(MapArray::new( + Arc::new(Field::new("entries", entries[0].data_type(), false)), + OffsetBuffer::new(vec![0, entries.len() as i32].into()), + ScalarValue::iter_to_array(entries)?.as_struct().to_owned(), + None, + false, + ))) + } + Some(LiteralType::EmptyMap(m)) => { + let key = match &m.key { + Some(k) => Ok(k), + _ => plan_err!("Missing key type for empty map"), + }?; + let value = match &m.value { + Some(v) => Ok(v), + _ => plan_err!("Missing value type for empty map"), + }?; + let key_type = from_substrait_type(consumer, key, dfs_names, name_idx)?; + let value_type = from_substrait_type(consumer, value, dfs_names, name_idx)?; + + // new_empty_array on a MapType creates a too empty array + // We want it to contain an empty struct array to align with an empty MapBuilder one + let entries = Field::new_struct( + "entries", + vec![ + Field::new("key", key_type, false), + Field::new("value", value_type, true), + ], + false, + ); + let struct_array = + new_empty_array(entries.data_type()).as_struct().to_owned(); + ScalarValue::Map(Arc::new(MapArray::new( + Arc::new(entries), + OffsetBuffer::new(vec![0, 0].into()), + struct_array, + None, + false, + ))) + } + Some(LiteralType::Struct(s)) => { + let mut builder = ScalarStructBuilder::new(); + for (i, field) in s.fields.iter().enumerate() { + let name = next_struct_field_name(i, dfs_names, name_idx)?; + let sv = from_substrait_literal(consumer, field, dfs_names, name_idx)?; + // We assume everything to be nullable, since Arrow's strict about things matching + // and it's hard to match otherwise. + builder = builder.with_scalar(Field::new(name, sv.data_type(), true), sv); + } + builder.build()? + } + Some(LiteralType::Null(null_type)) => { + let data_type = + from_substrait_type(consumer, null_type, dfs_names, name_idx)?; + ScalarValue::try_from(&data_type)? + } + Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond { + days, + seconds, + subseconds, + precision_mode, + })) => { + use interval_day_to_second::PrecisionMode; + // DF only supports millisecond precision, so for any more granular type we lose precision + let milliseconds = match precision_mode { + Some(PrecisionMode::Microseconds(ms)) => ms / 1000, + None => + if *subseconds != 0 { + return substrait_err!("Cannot set subseconds field of IntervalDayToSecond without setting precision"); + } else { + 0_i32 + } + Some(PrecisionMode::Precision(0)) => *subseconds as i32 * 1000, + Some(PrecisionMode::Precision(3)) => *subseconds as i32, + Some(PrecisionMode::Precision(6)) => (subseconds / 1000) as i32, + Some(PrecisionMode::Precision(9)) => (subseconds / 1000 / 1000) as i32, + _ => { + return not_impl_err!( + "Unsupported Substrait interval day to second precision mode: {precision_mode:?}") + } + }; + + ScalarValue::new_interval_dt(*days, (seconds * 1000) + milliseconds) + } + Some(LiteralType::IntervalYearToMonth(IntervalYearToMonth { years, months })) => { + ScalarValue::new_interval_ym(*years, *months) + } + Some(LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month, + interval_day_to_second, + })) => match (interval_year_to_month, interval_day_to_second) { + ( + Some(IntervalYearToMonth { years, months }), + Some(IntervalDayToSecond { + days, + seconds, + subseconds, + precision_mode: + Some(interval_day_to_second::PrecisionMode::Precision(p)), + }), + ) => { + if *p < 0 || *p > 9 { + return plan_err!( + "Unsupported Substrait interval day to second precision: {}", + p + ); + } + let nanos = *subseconds * i64::pow(10, (9 - p) as u32); + ScalarValue::new_interval_mdn( + *years * 12 + months, + *days, + *seconds as i64 * NANOSECONDS + nanos, + ) + } + _ => return plan_err!("Substrait compound interval missing components"), + }, + Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), + Some(LiteralType::UserDefined(user_defined)) => { + if let Ok(value) = consumer.consume_user_defined_literal(user_defined) { + return Ok(value); + } + + // TODO: remove the code below once the producer has been updated + + // Helper function to prevent duplicating this code - can be inlined once the non-extension path is removed + let interval_month_day_nano = + |user_defined: &proto::expression::literal::UserDefined| -> datafusion::common::Result { + let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { + return substrait_err!("Interval month day nano value is empty"); + }; + let value_slice: [u8; 16] = + (*raw_val.value).try_into().map_err(|_| { + substrait_datafusion_err!( + "Failed to parse interval month day nano value" + ) + })?; + let months = + i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); + let days = i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); + let nanoseconds = + i64::from_le_bytes(value_slice[8..16].try_into().unwrap()); + Ok(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano { + months, + days, + nanoseconds, + }, + ))) + }; + + if let Some(name) = consumer + .get_extensions() + .types + .get(&user_defined.type_reference) + { + match name.as_ref() { + // Kept for backwards compatibility - producers should use IntervalCompound instead + #[allow(deprecated)] + INTERVAL_MONTH_DAY_NANO_TYPE_NAME => { + interval_month_day_nano(user_defined)? + } + _ => { + return not_impl_err!( + "Unsupported Substrait user defined type with ref {} and name {}", + user_defined.type_reference, + name + ) + } + } + } else { + #[allow(deprecated)] + match user_defined.type_reference { + // Kept for backwards compatibility, producers should useIntervalYearToMonth instead + INTERVAL_YEAR_MONTH_TYPE_REF => { + let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { + return substrait_err!("Interval year month value is empty"); + }; + let value_slice: [u8; 4] = + (*raw_val.value).try_into().map_err(|_| { + substrait_datafusion_err!( + "Failed to parse interval year month value" + ) + })?; + ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes( + value_slice, + ))) + } + // Kept for backwards compatibility, producers should useIntervalDayToSecond instead + INTERVAL_DAY_TIME_TYPE_REF => { + let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { + return substrait_err!("Interval day time value is empty"); + }; + let value_slice: [u8; 8] = + (*raw_val.value).try_into().map_err(|_| { + substrait_datafusion_err!( + "Failed to parse interval day time value" + ) + })?; + let days = + i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); + let milliseconds = + i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days, + milliseconds, + })) + } + // Kept for backwards compatibility, producers should useIntervalCompound instead + INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + interval_month_day_nano(user_defined)? + } + _ => { + return not_impl_err!( + "Unsupported Substrait user defined type literal with ref {}", + user_defined.type_reference + ) + } + } + } + } + _ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type), + }; + + Ok(scalar_value) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::consumer::utils::tests::test_consumer; + + #[test] + fn interval_compound_different_precision() -> datafusion::common::Result<()> { + // DF producer (and thus roundtrip) always uses precision = 9, + // this test exists to test with some other value. + let substrait = Literal { + nullable: false, + type_variation_reference: 0, + literal_type: Some(LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month: Some(IntervalYearToMonth { + years: 1, + months: 2, + }), + interval_day_to_second: Some(IntervalDayToSecond { + days: 3, + seconds: 4, + subseconds: 5, + precision_mode: Some( + interval_day_to_second::PrecisionMode::Precision(6), + ), + }), + })), + }; + + let consumer = test_consumer(); + assert_eq!( + from_substrait_literal_without_names(&consumer, &substrait)?, + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { + months: 14, + days: 3, + nanoseconds: 4_000_005_000 + })) + ); + + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs b/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs new file mode 100644 index 000000000000..d70182767190 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/mod.rs @@ -0,0 +1,263 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod aggregate_function; +mod cast; +mod field_reference; +mod function_arguments; +mod if_then; +mod literal; +mod scalar_function; +mod singular_or_list; +mod subquery; +mod window_function; + +pub use aggregate_function::*; +pub use cast::*; +pub use field_reference::*; +pub use function_arguments::*; +pub use if_then::*; +pub use literal::*; +pub use scalar_function::*; +pub use singular_or_list::*; +pub use subquery::*; +pub use window_function::*; + +use crate::extensions::Extensions; +use crate::logical_plan::consumer::{ + from_substrait_named_struct, rename_field, DefaultSubstraitConsumer, + SubstraitConsumer, +}; +use datafusion::arrow::datatypes::Field; +use datafusion::common::{not_impl_err, plan_err, substrait_err, DFSchema, DFSchemaRef}; +use datafusion::execution::SessionState; +use datafusion::logical_expr::{Expr, ExprSchemable}; +use substrait::proto::expression::RexType; +use substrait::proto::expression_reference::ExprType; +use substrait::proto::{Expression, ExtendedExpression}; + +/// Convert Substrait Rex to DataFusion Expr +pub async fn from_substrait_rex( + consumer: &impl SubstraitConsumer, + expression: &Expression, + input_schema: &DFSchema, +) -> datafusion::common::Result { + match &expression.rex_type { + Some(t) => match t { + RexType::Literal(expr) => consumer.consume_literal(expr).await, + RexType::Selection(expr) => { + consumer.consume_field_reference(expr, input_schema).await + } + RexType::ScalarFunction(expr) => { + consumer.consume_scalar_function(expr, input_schema).await + } + RexType::WindowFunction(expr) => { + consumer.consume_window_function(expr, input_schema).await + } + RexType::IfThen(expr) => consumer.consume_if_then(expr, input_schema).await, + RexType::SwitchExpression(expr) => { + consumer.consume_switch(expr, input_schema).await + } + RexType::SingularOrList(expr) => { + consumer.consume_singular_or_list(expr, input_schema).await + } + + RexType::MultiOrList(expr) => { + consumer.consume_multi_or_list(expr, input_schema).await + } + + RexType::Cast(expr) => { + consumer.consume_cast(expr.as_ref(), input_schema).await + } + + RexType::Subquery(expr) => { + consumer.consume_subquery(expr.as_ref(), input_schema).await + } + RexType::Nested(expr) => consumer.consume_nested(expr, input_schema).await, + RexType::Enum(expr) => consumer.consume_enum(expr, input_schema).await, + RexType::DynamicParameter(expr) => { + consumer.consume_dynamic_parameter(expr, input_schema).await + } + }, + None => substrait_err!("Expression must set rex_type: {:?}", expression), + } +} + +/// Convert Substrait ExtendedExpression to ExprContainer +/// +/// A Substrait ExtendedExpression message contains one or more expressions, +/// with names for the outputs, and an input schema. These pieces are all included +/// in the ExprContainer. +/// +/// This is a top-level message and can be used to send expressions (not plans) +/// between systems. This is often useful for scenarios like pushdown where filter +/// expressions need to be sent to remote systems. +pub async fn from_substrait_extended_expr( + state: &SessionState, + extended_expr: &ExtendedExpression, +) -> datafusion::common::Result { + // Register function extension + let extensions = Extensions::try_from(&extended_expr.extensions)?; + if !extensions.type_variations.is_empty() { + return not_impl_err!("Type variation extensions are not supported"); + } + + let consumer = DefaultSubstraitConsumer { + extensions: &extensions, + state, + }; + + let input_schema = DFSchemaRef::new(match &extended_expr.base_schema { + Some(base_schema) => from_substrait_named_struct(&consumer, base_schema), + None => { + plan_err!("required property `base_schema` missing from Substrait ExtendedExpression message") + } + }?); + + // Parse expressions + let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len()); + for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() { + let scalar_expr = match &substrait_expr.expr_type { + Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr), + Some(ExprType::Measure(_)) => { + not_impl_err!("Measure expressions are not yet supported") + } + None => { + plan_err!("required property `expr_type` missing from Substrait ExpressionReference message") + } + }?; + let expr = consumer + .consume_expression(scalar_expr, &input_schema) + .await?; + let (output_type, expected_nullability) = + expr.data_type_and_nullable(&input_schema)?; + let output_field = Field::new("", output_type, expected_nullability); + let mut names_idx = 0; + let output_field = rename_field( + &output_field, + &substrait_expr.output_names, + expr_idx, + &mut names_idx, + )?; + exprs.push((expr, output_field)); + } + + Ok(ExprContainer { + input_schema, + exprs, + }) +} + +/// An ExprContainer is a container for a collection of expressions with a common input schema +/// +/// In addition, each expression is associated with a field, which defines the +/// expression's output. The data type and nullability of the field are calculated from the +/// expression and the input schema. However the names of the field (and its nested fields) are +/// derived from the Substrait message. +pub struct ExprContainer { + /// The input schema for the expressions + pub input_schema: DFSchemaRef, + /// The expressions + /// + /// Each item contains an expression and the field that defines the expected nullability and name of the expr's output + pub exprs: Vec<(Expr, Field)>, +} + +/// Convert Substrait Expressions to DataFusion Exprs +pub async fn from_substrait_rex_vec( + consumer: &impl SubstraitConsumer, + exprs: &Vec, + input_schema: &DFSchema, +) -> datafusion::common::Result> { + let mut expressions: Vec = vec![]; + for expr in exprs { + let expression = consumer.consume_expression(expr, input_schema).await?; + expressions.push(expression); + } + Ok(expressions) +} + +#[cfg(test)] +mod tests { + use crate::extensions::Extensions; + use crate::logical_plan::consumer::utils::tests::test_consumer; + use crate::logical_plan::consumer::*; + use datafusion::common::DFSchema; + use datafusion::logical_expr::Expr; + use substrait::proto::expression::window_function::BoundsType; + use substrait::proto::expression::RexType; + use substrait::proto::Expression; + + #[tokio::test] + async fn window_function_with_range_unit_and_no_order_by( + ) -> datafusion::common::Result<()> { + let substrait = Expression { + rex_type: Some(RexType::WindowFunction( + substrait::proto::expression::WindowFunction { + function_reference: 0, + bounds_type: BoundsType::Range as i32, + sorts: vec![], + ..Default::default() + }, + )), + }; + + let mut consumer = test_consumer(); + + // Just registering a single function (index 0) so that the plan + // does not throw a "function not found" error. + let mut extensions = Extensions::default(); + extensions.register_function("count".to_string()); + consumer.extensions = &extensions; + + match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? { + Expr::WindowFunction(window_function) => { + assert_eq!(window_function.params.order_by.len(), 1) + } + _ => panic!("expr was not a WindowFunction"), + }; + + Ok(()) + } + + #[tokio::test] + async fn window_function_with_count() -> datafusion::common::Result<()> { + let substrait = Expression { + rex_type: Some(RexType::WindowFunction( + substrait::proto::expression::WindowFunction { + function_reference: 0, + ..Default::default() + }, + )), + }; + + let mut consumer = test_consumer(); + + let mut extensions = Extensions::default(); + extensions.register_function("count".to_string()); + consumer.extensions = &extensions; + + match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? { + Expr::WindowFunction(window_function) => { + assert_eq!(window_function.params.args.len(), 1) + } + _ => panic!("expr was not a WindowFunction"), + }; + + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs new file mode 100644 index 000000000000..7797c935211f --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs @@ -0,0 +1,372 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::{from_substrait_func_args, SubstraitConsumer}; +use datafusion::common::Result; +use datafusion::common::{ + not_impl_err, plan_err, substrait_err, DFSchema, DataFusionError, ScalarValue, +}; +use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::{expr, BinaryExpr, Expr, Like, Operator}; +use std::vec::Drain; +use substrait::proto::expression::ScalarFunction; +use substrait::proto::function_argument::ArgType; + +pub async fn from_scalar_function( + consumer: &impl SubstraitConsumer, + f: &ScalarFunction, + input_schema: &DFSchema, +) -> Result { + let Some(fn_signature) = consumer + .get_extensions() + .functions + .get(&f.function_reference) + else { + return plan_err!( + "Scalar function not found: function reference = {:?}", + f.function_reference + ); + }; + let fn_name = substrait_fun_name(fn_signature); + let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?; + + // try to first match the requested function into registered udfs, then built-in ops + // and finally built-in expressions + if let Ok(func) = consumer.get_function_registry().udf(fn_name) { + Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( + func.to_owned(), + args, + ))) + } else if let Some(op) = name_to_op(fn_name) { + if args.len() < 2 { + return not_impl_err!( + "Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}", + f.arguments.len() + ); + } + // In those cases we build a balanced tree of BinaryExprs + arg_list_to_binary_op_tree(op, args) + } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { + builder.build(consumer, f, input_schema).await + } else { + not_impl_err!("Unsupported function name: {fn_name:?}") + } +} + +pub fn substrait_fun_name(name: &str) -> &str { + let name = match name.rsplit_once(':') { + // Since 0.32.0, Substrait requires the function names to be in a compound format + // https://substrait.io/extensions/#function-signature-compound-names + // for example, `add:i8_i8`. + // On the consumer side, we don't really care about the signature though, just the name. + Some((name, _)) => name, + None => name, + }; + name +} + +pub fn name_to_op(name: &str) -> Option { + match name { + "equal" => Some(Operator::Eq), + "not_equal" => Some(Operator::NotEq), + "lt" => Some(Operator::Lt), + "lte" => Some(Operator::LtEq), + "gt" => Some(Operator::Gt), + "gte" => Some(Operator::GtEq), + "add" => Some(Operator::Plus), + "subtract" => Some(Operator::Minus), + "multiply" => Some(Operator::Multiply), + "divide" => Some(Operator::Divide), + "mod" => Some(Operator::Modulo), + "modulus" => Some(Operator::Modulo), + "and" => Some(Operator::And), + "or" => Some(Operator::Or), + "is_distinct_from" => Some(Operator::IsDistinctFrom), + "is_not_distinct_from" => Some(Operator::IsNotDistinctFrom), + "regex_match" => Some(Operator::RegexMatch), + "regex_imatch" => Some(Operator::RegexIMatch), + "regex_not_match" => Some(Operator::RegexNotMatch), + "regex_not_imatch" => Some(Operator::RegexNotIMatch), + "bitwise_and" => Some(Operator::BitwiseAnd), + "bitwise_or" => Some(Operator::BitwiseOr), + "str_concat" => Some(Operator::StringConcat), + "at_arrow" => Some(Operator::AtArrow), + "arrow_at" => Some(Operator::ArrowAt), + "bitwise_xor" => Some(Operator::BitwiseXor), + "bitwise_shift_right" => Some(Operator::BitwiseShiftRight), + "bitwise_shift_left" => Some(Operator::BitwiseShiftLeft), + _ => None, + } +} + +/// Build a balanced tree of binary operations from a binary operator and a list of arguments. +/// +/// For example, `OR` `(a, b, c, d, e)` will be converted to: `OR(OR(a, OR(b, c)), OR(d, e))`. +/// +/// `args` must not be empty. +fn arg_list_to_binary_op_tree(op: Operator, mut args: Vec) -> Result { + let n_args = args.len(); + let mut drained_args = args.drain(..); + arg_list_to_binary_op_tree_inner(op, &mut drained_args, n_args) +} + +/// Helper function for [`arg_list_to_binary_op_tree`] implementation +/// +/// `take_len` represents the number of elements to take from `args` before returning. +/// We use `take_len` to avoid recursively building a `Take>>` type. +fn arg_list_to_binary_op_tree_inner( + op: Operator, + args: &mut Drain, + take_len: usize, +) -> Result { + if take_len == 1 { + return args.next().ok_or_else(|| { + DataFusionError::Substrait( + "Expected one more available element in iterator, found none".to_string(), + ) + }); + } else if take_len == 0 { + return substrait_err!("Cannot build binary operation tree with 0 arguments"); + } + // Cut argument list in 2 balanced parts + let left_take = take_len / 2; + let right_take = take_len - left_take; + let left = arg_list_to_binary_op_tree_inner(op, args, left_take)?; + let right = arg_list_to_binary_op_tree_inner(op, args, right_take)?; + Ok(Expr::BinaryExpr(BinaryExpr { + left: Box::new(left), + op, + right: Box::new(right), + })) +} + +/// Build [`Expr`] from its name and required inputs. +struct BuiltinExprBuilder { + expr_name: String, +} + +impl BuiltinExprBuilder { + pub fn try_from_name(name: &str) -> Option { + match name { + "not" | "like" | "ilike" | "is_null" | "is_not_null" | "is_true" + | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" + | "is_not_unknown" | "negative" | "negate" => Some(Self { + expr_name: name.to_string(), + }), + _ => None, + } + } + + pub async fn build( + self, + consumer: &impl SubstraitConsumer, + f: &ScalarFunction, + input_schema: &DFSchema, + ) -> Result { + match self.expr_name.as_str() { + "like" => Self::build_like_expr(consumer, false, f, input_schema).await, + "ilike" => Self::build_like_expr(consumer, true, f, input_schema).await, + "not" | "negative" | "negate" | "is_null" | "is_not_null" | "is_true" + | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" + | "is_not_unknown" => { + Self::build_unary_expr(consumer, &self.expr_name, f, input_schema).await + } + _ => { + not_impl_err!("Unsupported builtin expression: {}", self.expr_name) + } + } + } + + async fn build_unary_expr( + consumer: &impl SubstraitConsumer, + fn_name: &str, + f: &ScalarFunction, + input_schema: &DFSchema, + ) -> Result { + if f.arguments.len() != 1 { + return substrait_err!("Expect one argument for {fn_name} expr"); + } + let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { + return substrait_err!("Invalid arguments type for {fn_name} expr"); + }; + let arg = consumer + .consume_expression(expr_substrait, input_schema) + .await?; + let arg = Box::new(arg); + + let expr = match fn_name { + "not" => Expr::Not(arg), + "negative" | "negate" => Expr::Negative(arg), + "is_null" => Expr::IsNull(arg), + "is_not_null" => Expr::IsNotNull(arg), + "is_true" => Expr::IsTrue(arg), + "is_false" => Expr::IsFalse(arg), + "is_not_true" => Expr::IsNotTrue(arg), + "is_not_false" => Expr::IsNotFalse(arg), + "is_unknown" => Expr::IsUnknown(arg), + "is_not_unknown" => Expr::IsNotUnknown(arg), + _ => return not_impl_err!("Unsupported builtin expression: {}", fn_name), + }; + + Ok(expr) + } + + async fn build_like_expr( + consumer: &impl SubstraitConsumer, + case_insensitive: bool, + f: &ScalarFunction, + input_schema: &DFSchema, + ) -> Result { + let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; + if f.arguments.len() != 2 && f.arguments.len() != 3 { + return substrait_err!("Expect two or three arguments for `{fn_name}` expr"); + } + + let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let expr = consumer + .consume_expression(expr_substrait, input_schema) + .await?; + let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); + }; + let pattern = consumer + .consume_expression(pattern_substrait, input_schema) + .await?; + + // Default case: escape character is Literal(Utf8(None)) + let escape_char = if f.arguments.len() == 3 { + let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type + else { + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); + }; + + let escape_char_expr = consumer + .consume_expression(escape_char_substrait, input_schema) + .await?; + + match escape_char_expr { + Expr::Literal(ScalarValue::Utf8(escape_char_string), _) => { + // Convert Option to Option + escape_char_string.and_then(|s| s.chars().next()) + } + _ => { + return substrait_err!( + "Expect Utf8 literal for escape char, but found {escape_char_expr:?}" + ) + } + } + } else { + None + }; + + Ok(Expr::Like(Like { + negated: false, + expr: Box::new(expr), + pattern: Box::new(pattern), + escape_char, + case_insensitive, + })) + } +} + +#[cfg(test)] +mod tests { + use super::arg_list_to_binary_op_tree; + use crate::extensions::Extensions; + use crate::logical_plan::consumer::tests::TEST_SESSION_STATE; + use crate::logical_plan::consumer::{DefaultSubstraitConsumer, SubstraitConsumer}; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::common::{DFSchema, Result, ScalarValue}; + use datafusion::logical_expr::{Expr, Operator}; + use insta::assert_snapshot; + use substrait::proto::expression::literal::LiteralType; + use substrait::proto::expression::{Literal, RexType, ScalarFunction}; + use substrait::proto::function_argument::ArgType; + use substrait::proto::{Expression, FunctionArgument}; + + /// Test that large argument lists for binary operations do not crash the consumer + #[tokio::test] + async fn test_binary_op_large_argument_list() -> Result<()> { + // Build substrait extensions (we are using only one function) + let mut extensions = Extensions::default(); + extensions.functions.insert(0, String::from("or:bool_bool")); + // Build substrait consumer + let consumer = DefaultSubstraitConsumer::new(&extensions, &TEST_SESSION_STATE); + + // Build arguments for the function call, this is basically an OR(true, true, ..., true) + let arg = FunctionArgument { + arg_type: Some(ArgType::Value(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: false, + type_variation_reference: 0, + literal_type: Some(LiteralType::Boolean(true)), + })), + })), + }; + let arguments = vec![arg; 50000]; + let func = ScalarFunction { + function_reference: 0, + arguments, + ..Default::default() + }; + // Trivial input schema + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); + let df_schema = DFSchema::try_from(schema).unwrap(); + + // Consume the expression and ensure we don't crash + let _ = consumer.consume_scalar_function(&func, &df_schema).await?; + Ok(()) + } + + fn int64_literals(integers: &[i64]) -> Vec { + integers + .iter() + .map(|value| Expr::Literal(ScalarValue::Int64(Some(*value)), None)) + .collect() + } + + #[test] + fn arg_list_to_binary_op_tree_1_arg() -> Result<()> { + let expr = arg_list_to_binary_op_tree(Operator::Or, int64_literals(&[1]))?; + assert_snapshot!(expr.to_string(), @"Int64(1)"); + Ok(()) + } + + #[test] + fn arg_list_to_binary_op_tree_2_args() -> Result<()> { + let expr = arg_list_to_binary_op_tree(Operator::Or, int64_literals(&[1, 2]))?; + assert_snapshot!(expr.to_string(), @"Int64(1) OR Int64(2)"); + Ok(()) + } + + #[test] + fn arg_list_to_binary_op_tree_3_args() -> Result<()> { + let expr = arg_list_to_binary_op_tree(Operator::Or, int64_literals(&[1, 2, 3]))?; + assert_snapshot!(expr.to_string(), @"Int64(1) OR Int64(2) OR Int64(3)"); + Ok(()) + } + + #[test] + fn arg_list_to_binary_op_tree_4_args() -> Result<()> { + let expr = + arg_list_to_binary_op_tree(Operator::Or, int64_literals(&[1, 2, 3, 4]))?; + assert_snapshot!(expr.to_string(), @"Int64(1) OR Int64(2) OR Int64(3) OR Int64(4)"); + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/singular_or_list.rs b/datafusion/substrait/src/logical_plan/consumer/expr/singular_or_list.rs new file mode 100644 index 000000000000..6d44ebcce590 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/singular_or_list.rs @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::{from_substrait_rex_vec, SubstraitConsumer}; +use datafusion::common::DFSchema; +use datafusion::logical_expr::expr::InList; +use datafusion::logical_expr::Expr; +use substrait::proto::expression::SingularOrList; + +pub async fn from_singular_or_list( + consumer: &impl SubstraitConsumer, + expr: &SingularOrList, + input_schema: &DFSchema, +) -> datafusion::common::Result { + let substrait_expr = expr.value.as_ref().unwrap(); + let substrait_list = expr.options.as_ref(); + Ok(Expr::InList(InList { + expr: Box::new( + consumer + .consume_expression(substrait_expr, input_schema) + .await?, + ), + list: from_substrait_rex_vec(consumer, substrait_list, input_schema).await?, + negated: false, + })) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs b/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs new file mode 100644 index 000000000000..f7e4c2bb0fbd --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs @@ -0,0 +1,106 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{substrait_err, DFSchema, Spans}; +use datafusion::logical_expr::expr::{Exists, InSubquery}; +use datafusion::logical_expr::{Expr, Subquery}; +use std::sync::Arc; +use substrait::proto::expression as substrait_expression; +use substrait::proto::expression::subquery::set_predicate::PredicateOp; +use substrait::proto::expression::subquery::SubqueryType; + +pub async fn from_subquery( + consumer: &impl SubstraitConsumer, + subquery: &substrait_expression::Subquery, + input_schema: &DFSchema, +) -> datafusion::common::Result { + match &subquery.subquery_type { + Some(subquery_type) => match subquery_type { + SubqueryType::InPredicate(in_predicate) => { + if in_predicate.needles.len() != 1 { + substrait_err!("InPredicate Subquery type must have exactly one Needle expression") + } else { + let needle_expr = &in_predicate.needles[0]; + let haystack_expr = &in_predicate.haystack; + if let Some(haystack_expr) = haystack_expr { + let haystack_expr = consumer.consume_rel(haystack_expr).await?; + let outer_refs = haystack_expr.all_out_ref_exprs(); + Ok(Expr::InSubquery(InSubquery { + expr: Box::new( + consumer + .consume_expression(needle_expr, input_schema) + .await?, + ), + subquery: Subquery { + subquery: Arc::new(haystack_expr), + outer_ref_columns: outer_refs, + spans: Spans::new(), + }, + negated: false, + })) + } else { + substrait_err!( + "InPredicate Subquery type must have a Haystack expression" + ) + } + } + } + SubqueryType::Scalar(query) => { + let plan = consumer + .consume_rel(&(query.input.clone()).unwrap_or_default()) + .await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Expr::ScalarSubquery(Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: Spans::new(), + })) + } + SubqueryType::SetPredicate(predicate) => { + match predicate.predicate_op() { + // exist + PredicateOp::Exists => { + let relation = &predicate.tuples; + let plan = consumer + .consume_rel(&relation.clone().unwrap_or_default()) + .await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Expr::Exists(Exists::new( + Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: Spans::new(), + }, + false, + ))) + } + other_type => substrait_err!( + "unimplemented type {:?} for set predicate", + other_type + ), + } + } + other_type => { + substrait_err!("Subquery type {:?} not implemented", other_type) + } + }, + None => { + substrait_err!("Subquery expression without SubqueryType is not allowed") + } + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs new file mode 100644 index 000000000000..80b643a547ee --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs @@ -0,0 +1,163 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::{ + from_substrait_func_args, from_substrait_rex_vec, from_substrait_sorts, + substrait_fun_name, SubstraitConsumer, +}; +use datafusion::common::{ + not_impl_err, plan_datafusion_err, plan_err, substrait_err, DFSchema, ScalarValue, +}; +use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::expr::WindowFunctionParams; +use datafusion::logical_expr::{ + expr, Expr, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, +}; +use substrait::proto::expression::window_function::{Bound, BoundsType}; +use substrait::proto::expression::WindowFunction; +use substrait::proto::expression::{ + window_function::bound as SubstraitBound, window_function::bound::Kind as BoundKind, +}; + +pub async fn from_window_function( + consumer: &impl SubstraitConsumer, + window: &WindowFunction, + input_schema: &DFSchema, +) -> datafusion::common::Result { + let Some(fn_signature) = consumer + .get_extensions() + .functions + .get(&window.function_reference) + else { + return plan_err!( + "Window function not found: function reference = {:?}", + window.function_reference + ); + }; + let fn_name = substrait_fun_name(fn_signature); + + // check udwf first, then udaf, then built-in window and aggregate functions + let fun = if let Ok(udwf) = consumer.get_function_registry().udwf(fn_name) { + Ok(WindowFunctionDefinition::WindowUDF(udwf)) + } else if let Ok(udaf) = consumer.get_function_registry().udaf(fn_name) { + Ok(WindowFunctionDefinition::AggregateUDF(udaf)) + } else { + not_impl_err!( + "Window function {} is not supported: function anchor = {:?}", + fn_name, + window.function_reference + ) + }?; + + let mut order_by = + from_substrait_sorts(consumer, &window.sorts, input_schema).await?; + + let bound_units = match BoundsType::try_from(window.bounds_type).map_err(|e| { + plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type) + })? { + BoundsType::Rows => WindowFrameUnits::Rows, + BoundsType::Range => WindowFrameUnits::Range, + BoundsType::Unspecified => { + // If the plan does not specify the bounds type, then we use a simple logic to determine the units + // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary + // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row + if order_by.is_empty() { + WindowFrameUnits::Rows + } else { + WindowFrameUnits::Range + } + } + }; + let window_frame = datafusion::logical_expr::WindowFrame::new_bounds( + bound_units, + from_substrait_bound(&window.lower_bound, true)?, + from_substrait_bound(&window.upper_bound, false)?, + ); + + window_frame.regularize_order_bys(&mut order_by)?; + + // Datafusion does not support aggregate functions with no arguments, so + // we inject a dummy argument that does not affect the query, but allows + // us to bypass this limitation. + let args = if fun.name() == "count" && window.arguments.is_empty() { + vec![Expr::Literal(ScalarValue::Int64(Some(1)), None)] + } else { + from_substrait_func_args(consumer, &window.arguments, input_schema).await? + }; + + Ok(Expr::from(expr::WindowFunction { + fun, + params: WindowFunctionParams { + args, + partition_by: from_substrait_rex_vec( + consumer, + &window.partitions, + input_schema, + ) + .await?, + order_by, + window_frame, + null_treatment: None, + }, + })) +} + +fn from_substrait_bound( + bound: &Option, + is_lower: bool, +) -> datafusion::common::Result { + match bound { + Some(b) => match &b.kind { + Some(k) => match k { + BoundKind::CurrentRow(SubstraitBound::CurrentRow {}) => { + Ok(WindowFrameBound::CurrentRow) + } + BoundKind::Preceding(SubstraitBound::Preceding { offset }) => { + if *offset <= 0 { + return plan_err!("Preceding bound must be positive"); + } + Ok(WindowFrameBound::Preceding(ScalarValue::UInt64(Some( + *offset as u64, + )))) + } + BoundKind::Following(SubstraitBound::Following { offset }) => { + if *offset <= 0 { + return plan_err!("Following bound must be positive"); + } + Ok(WindowFrameBound::Following(ScalarValue::UInt64(Some( + *offset as u64, + )))) + } + BoundKind::Unbounded(SubstraitBound::Unbounded {}) => { + if is_lower { + Ok(WindowFrameBound::Preceding(ScalarValue::Null)) + } else { + Ok(WindowFrameBound::Following(ScalarValue::Null)) + } + } + }, + None => substrait_err!("WindowFunction missing Substrait Bound kind"), + }, + None => { + if is_lower { + Ok(WindowFrameBound::Preceding(ScalarValue::Null)) + } else { + Ok(WindowFrameBound::Following(ScalarValue::Null)) + } + } + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/mod.rs b/datafusion/substrait/src/logical_plan/consumer/mod.rs new file mode 100644 index 000000000000..0e01d6ded6e4 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/mod.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod expr; +mod plan; +mod rel; +mod substrait_consumer; +mod types; +mod utils; + +pub use expr::*; +pub use plan::*; +pub use rel::*; +pub use substrait_consumer::*; +pub use types::*; +pub use utils::*; diff --git a/datafusion/substrait/src/logical_plan/consumer/plan.rs b/datafusion/substrait/src/logical_plan/consumer/plan.rs new file mode 100644 index 000000000000..f994f792a17e --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/plan.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::utils::{make_renamed_schema, rename_expressions}; +use super::{DefaultSubstraitConsumer, SubstraitConsumer}; +use crate::extensions::Extensions; +use datafusion::common::{not_impl_err, plan_err}; +use datafusion::execution::SessionState; +use datafusion::logical_expr::{col, Aggregate, LogicalPlan, Projection}; +use std::sync::Arc; +use substrait::proto::{plan_rel, Plan}; + +/// Convert Substrait Plan to DataFusion LogicalPlan +pub async fn from_substrait_plan( + state: &SessionState, + plan: &Plan, +) -> datafusion::common::Result { + // Register function extension + let extensions = Extensions::try_from(&plan.extensions)?; + if !extensions.type_variations.is_empty() { + return not_impl_err!("Type variation extensions are not supported"); + } + + let consumer = DefaultSubstraitConsumer { + extensions: &extensions, + state, + }; + from_substrait_plan_with_consumer(&consumer, plan).await +} + +/// Convert Substrait Plan to DataFusion LogicalPlan using the given consumer +pub async fn from_substrait_plan_with_consumer( + consumer: &impl SubstraitConsumer, + plan: &Plan, +) -> datafusion::common::Result { + match plan.relations.len() { + 1 => { + match plan.relations[0].rel_type.as_ref() { + Some(rt) => match rt { + plan_rel::RelType::Rel(rel) => Ok(consumer.consume_rel(rel).await?), + plan_rel::RelType::Root(root) => { + let plan = consumer.consume_rel(root.input.as_ref().unwrap()).await?; + if root.names.is_empty() { + // Backwards compatibility for plans missing names + return Ok(plan); + } + let renamed_schema = make_renamed_schema(plan.schema(), &root.names)?; + if renamed_schema.has_equivalent_names_and_types(plan.schema()).is_ok() { + // Nothing to do if the schema is already equivalent + return Ok(plan); + } + match plan { + // If the last node of the plan produces expressions, bake the renames into those expressions. + // This isn't necessary for correctness, but helps with roundtrip tests. + LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema.fields())?, p.input)?)), + LogicalPlan::Aggregate(a) => { + let (group_fields, expr_fields) = renamed_schema.fields().split_at(a.group_expr.len()); + let new_group_exprs = rename_expressions(a.group_expr, a.input.schema(), group_fields)?; + let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), expr_fields)?; + Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, new_group_exprs, new_aggr_exprs)?)) + }, + // There are probably more plans where we could bake things in, can add them later as needed. + // Otherwise, add a new Project to handle the renaming. + _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema.fields())?, Arc::new(plan))?)) + } + } + }, + None => plan_err!("Cannot parse plan relation: None") + } + }, + _ => not_impl_err!( + "Substrait plan with more than 1 relation trees not supported. Number of relation trees: {:?}", + plan.relations.len() + ) + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/aggregate_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/aggregate_rel.rs new file mode 100644 index 000000000000..9421bb17c162 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/aggregate_rel.rs @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::{from_substrait_agg_func, from_substrait_sorts}; +use crate::logical_plan::consumer::{NameTracker, SubstraitConsumer}; +use datafusion::common::{not_impl_err, DFSchemaRef}; +use datafusion::logical_expr::{Expr, GroupingSet, LogicalPlan, LogicalPlanBuilder}; +use substrait::proto::aggregate_function::AggregationInvocation; +use substrait::proto::aggregate_rel::Grouping; +use substrait::proto::AggregateRel; + +pub async fn from_aggregate_rel( + consumer: &impl SubstraitConsumer, + agg: &AggregateRel, +) -> datafusion::common::Result { + if let Some(input) = agg.input.as_ref() { + let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); + let mut ref_group_exprs = vec![]; + + for e in &agg.grouping_expressions { + let x = consumer.consume_expression(e, input.schema()).await?; + ref_group_exprs.push(x); + } + + let mut group_exprs = vec![]; + let mut aggr_exprs = vec![]; + + match agg.groupings.len() { + 1 => { + group_exprs.extend_from_slice( + &from_substrait_grouping( + consumer, + &agg.groupings[0], + &ref_group_exprs, + input.schema(), + ) + .await?, + ); + } + _ => { + let mut grouping_sets = vec![]; + for grouping in &agg.groupings { + let grouping_set = from_substrait_grouping( + consumer, + grouping, + &ref_group_exprs, + input.schema(), + ) + .await?; + grouping_sets.push(grouping_set); + } + // Single-element grouping expression of type Expr::GroupingSet. + // Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when + // parsed by the producer and consumer, since Substrait does not have a type dedicated + // to ROLLUP. Only vector of Groupings (grouping sets) is available. + group_exprs + .push(Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets))); + } + }; + + for m in &agg.measures { + let filter = match &m.filter { + Some(fil) => Some(Box::new( + consumer.consume_expression(fil, input.schema()).await?, + )), + None => None, + }; + let agg_func = match &m.measure { + Some(f) => { + let distinct = match f.invocation { + _ if f.invocation == AggregationInvocation::Distinct as i32 => { + true + } + _ if f.invocation == AggregationInvocation::All as i32 => false, + _ => false, + }; + let order_by = if !f.sorts.is_empty() { + Some( + from_substrait_sorts(consumer, &f.sorts, input.schema()) + .await?, + ) + } else { + None + }; + + from_substrait_agg_func( + consumer, + f, + input.schema(), + filter, + order_by, + distinct, + ) + .await + } + None => { + not_impl_err!("Aggregate without aggregate function is not supported") + } + }; + aggr_exprs.push(agg_func?.as_ref().clone()); + } + + // Ensure that all expressions have a unique name + let mut name_tracker = NameTracker::new(); + let group_exprs = group_exprs + .iter() + .map(|e| name_tracker.get_uniquely_named_expr(e.clone())) + .collect::, _>>()?; + + input.aggregate(group_exprs, aggr_exprs)?.build() + } else { + not_impl_err!("Aggregate without an input is not valid") + } +} + +#[allow(deprecated)] +async fn from_substrait_grouping( + consumer: &impl SubstraitConsumer, + grouping: &Grouping, + expressions: &[Expr], + input_schema: &DFSchemaRef, +) -> datafusion::common::Result> { + let mut group_exprs = vec![]; + if !grouping.grouping_expressions.is_empty() { + for e in &grouping.grouping_expressions { + let expr = consumer.consume_expression(e, input_schema).await?; + group_exprs.push(expr); + } + return Ok(group_exprs); + } + for idx in &grouping.expression_references { + let e = &expressions[*idx as usize]; + group_exprs.push(e.clone()); + } + Ok(group_exprs) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/cross_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/cross_rel.rs new file mode 100644 index 000000000000..a91366e47742 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/cross_rel.rs @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::utils::requalify_sides_if_needed; +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder}; +use substrait::proto::CrossRel; + +pub async fn from_cross_rel( + consumer: &impl SubstraitConsumer, + cross: &CrossRel, +) -> datafusion::common::Result { + let left = LogicalPlanBuilder::from( + consumer.consume_rel(cross.left.as_ref().unwrap()).await?, + ); + let right = LogicalPlanBuilder::from( + consumer.consume_rel(cross.right.as_ref().unwrap()).await?, + ); + let (left, right) = requalify_sides_if_needed(left, right)?; + left.cross_join(right.build()?)?.build() +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/exchange_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/exchange_rel.rs new file mode 100644 index 000000000000..d326fff44bbb --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/exchange_rel.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::from_substrait_field_reference; +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{not_impl_err, substrait_err}; +use datafusion::logical_expr::{LogicalPlan, Partitioning, Repartition}; +use std::sync::Arc; +use substrait::proto::exchange_rel::ExchangeKind; +use substrait::proto::ExchangeRel; + +pub async fn from_exchange_rel( + consumer: &impl SubstraitConsumer, + exchange: &ExchangeRel, +) -> datafusion::common::Result { + let Some(input) = exchange.input.as_ref() else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + let input = Arc::new(consumer.consume_rel(input).await?); + + let Some(exchange_kind) = &exchange.exchange_kind else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let partitioning_scheme = match exchange_kind { + ExchangeKind::ScatterByFields(scatter_fields) => { + let mut partition_columns = vec![]; + let input_schema = input.schema(); + for field_ref in &scatter_fields.fields { + let column = from_substrait_field_reference(field_ref, input_schema)?; + partition_columns.push(column); + } + Partitioning::Hash(partition_columns, exchange.partition_count as usize) + } + ExchangeKind::RoundRobin(_) => { + Partitioning::RoundRobinBatch(exchange.partition_count as usize) + } + ExchangeKind::SingleTarget(_) + | ExchangeKind::MultiTarget(_) + | ExchangeKind::Broadcast(_) => { + return not_impl_err!("Unsupported exchange kind: {exchange_kind:?}"); + } + }; + Ok(LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + })) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/fetch_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/fetch_rel.rs new file mode 100644 index 000000000000..74161d8600ea --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/fetch_rel.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use async_recursion::async_recursion; +use datafusion::common::{not_impl_err, DFSchema, DFSchemaRef}; +use datafusion::logical_expr::{lit, LogicalPlan, LogicalPlanBuilder}; +use substrait::proto::{fetch_rel, FetchRel}; + +#[async_recursion] +pub async fn from_fetch_rel( + consumer: &impl SubstraitConsumer, + fetch: &FetchRel, +) -> datafusion::common::Result { + if let Some(input) = fetch.input.as_ref() { + let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + let offset = match &fetch.offset_mode { + Some(fetch_rel::OffsetMode::Offset(offset)) => Some(lit(*offset)), + Some(fetch_rel::OffsetMode::OffsetExpr(expr)) => { + Some(consumer.consume_expression(expr, &empty_schema).await?) + } + None => None, + }; + let count = match &fetch.count_mode { + Some(fetch_rel::CountMode::Count(count)) => { + // -1 means that ALL records should be returned, equivalent to None + (*count != -1).then(|| lit(*count)) + } + Some(fetch_rel::CountMode::CountExpr(expr)) => { + Some(consumer.consume_expression(expr, &empty_schema).await?) + } + None => None, + }; + input.limit_by_expr(offset, count)?.build() + } else { + not_impl_err!("Fetch without an input is not valid") + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/filter_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/filter_rel.rs new file mode 100644 index 000000000000..645b98278208 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/filter_rel.rs @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use async_recursion::async_recursion; +use datafusion::common::not_impl_err; +use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder}; +use substrait::proto::FilterRel; + +#[async_recursion] +pub async fn from_filter_rel( + consumer: &impl SubstraitConsumer, + filter: &FilterRel, +) -> datafusion::common::Result { + if let Some(input) = filter.input.as_ref() { + let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); + if let Some(condition) = filter.condition.as_ref() { + let expr = consumer + .consume_expression(condition, input.schema()) + .await?; + input.filter(expr)?.build() + } else { + not_impl_err!("Filter without an condition is not valid") + } + } else { + not_impl_err!("Filter without an input is not valid") + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs new file mode 100644 index 000000000000..0cf920dd6260 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/join_rel.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::utils::requalify_sides_if_needed; +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{not_impl_err, plan_err, Column, JoinType, NullEquality}; +use datafusion::logical_expr::utils::split_conjunction; +use datafusion::logical_expr::{ + BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, +}; +use substrait::proto::{join_rel, JoinRel}; + +pub async fn from_join_rel( + consumer: &impl SubstraitConsumer, + join: &JoinRel, +) -> datafusion::common::Result { + if join.post_join_filter.is_some() { + return not_impl_err!("JoinRel with post_join_filter is not yet supported"); + } + + let left: LogicalPlanBuilder = LogicalPlanBuilder::from( + consumer.consume_rel(join.left.as_ref().unwrap()).await?, + ); + let right = LogicalPlanBuilder::from( + consumer.consume_rel(join.right.as_ref().unwrap()).await?, + ); + let (left, right) = requalify_sides_if_needed(left, right)?; + + let join_type = from_substrait_jointype(join.r#type)?; + // The join condition expression needs full input schema and not the output schema from join since we lose columns from + // certain join types such as semi and anti joins + let in_join_schema = left.schema().join(right.schema())?; + + // If join expression exists, parse the `on` condition expression, build join and return + // Otherwise, build join with only the filter, without join keys + match &join.expression.as_ref() { + Some(expr) => { + let on = consumer.consume_expression(expr, &in_join_schema).await?; + // The join expression can contain both equal and non-equal ops. + // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. + // So we extract each part as follows: + // - If an Eq or IsNotDistinctFrom op is encountered, add the left column, right column and is_null_equal_nulls to `join_ons` vector + // - Otherwise we add the expression to join_filter (use conjunction if filter already exists) + let (join_ons, nulls_equal_nulls, join_filter) = + split_eq_and_noneq_join_predicate_with_nulls_equality(&on); + let (left_cols, right_cols): (Vec<_>, Vec<_>) = + itertools::multiunzip(join_ons); + let null_equality = if nulls_equal_nulls { + NullEquality::NullEqualsNull + } else { + NullEquality::NullEqualsNothing + }; + left.join_detailed( + right.build()?, + join_type, + (left_cols, right_cols), + join_filter, + null_equality, + )? + .build() + } + None => { + let on: Vec = vec![]; + left.join_detailed( + right.build()?, + join_type, + (on.clone(), on), + None, + NullEquality::NullEqualsNothing, + )? + .build() + } + } +} + +fn split_eq_and_noneq_join_predicate_with_nulls_equality( + filter: &Expr, +) -> (Vec<(Column, Column)>, bool, Option) { + let exprs = split_conjunction(filter); + + let mut accum_join_keys: Vec<(Column, Column)> = vec![]; + let mut accum_filters: Vec = vec![]; + let mut nulls_equal_nulls = false; + + for expr in exprs { + #[allow(clippy::collapsible_match)] + match expr { + Expr::BinaryExpr(binary_expr) => match binary_expr { + x @ (BinaryExpr { + left, + op: Operator::Eq, + right, + } + | BinaryExpr { + left, + op: Operator::IsNotDistinctFrom, + right, + }) => { + nulls_equal_nulls = match x.op { + Operator::Eq => false, + Operator::IsNotDistinctFrom => true, + _ => unreachable!(), + }; + + match (left.as_ref(), right.as_ref()) { + (Expr::Column(l), Expr::Column(r)) => { + accum_join_keys.push((l.clone(), r.clone())); + } + _ => accum_filters.push(expr.clone()), + } + } + _ => accum_filters.push(expr.clone()), + }, + _ => accum_filters.push(expr.clone()), + } + } + + let join_filter = accum_filters.into_iter().reduce(Expr::and); + (accum_join_keys, nulls_equal_nulls, join_filter) +} + +fn from_substrait_jointype(join_type: i32) -> datafusion::common::Result { + if let Ok(substrait_join_type) = join_rel::JoinType::try_from(join_type) { + match substrait_join_type { + join_rel::JoinType::Inner => Ok(JoinType::Inner), + join_rel::JoinType::Left => Ok(JoinType::Left), + join_rel::JoinType::Right => Ok(JoinType::Right), + join_rel::JoinType::Outer => Ok(JoinType::Full), + join_rel::JoinType::LeftAnti => Ok(JoinType::LeftAnti), + join_rel::JoinType::LeftSemi => Ok(JoinType::LeftSemi), + join_rel::JoinType::LeftMark => Ok(JoinType::LeftMark), + join_rel::JoinType::RightMark => Ok(JoinType::RightMark), + _ => plan_err!("unsupported join type {substrait_join_type:?}"), + } + } else { + plan_err!("invalid join type variant {join_type:?}") + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/mod.rs b/datafusion/substrait/src/logical_plan/consumer/rel/mod.rs new file mode 100644 index 000000000000..a83ddd8997b2 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/mod.rs @@ -0,0 +1,173 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod aggregate_rel; +mod cross_rel; +mod exchange_rel; +mod fetch_rel; +mod filter_rel; +mod join_rel; +mod project_rel; +mod read_rel; +mod set_rel; +mod sort_rel; + +pub use aggregate_rel::*; +pub use cross_rel::*; +pub use exchange_rel::*; +pub use fetch_rel::*; +pub use filter_rel::*; +pub use join_rel::*; +pub use project_rel::*; +pub use read_rel::*; +pub use set_rel::*; +pub use sort_rel::*; + +use crate::logical_plan::consumer::utils::NameTracker; +use crate::logical_plan::consumer::SubstraitConsumer; +use async_recursion::async_recursion; +use datafusion::common::{not_impl_err, substrait_datafusion_err, substrait_err, Column}; +use datafusion::logical_expr::builder::project; +use datafusion::logical_expr::{Expr, LogicalPlan, Projection}; +use std::sync::Arc; +use substrait::proto::rel::RelType; +use substrait::proto::rel_common::{Emit, EmitKind}; +use substrait::proto::{rel_common, Rel, RelCommon}; + +/// Convert Substrait Rel to DataFusion DataFrame +#[async_recursion] +pub async fn from_substrait_rel( + consumer: &impl SubstraitConsumer, + relation: &Rel, +) -> datafusion::common::Result { + let plan: datafusion::common::Result = match &relation.rel_type { + Some(rel_type) => match rel_type { + RelType::Read(rel) => consumer.consume_read(rel).await, + RelType::Filter(rel) => consumer.consume_filter(rel).await, + RelType::Fetch(rel) => consumer.consume_fetch(rel).await, + RelType::Aggregate(rel) => consumer.consume_aggregate(rel).await, + RelType::Sort(rel) => consumer.consume_sort(rel).await, + RelType::Join(rel) => consumer.consume_join(rel).await, + RelType::Project(rel) => consumer.consume_project(rel).await, + RelType::Set(rel) => consumer.consume_set(rel).await, + RelType::ExtensionSingle(rel) => consumer.consume_extension_single(rel).await, + RelType::ExtensionMulti(rel) => consumer.consume_extension_multi(rel).await, + RelType::ExtensionLeaf(rel) => consumer.consume_extension_leaf(rel).await, + RelType::Cross(rel) => consumer.consume_cross(rel).await, + RelType::Window(rel) => { + consumer.consume_consistent_partition_window(rel).await + } + RelType::Exchange(rel) => consumer.consume_exchange(rel).await, + rt => not_impl_err!("{rt:?} rel not supported yet"), + }, + None => return substrait_err!("rel must set rel_type"), + }; + apply_emit_kind(retrieve_rel_common(relation), plan?) +} + +fn apply_emit_kind( + rel_common: Option<&RelCommon>, + plan: LogicalPlan, +) -> datafusion::common::Result { + match retrieve_emit_kind(rel_common) { + EmitKind::Direct(_) => Ok(plan), + EmitKind::Emit(Emit { output_mapping }) => { + // It is valid to reference the same field multiple times in the Emit + // In this case, we need to provide unique names to avoid collisions + let mut name_tracker = NameTracker::new(); + match plan { + // To avoid adding a projection on top of a projection, we apply special case + // handling to flatten Substrait Emits. This is only applicable if none of the + // expressions in the projection are volatile. This is to avoid issues like + // converting a single call of the random() function into multiple calls due to + // duplicate fields in the output_mapping. + LogicalPlan::Projection(proj) if !contains_volatile_expr(&proj) => { + let mut exprs: Vec = vec![]; + for field in output_mapping { + let expr = proj.expr + .get(field as usize) + .ok_or_else(|| substrait_datafusion_err!( + "Emit output field {} cannot be resolved in input schema {}", + field, proj.input.schema() + ))?; + exprs.push(name_tracker.get_uniquely_named_expr(expr.clone())?); + } + + let input = Arc::unwrap_or_clone(proj.input); + project(input, exprs) + } + // Otherwise we just handle the output_mapping as a projection + _ => { + let input_schema = plan.schema(); + + let mut exprs: Vec = vec![]; + for index in output_mapping.into_iter() { + let column = Expr::Column(Column::from( + input_schema.qualified_field(index as usize), + )); + let expr = name_tracker.get_uniquely_named_expr(column)?; + exprs.push(expr); + } + + project(plan, exprs) + } + } + } + } +} + +fn retrieve_rel_common(rel: &Rel) -> Option<&RelCommon> { + match rel.rel_type.as_ref() { + None => None, + Some(rt) => match rt { + RelType::Read(r) => r.common.as_ref(), + RelType::Filter(f) => f.common.as_ref(), + RelType::Fetch(f) => f.common.as_ref(), + RelType::Aggregate(a) => a.common.as_ref(), + RelType::Sort(s) => s.common.as_ref(), + RelType::Join(j) => j.common.as_ref(), + RelType::Project(p) => p.common.as_ref(), + RelType::Set(s) => s.common.as_ref(), + RelType::ExtensionSingle(e) => e.common.as_ref(), + RelType::ExtensionMulti(e) => e.common.as_ref(), + RelType::ExtensionLeaf(e) => e.common.as_ref(), + RelType::Cross(c) => c.common.as_ref(), + RelType::Reference(_) => None, + RelType::Write(w) => w.common.as_ref(), + RelType::Ddl(d) => d.common.as_ref(), + RelType::HashJoin(j) => j.common.as_ref(), + RelType::MergeJoin(j) => j.common.as_ref(), + RelType::NestedLoopJoin(j) => j.common.as_ref(), + RelType::Window(w) => w.common.as_ref(), + RelType::Exchange(e) => e.common.as_ref(), + RelType::Expand(e) => e.common.as_ref(), + RelType::Update(_) => None, + }, + } +} + +fn retrieve_emit_kind(rel_common: Option<&RelCommon>) -> EmitKind { + // the default EmitKind is Direct if it is not set explicitly + let default = EmitKind::Direct(rel_common::Direct {}); + rel_common + .and_then(|rc| rc.emit_kind.as_ref()) + .map_or(default, |ek| ek.clone()) +} + +fn contains_volatile_expr(proj: &Projection) -> bool { + proj.expr.iter().any(|e| e.is_volatile()) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs new file mode 100644 index 000000000000..8ece6392974e --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/project_rel.rs @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::utils::NameTracker; +use crate::logical_plan::consumer::SubstraitConsumer; +use async_recursion::async_recursion; +use datafusion::common::{not_impl_err, Column}; +use datafusion::logical_expr::builder::project; +use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; +use std::collections::HashSet; +use std::sync::Arc; +use substrait::proto::ProjectRel; + +#[async_recursion] +pub async fn from_project_rel( + consumer: &impl SubstraitConsumer, + p: &ProjectRel, +) -> datafusion::common::Result { + if let Some(input) = p.input.as_ref() { + let input = consumer.consume_rel(input).await?; + let original_schema = Arc::clone(input.schema()); + + // Ensure that all expressions have a unique display name, so that + // validate_unique_names does not fail when constructing the project. + let mut name_tracker = NameTracker::new(); + + // By default, a Substrait Project emits all inputs fields followed by all expressions. + // We build the explicit expressions first, and then the input expressions to avoid + // adding aliases to the explicit expressions (as part of ensuring unique names). + // + // This is helpful for plan visualization and tests, because when DataFusion produces + // Substrait Projects it adds an output mapping that excludes all input columns + // leaving only explicit expressions. + + let mut explicit_exprs: Vec = vec![]; + // For WindowFunctions, we need to wrap them in a Window relation. If there are duplicates, + // we can do the window'ing only once, then the project will duplicate the result. + // Order here doesn't matter since LPB::window_plan sorts the expressions. + let mut window_exprs: HashSet = HashSet::new(); + for expr in &p.expressions { + let e = consumer + .consume_expression(expr, input.clone().schema()) + .await?; + // if the expression is WindowFunction, wrap in a Window relation + if let Expr::WindowFunction(_) = &e { + // Adding the same expression here and in the project below + // works because the project's builder uses columnize_expr(..) + // to transform it into a column reference + window_exprs.insert(e.clone()); + } + explicit_exprs.push(name_tracker.get_uniquely_named_expr(e)?); + } + + let input = if !window_exprs.is_empty() { + LogicalPlanBuilder::window_plan(input, window_exprs)? + } else { + input + }; + + let mut final_exprs: Vec = vec![]; + for index in 0..original_schema.fields().len() { + let e = Expr::Column(Column::from(original_schema.qualified_field(index))); + final_exprs.push(name_tracker.get_uniquely_named_expr(e)?); + } + final_exprs.append(&mut explicit_exprs); + project(input, final_exprs) + } else { + not_impl_err!("Projection without an input is not supported") + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs new file mode 100644 index 000000000000..f1cbd16d2d8f --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs @@ -0,0 +1,280 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::from_substrait_literal; +use crate::logical_plan::consumer::from_substrait_named_struct; +use crate::logical_plan::consumer::utils::ensure_schema_compatibility; +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{ + not_impl_err, plan_err, substrait_datafusion_err, substrait_err, DFSchema, + DFSchemaRef, TableReference, +}; +use datafusion::datasource::provider_as_source; +use datafusion::logical_expr::utils::split_conjunction_owned; +use datafusion::logical_expr::{ + EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder, Values, +}; +use std::sync::Arc; +use substrait::proto::expression::MaskExpression; +use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; +use substrait::proto::read_rel::ReadType; +use substrait::proto::{Expression, ReadRel}; +use url::Url; + +#[allow(deprecated)] +pub async fn from_read_rel( + consumer: &impl SubstraitConsumer, + read: &ReadRel, +) -> datafusion::common::Result { + async fn read_with_schema( + consumer: &impl SubstraitConsumer, + table_ref: TableReference, + schema: DFSchema, + projection: &Option, + filter: &Option>, + ) -> datafusion::common::Result { + let schema = schema.replace_qualifier(table_ref.clone()); + + let filters = if let Some(f) = filter { + let filter_expr = consumer.consume_expression(f, &schema).await?; + split_conjunction_owned(filter_expr) + } else { + vec![] + }; + + let plan = { + let provider = match consumer.resolve_table_ref(&table_ref).await? { + Some(ref provider) => Arc::clone(provider), + _ => return plan_err!("No table named '{table_ref}'"), + }; + + LogicalPlanBuilder::scan_with_filters( + table_ref, + provider_as_source(Arc::clone(&provider)), + None, + filters, + )? + .build()? + }; + + ensure_schema_compatibility(plan.schema(), schema.clone())?; + + let schema = apply_masking(schema, projection)?; + + apply_projection(plan, schema) + } + + let named_struct = read.base_schema.as_ref().ok_or_else(|| { + substrait_datafusion_err!("No base schema provided for Read Relation") + })?; + + let substrait_schema = from_substrait_named_struct(consumer, named_struct)?; + + match &read.read_type { + Some(ReadType::NamedTable(nt)) => { + let table_reference = match nt.names.len() { + 0 => { + return plan_err!("No table name found in NamedTable"); + } + 1 => TableReference::Bare { + table: nt.names[0].clone().into(), + }, + 2 => TableReference::Partial { + schema: nt.names[0].clone().into(), + table: nt.names[1].clone().into(), + }, + _ => TableReference::Full { + catalog: nt.names[0].clone().into(), + schema: nt.names[1].clone().into(), + table: nt.names[2].clone().into(), + }, + }; + + read_with_schema( + consumer, + table_reference, + substrait_schema, + &read.projection, + &read.filter, + ) + .await + } + Some(ReadType::VirtualTable(vt)) => { + if vt.values.is_empty() { + return Ok(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: DFSchemaRef::new(substrait_schema), + })); + } + + let values = vt + .values + .iter() + .map(|row| { + let mut name_idx = 0; + let lits = row + .fields + .iter() + .map(|lit| { + name_idx += 1; // top-level names are provided through schema + Ok(Expr::Literal(from_substrait_literal( + consumer, + lit, + &named_struct.names, + &mut name_idx, + )?, None)) + }) + .collect::>()?; + if name_idx != named_struct.names.len() { + return substrait_err!( + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + named_struct.names.len() + ); + } + Ok(lits) + }) + .collect::>()?; + + Ok(LogicalPlan::Values(Values { + schema: DFSchemaRef::new(substrait_schema), + values, + })) + } + Some(ReadType::LocalFiles(lf)) => { + fn extract_filename(name: &str) -> Option { + let corrected_url = + if name.starts_with("file://") && !name.starts_with("file:///") { + name.replacen("file://", "file:///", 1) + } else { + name.to_string() + }; + + Url::parse(&corrected_url).ok().and_then(|url| { + let path = url.path(); + std::path::Path::new(path) + .file_name() + .map(|filename| filename.to_string_lossy().to_string()) + }) + } + + // we could use the file name to check the original table provider + // TODO: currently does not support multiple local files + let filename: Option = + lf.items.first().and_then(|x| match x.path_type.as_ref() { + Some(UriFile(name)) => extract_filename(name), + _ => None, + }); + + if lf.items.len() > 1 || filename.is_none() { + return not_impl_err!("Only single file reads are supported"); + } + let name = filename.unwrap(); + // directly use unwrap here since we could determine it is a valid one + let table_reference = TableReference::Bare { table: name.into() }; + + read_with_schema( + consumer, + table_reference, + substrait_schema, + &read.projection, + &read.filter, + ) + .await + } + _ => { + not_impl_err!("Unsupported ReadType: {:?}", read.read_type) + } + } +} + +pub fn apply_masking( + schema: DFSchema, + mask_expression: &::core::option::Option, +) -> datafusion::common::Result { + match mask_expression { + Some(MaskExpression { select, .. }) => match &select.as_ref() { + Some(projection) => { + let column_indices: Vec = projection + .struct_items + .iter() + .map(|item| item.field as usize) + .collect(); + + let fields = column_indices + .iter() + .map(|i| schema.qualified_field(*i)) + .map(|(qualifier, field)| { + (qualifier.cloned(), Arc::new(field.clone())) + }) + .collect(); + + Ok(DFSchema::new_with_metadata( + fields, + schema.metadata().clone(), + )?) + } + None => Ok(schema), + }, + None => Ok(schema), + } +} + +/// This function returns a DataFrame with fields adjusted if necessary in the event that the +/// Substrait schema is a subset of the DataFusion schema. +fn apply_projection( + plan: LogicalPlan, + substrait_schema: DFSchema, +) -> datafusion::common::Result { + let df_schema = plan.schema(); + + if df_schema.logically_equivalent_names_and_types(&substrait_schema) { + return Ok(plan); + } + + let df_schema = df_schema.to_owned(); + + match plan { + LogicalPlan::TableScan(mut scan) => { + let column_indices: Vec = substrait_schema + .strip_qualifiers() + .fields() + .iter() + .map(|substrait_field| { + Ok(df_schema + .index_of_column_by_name(None, substrait_field.name().as_str()) + .unwrap()) + }) + .collect::>()?; + + let fields = column_indices + .iter() + .map(|i| df_schema.qualified_field(*i)) + .map(|(qualifier, field)| (qualifier.cloned(), Arc::new(field.clone()))) + .collect(); + + scan.projected_schema = DFSchemaRef::new(DFSchema::new_with_metadata( + fields, + df_schema.metadata().clone(), + )?); + scan.projection = Some(column_indices); + + Ok(LogicalPlan::TableScan(scan)) + } + _ => plan_err!("DataFrame passed to apply_projection must be a TableScan"), + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/set_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/set_rel.rs new file mode 100644 index 000000000000..6688a80f5274 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/set_rel.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::common::{not_impl_err, substrait_err}; +use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder}; +use substrait::proto::set_rel::SetOp; +use substrait::proto::{Rel, SetRel}; + +pub async fn from_set_rel( + consumer: &impl SubstraitConsumer, + set: &SetRel, +) -> datafusion::common::Result { + if set.inputs.len() < 2 { + substrait_err!("Set operation requires at least two inputs") + } else { + match set.op() { + SetOp::UnionAll => union_rels(consumer, &set.inputs, true).await, + SetOp::UnionDistinct => union_rels(consumer, &set.inputs, false).await, + SetOp::IntersectionPrimary => LogicalPlanBuilder::intersect( + consumer.consume_rel(&set.inputs[0]).await?, + union_rels(consumer, &set.inputs[1..], true).await?, + false, + ), + SetOp::IntersectionMultiset => { + intersect_rels(consumer, &set.inputs, false).await + } + SetOp::IntersectionMultisetAll => { + intersect_rels(consumer, &set.inputs, true).await + } + SetOp::MinusPrimary => except_rels(consumer, &set.inputs, false).await, + SetOp::MinusPrimaryAll => except_rels(consumer, &set.inputs, true).await, + set_op => not_impl_err!("Unsupported set operator: {set_op:?}"), + } + } +} + +async fn union_rels( + consumer: &impl SubstraitConsumer, + rels: &[Rel], + is_all: bool, +) -> datafusion::common::Result { + let mut union_builder = Ok(LogicalPlanBuilder::from( + consumer.consume_rel(&rels[0]).await?, + )); + for input in &rels[1..] { + let rel_plan = consumer.consume_rel(input).await?; + + union_builder = if is_all { + union_builder?.union(rel_plan) + } else { + union_builder?.union_distinct(rel_plan) + }; + } + union_builder?.build() +} + +async fn intersect_rels( + consumer: &impl SubstraitConsumer, + rels: &[Rel], + is_all: bool, +) -> datafusion::common::Result { + let mut rel = consumer.consume_rel(&rels[0]).await?; + + for input in &rels[1..] { + rel = LogicalPlanBuilder::intersect( + rel, + consumer.consume_rel(input).await?, + is_all, + )? + } + + Ok(rel) +} + +async fn except_rels( + consumer: &impl SubstraitConsumer, + rels: &[Rel], + is_all: bool, +) -> datafusion::common::Result { + let mut rel = consumer.consume_rel(&rels[0]).await?; + + for input in &rels[1..] { + rel = LogicalPlanBuilder::except(rel, consumer.consume_rel(input).await?, is_all)? + } + + Ok(rel) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/sort_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/sort_rel.rs new file mode 100644 index 000000000000..56ca0ba03857 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/rel/sort_rel.rs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::{from_substrait_sorts, SubstraitConsumer}; +use datafusion::common::not_impl_err; +use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder}; +use substrait::proto::SortRel; + +pub async fn from_sort_rel( + consumer: &impl SubstraitConsumer, + sort: &SortRel, +) -> datafusion::common::Result { + if let Some(input) = sort.input.as_ref() { + let input = LogicalPlanBuilder::from(consumer.consume_rel(input).await?); + let sorts = from_substrait_sorts(consumer, &sort.sorts, input.schema()).await?; + input.sort(sorts)?.build() + } else { + not_impl_err!("Sort without an input is not valid") + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs b/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs new file mode 100644 index 000000000000..5392dd77b576 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs @@ -0,0 +1,523 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::{ + from_aggregate_rel, from_cast, from_cross_rel, from_exchange_rel, from_fetch_rel, + from_field_reference, from_filter_rel, from_if_then, from_join_rel, from_literal, + from_project_rel, from_read_rel, from_scalar_function, from_set_rel, + from_singular_or_list, from_sort_rel, from_subquery, from_substrait_rel, + from_substrait_rex, from_window_function, +}; +use crate::extensions::Extensions; +use async_trait::async_trait; +use datafusion::arrow::datatypes::DataType; +use datafusion::catalog::TableProvider; +use datafusion::common::{ + not_impl_err, substrait_err, DFSchema, ScalarValue, TableReference, +}; +use datafusion::execution::{FunctionRegistry, SessionState}; +use datafusion::logical_expr::{Expr, Extension, LogicalPlan}; +use std::sync::Arc; +use substrait::proto; +use substrait::proto::expression as substrait_expression; +use substrait::proto::expression::{ + Enum, FieldReference, IfThen, Literal, MultiOrList, Nested, ScalarFunction, + SingularOrList, SwitchExpression, WindowFunction, +}; +use substrait::proto::{ + r#type, AggregateRel, ConsistentPartitionWindowRel, CrossRel, DynamicParameter, + ExchangeRel, Expression, ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel, + FetchRel, FilterRel, JoinRel, ProjectRel, ReadRel, Rel, SetRel, SortRel, +}; + +#[async_trait] +/// This trait is used to consume Substrait plans, converting them into DataFusion Logical Plans. +/// It can be implemented by users to allow for custom handling of relations, expressions, etc. +/// +/// Combined with the [crate::logical_plan::producer::SubstraitProducer] this allows for fully +/// customizable Substrait serde. +/// +/// # Example Usage +/// +/// ``` +/// # use async_trait::async_trait; +/// # use datafusion::catalog::TableProvider; +/// # use datafusion::common::{not_impl_err, substrait_err, DFSchema, ScalarValue, TableReference}; +/// # use datafusion::error::Result; +/// # use datafusion::execution::{FunctionRegistry, SessionState}; +/// # use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; +/// # use std::sync::Arc; +/// # use substrait::proto; +/// # use substrait::proto::{ExtensionLeafRel, FilterRel, ProjectRel}; +/// # use datafusion::arrow::datatypes::DataType; +/// # use datafusion::logical_expr::expr::ScalarFunction; +/// # use datafusion_substrait::extensions::Extensions; +/// # use datafusion_substrait::logical_plan::consumer::{ +/// # from_project_rel, from_substrait_rel, from_substrait_rex, SubstraitConsumer +/// # }; +/// +/// struct CustomSubstraitConsumer { +/// extensions: Arc, +/// state: Arc, +/// } +/// +/// #[async_trait] +/// impl SubstraitConsumer for CustomSubstraitConsumer { +/// async fn resolve_table_ref( +/// &self, +/// table_ref: &TableReference, +/// ) -> Result>> { +/// let table = table_ref.table().to_string(); +/// let schema = self.state.schema_for_ref(table_ref.clone())?; +/// let table_provider = schema.table(&table).await?; +/// Ok(table_provider) +/// } +/// +/// fn get_extensions(&self) -> &Extensions { +/// self.extensions.as_ref() +/// } +/// +/// fn get_function_registry(&self) -> &impl FunctionRegistry { +/// self.state.as_ref() +/// } +/// +/// // You can reuse existing consumer code to assist in handling advanced extensions +/// async fn consume_project(&self, rel: &ProjectRel) -> Result { +/// let df_plan = from_project_rel(self, rel).await?; +/// if let Some(advanced_extension) = rel.advanced_extension.as_ref() { +/// not_impl_err!( +/// "decode and handle an advanced extension: {:?}", +/// advanced_extension +/// ) +/// } else { +/// Ok(df_plan) +/// } +/// } +/// +/// // You can implement a fully custom consumer method if you need special handling +/// async fn consume_filter(&self, rel: &FilterRel) -> Result { +/// let input = self.consume_rel(rel.input.as_ref().unwrap()).await?; +/// let expression = +/// self.consume_expression(rel.condition.as_ref().unwrap(), input.schema()) +/// .await?; +/// // though this one is quite boring +/// LogicalPlanBuilder::from(input).filter(expression)?.build() +/// } +/// +/// // You can add handlers for extension relations +/// async fn consume_extension_leaf( +/// &self, +/// rel: &ExtensionLeafRel, +/// ) -> Result { +/// not_impl_err!( +/// "handle protobuf Any {} as you need", +/// rel.detail.as_ref().unwrap().type_url +/// ) +/// } +/// +/// // and handlers for user-define types +/// fn consume_user_defined_type(&self, typ: &proto::r#type::UserDefined) -> Result { +/// let type_string = self.extensions.types.get(&typ.type_reference).unwrap(); +/// match type_string.as_str() { +/// "u!foo" => not_impl_err!("handle foo conversion"), +/// "u!bar" => not_impl_err!("handle bar conversion"), +/// _ => substrait_err!("unexpected type") +/// } +/// } +/// +/// // and user-defined literals +/// fn consume_user_defined_literal(&self, literal: &proto::expression::literal::UserDefined) -> Result { +/// let type_string = self.extensions.types.get(&literal.type_reference).unwrap(); +/// match type_string.as_str() { +/// "u!foo" => not_impl_err!("handle foo conversion"), +/// "u!bar" => not_impl_err!("handle bar conversion"), +/// _ => substrait_err!("unexpected type") +/// } +/// } +/// } +/// ``` +/// +pub trait SubstraitConsumer: Send + Sync + Sized { + async fn resolve_table_ref( + &self, + table_ref: &TableReference, + ) -> datafusion::common::Result>>; + + // TODO: Remove these two methods + // Ideally, the abstract consumer should not place any constraints on implementations. + // The functionality for which the Extensions and FunctionRegistry is needed should be abstracted + // out into methods on the trait. As an example, resolve_table_reference is such a method. + // See: https://github.com/apache/datafusion/issues/13863 + fn get_extensions(&self) -> &Extensions; + fn get_function_registry(&self) -> &impl FunctionRegistry; + + // Relation Methods + // There is one method per Substrait relation to allow for easy overriding of consumer behaviour. + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + /// All [Rel]s to be converted pass through this method. + /// You can provide your own implementation if you wish to customize the conversion behaviour. + async fn consume_rel(&self, rel: &Rel) -> datafusion::common::Result { + from_substrait_rel(self, rel).await + } + + async fn consume_read( + &self, + rel: &ReadRel, + ) -> datafusion::common::Result { + from_read_rel(self, rel).await + } + + async fn consume_filter( + &self, + rel: &FilterRel, + ) -> datafusion::common::Result { + from_filter_rel(self, rel).await + } + + async fn consume_fetch( + &self, + rel: &FetchRel, + ) -> datafusion::common::Result { + from_fetch_rel(self, rel).await + } + + async fn consume_aggregate( + &self, + rel: &AggregateRel, + ) -> datafusion::common::Result { + from_aggregate_rel(self, rel).await + } + + async fn consume_sort( + &self, + rel: &SortRel, + ) -> datafusion::common::Result { + from_sort_rel(self, rel).await + } + + async fn consume_join( + &self, + rel: &JoinRel, + ) -> datafusion::common::Result { + from_join_rel(self, rel).await + } + + async fn consume_project( + &self, + rel: &ProjectRel, + ) -> datafusion::common::Result { + from_project_rel(self, rel).await + } + + async fn consume_set(&self, rel: &SetRel) -> datafusion::common::Result { + from_set_rel(self, rel).await + } + + async fn consume_cross( + &self, + rel: &CrossRel, + ) -> datafusion::common::Result { + from_cross_rel(self, rel).await + } + + async fn consume_consistent_partition_window( + &self, + _rel: &ConsistentPartitionWindowRel, + ) -> datafusion::common::Result { + not_impl_err!("Consistent Partition Window Rel not supported") + } + + async fn consume_exchange( + &self, + rel: &ExchangeRel, + ) -> datafusion::common::Result { + from_exchange_rel(self, rel).await + } + + // Expression Methods + // There is one method per Substrait expression to allow for easy overriding of consumer behaviour + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + /// All [Expression]s to be converted pass through this method. + /// You can provide your own implementation if you wish to customize the conversion behaviour. + async fn consume_expression( + &self, + expr: &Expression, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_substrait_rex(self, expr, input_schema).await + } + + async fn consume_literal(&self, expr: &Literal) -> datafusion::common::Result { + from_literal(self, expr).await + } + + async fn consume_field_reference( + &self, + expr: &FieldReference, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_field_reference(self, expr, input_schema).await + } + + async fn consume_scalar_function( + &self, + expr: &ScalarFunction, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_scalar_function(self, expr, input_schema).await + } + + async fn consume_window_function( + &self, + expr: &WindowFunction, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_window_function(self, expr, input_schema).await + } + + async fn consume_if_then( + &self, + expr: &IfThen, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_if_then(self, expr, input_schema).await + } + + async fn consume_switch( + &self, + _expr: &SwitchExpression, + _input_schema: &DFSchema, + ) -> datafusion::common::Result { + not_impl_err!("Switch expression not supported") + } + + async fn consume_singular_or_list( + &self, + expr: &SingularOrList, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_singular_or_list(self, expr, input_schema).await + } + + async fn consume_multi_or_list( + &self, + _expr: &MultiOrList, + _input_schema: &DFSchema, + ) -> datafusion::common::Result { + not_impl_err!("Multi Or List expression not supported") + } + + async fn consume_cast( + &self, + expr: &substrait_expression::Cast, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_cast(self, expr, input_schema).await + } + + async fn consume_subquery( + &self, + expr: &substrait_expression::Subquery, + input_schema: &DFSchema, + ) -> datafusion::common::Result { + from_subquery(self, expr, input_schema).await + } + + async fn consume_nested( + &self, + _expr: &Nested, + _input_schema: &DFSchema, + ) -> datafusion::common::Result { + not_impl_err!("Nested expression not supported") + } + + async fn consume_enum( + &self, + _expr: &Enum, + _input_schema: &DFSchema, + ) -> datafusion::common::Result { + not_impl_err!("Enum expression not supported") + } + + async fn consume_dynamic_parameter( + &self, + _expr: &DynamicParameter, + _input_schema: &DFSchema, + ) -> datafusion::common::Result { + not_impl_err!("Dynamic Parameter expression not supported") + } + + // User-Defined Functionality + + // The details of extension relations, and how to handle them, are fully up to users to specify. + // The following methods allow users to customize the consumer behaviour + + async fn consume_extension_leaf( + &self, + rel: &ExtensionLeafRel, + ) -> datafusion::common::Result { + if let Some(detail) = rel.detail.as_ref() { + return substrait_err!( + "Missing handler for ExtensionLeafRel: {}", + detail.type_url + ); + } + substrait_err!("Missing handler for ExtensionLeafRel") + } + + async fn consume_extension_single( + &self, + rel: &ExtensionSingleRel, + ) -> datafusion::common::Result { + if let Some(detail) = rel.detail.as_ref() { + return substrait_err!( + "Missing handler for ExtensionSingleRel: {}", + detail.type_url + ); + } + substrait_err!("Missing handler for ExtensionSingleRel") + } + + async fn consume_extension_multi( + &self, + rel: &ExtensionMultiRel, + ) -> datafusion::common::Result { + if let Some(detail) = rel.detail.as_ref() { + return substrait_err!( + "Missing handler for ExtensionMultiRel: {}", + detail.type_url + ); + } + substrait_err!("Missing handler for ExtensionMultiRel") + } + + // Users can bring their own types to Substrait which require custom handling + + fn consume_user_defined_type( + &self, + user_defined_type: &r#type::UserDefined, + ) -> datafusion::common::Result { + substrait_err!( + "Missing handler for user-defined type: {}", + user_defined_type.type_reference + ) + } + + fn consume_user_defined_literal( + &self, + user_defined_literal: &proto::expression::literal::UserDefined, + ) -> datafusion::common::Result { + substrait_err!( + "Missing handler for user-defined literals {}", + user_defined_literal.type_reference + ) + } +} + +/// Default SubstraitConsumer for converting standard Substrait without user-defined extensions. +/// +/// Used as the consumer in [crate::logical_plan::consumer::from_substrait_plan] +pub struct DefaultSubstraitConsumer<'a> { + pub(super) extensions: &'a Extensions, + pub(super) state: &'a SessionState, +} + +impl<'a> DefaultSubstraitConsumer<'a> { + pub fn new(extensions: &'a Extensions, state: &'a SessionState) -> Self { + DefaultSubstraitConsumer { extensions, state } + } +} + +#[async_trait] +impl SubstraitConsumer for DefaultSubstraitConsumer<'_> { + async fn resolve_table_ref( + &self, + table_ref: &TableReference, + ) -> datafusion::common::Result>> { + let table = table_ref.table().to_string(); + let schema = self.state.schema_for_ref(table_ref.clone())?; + let table_provider = schema.table(&table).await?; + Ok(table_provider) + } + + fn get_extensions(&self) -> &Extensions { + self.extensions + } + + fn get_function_registry(&self) -> &impl FunctionRegistry { + self.state + } + + async fn consume_extension_leaf( + &self, + rel: &ExtensionLeafRel, + ) -> datafusion::common::Result { + let Some(ext_detail) = &rel.detail else { + return substrait_err!("Unexpected empty detail in ExtensionLeafRel"); + }; + let plan = self + .state + .serializer_registry() + .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; + Ok(LogicalPlan::Extension(Extension { node: plan })) + } + + async fn consume_extension_single( + &self, + rel: &ExtensionSingleRel, + ) -> datafusion::common::Result { + let Some(ext_detail) = &rel.detail else { + return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); + }; + let plan = self + .state + .serializer_registry() + .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; + let Some(input_rel) = &rel.input else { + return substrait_err!( + "ExtensionSingleRel missing input rel, try using ExtensionLeafRel instead" + ); + }; + let input_plan = self.consume_rel(input_rel).await?; + let plan = plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?; + Ok(LogicalPlan::Extension(Extension { node: plan })) + } + + async fn consume_extension_multi( + &self, + rel: &ExtensionMultiRel, + ) -> datafusion::common::Result { + let Some(ext_detail) = &rel.detail else { + return substrait_err!("Unexpected empty detail in ExtensionMultiRel"); + }; + let plan = self + .state + .serializer_registry() + .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; + let mut inputs = Vec::with_capacity(rel.inputs.len()); + for input in &rel.inputs { + let input_plan = self.consume_rel(input).await?; + inputs.push(input_plan); + } + let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; + Ok(LogicalPlan::Extension(Extension { node: plan })) + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/types.rs b/datafusion/substrait/src/logical_plan/consumer/types.rs new file mode 100644 index 000000000000..4ea479e7cccd --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/types.rs @@ -0,0 +1,334 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::utils::{next_struct_field_name, DEFAULT_TIMEZONE}; +use super::SubstraitConsumer; +#[allow(deprecated)] +use crate::variation_const::{ + DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, + DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF, + DEFAULT_TYPE_VARIATION_REF, DURATION_INTERVAL_DAY_TYPE_VARIATION_REF, + INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, + INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, TIMESTAMP_MICRO_TYPE_VARIATION_REF, + TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, + TIMESTAMP_SECOND_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, +}; +use datafusion::arrow::datatypes::{ + DataType, Field, Fields, IntervalUnit, Schema, TimeUnit, +}; +use datafusion::common::{ + not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, +}; +use std::sync::Arc; +use substrait::proto::{r#type, NamedStruct, Type}; + +pub(crate) fn from_substrait_type_without_names( + consumer: &impl SubstraitConsumer, + dt: &Type, +) -> datafusion::common::Result { + from_substrait_type(consumer, dt, &[], &mut 0) +} + +pub fn from_substrait_type( + consumer: &impl SubstraitConsumer, + dt: &Type, + dfs_names: &[String], + name_idx: &mut usize, +) -> datafusion::common::Result { + match &dt.kind { + Some(s_kind) => match s_kind { + r#type::Kind::Bool(_) => Ok(DataType::Boolean), + r#type::Kind::I8(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int8), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt8), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::I16(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int16), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt16), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::I32(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt32), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::I64(integer) => match integer.type_variation_reference { + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int64), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt64), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::Fp32(_) => Ok(DataType::Float32), + r#type::Kind::Fp64(_) => Ok(DataType::Float64), + r#type::Kind::Timestamp(ts) => { + // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead + #[allow(deprecated)] + match ts.type_variation_reference { + TIMESTAMP_SECOND_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Second, None)) + } + TIMESTAMP_MILLI_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) + } + TIMESTAMP_MICRO_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) + } + TIMESTAMP_NANO_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) + } + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + } + } + r#type::Kind::PrecisionTimestamp(pts) => { + let unit = match pts.precision { + 0 => Ok(TimeUnit::Second), + 3 => Ok(TimeUnit::Millisecond), + 6 => Ok(TimeUnit::Microsecond), + 9 => Ok(TimeUnit::Nanosecond), + p => not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestamp" + ), + }?; + Ok(DataType::Timestamp(unit, None)) + } + r#type::Kind::PrecisionTimestampTz(pts) => { + let unit = match pts.precision { + 0 => Ok(TimeUnit::Second), + 3 => Ok(TimeUnit::Millisecond), + 6 => Ok(TimeUnit::Microsecond), + 9 => Ok(TimeUnit::Nanosecond), + p => not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestampTz" + ), + }?; + Ok(DataType::Timestamp(unit, Some(DEFAULT_TIMEZONE.into()))) + } + r#type::Kind::Date(date) => match date.type_variation_reference { + DATE_32_TYPE_VARIATION_REF => Ok(DataType::Date32), + DATE_64_TYPE_VARIATION_REF => Ok(DataType::Date64), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::Binary(binary) => match binary.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Binary), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeBinary), + VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::BinaryView), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::FixedBinary(fixed) => { + Ok(DataType::FixedSizeBinary(fixed.length)) + } + r#type::Kind::String(string) => match string.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeUtf8), + VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8View), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::List(list) => { + let inner_type = list.r#type.as_ref().ok_or_else(|| { + substrait_datafusion_err!("List type must have inner type") + })?; + let field = Arc::new(Field::new_list_field( + from_substrait_type(consumer, inner_type, dfs_names, name_idx)?, + // We ignore Substrait's nullability here to match to_substrait_literal + // which always creates nullable lists + true, + )); + match list.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::List(field)), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeList(field)), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + )?, + } + } + r#type::Kind::Map(map) => { + let key_type = map.key.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Map type must have key type") + })?; + let value_type = map.value.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Map type must have value type") + })?; + let key_field = Arc::new(Field::new( + "key", + from_substrait_type(consumer, key_type, dfs_names, name_idx)?, + false, + )); + let value_field = Arc::new(Field::new( + "value", + from_substrait_type(consumer, value_type, dfs_names, name_idx)?, + true, + )); + Ok(DataType::Map( + Arc::new(Field::new_struct( + "entries", + [key_field, value_field], + false, // The inner map field is always non-nullable (Arrow #1697), + )), + false, // whether keys are sorted + )) + } + r#type::Kind::Decimal(d) => match d.type_variation_reference { + DECIMAL_128_TYPE_VARIATION_REF => { + Ok(DataType::Decimal128(d.precision as u8, d.scale as i8)) + } + DECIMAL_256_TYPE_VARIATION_REF => { + Ok(DataType::Decimal256(d.precision as u8, d.scale as i8)) + } + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::IntervalYear(_) => { + Ok(DataType::Interval(IntervalUnit::YearMonth)) + } + r#type::Kind::IntervalDay(i) => match i.type_variation_reference { + DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF => { + Ok(DataType::Interval(IntervalUnit::DayTime)) + } + DURATION_INTERVAL_DAY_TYPE_VARIATION_REF => { + let duration_unit = match i.precision { + Some(0) => Ok(TimeUnit::Second), + Some(3) => Ok(TimeUnit::Millisecond), + Some(6) => Ok(TimeUnit::Microsecond), + Some(9) => Ok(TimeUnit::Nanosecond), + p => { + not_impl_err!( + "Unsupported Substrait precision {p:?} for Duration" + ) + } + }?; + Ok(DataType::Duration(duration_unit)) + } + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), + }, + r#type::Kind::IntervalCompound(_) => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } + r#type::Kind::UserDefined(u) => { + if let Ok(data_type) = consumer.consume_user_defined_type(u) { + return Ok(data_type); + } + + // TODO: remove the code below once the producer has been updated + if let Some(name) = consumer.get_extensions().types.get(&u.type_reference) + { + #[allow(deprecated)] + match name.as_ref() { + // Kept for backwards compatibility, producers should use IntervalCompound instead + INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), + _ => not_impl_err!( + "Unsupported Substrait user defined type with ref {} and variation {}", + u.type_reference, + u.type_variation_reference + ), + } + } else { + #[allow(deprecated)] + match u.type_reference { + // Kept for backwards compatibility, producers should use IntervalYear instead + INTERVAL_YEAR_MONTH_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::YearMonth)) + } + // Kept for backwards compatibility, producers should use IntervalDay instead + INTERVAL_DAY_TIME_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::DayTime)) + } + // Kept for backwards compatibility, producers should use IntervalCompound instead + INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } + _ => not_impl_err!( + "Unsupported Substrait user defined type with ref {} and variation {}", + u.type_reference, + u.type_variation_reference + ), + } + } + } + r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type( + consumer, s, dfs_names, name_idx, + )?)), + r#type::Kind::Varchar(_) => Ok(DataType::Utf8), + r#type::Kind::FixedChar(_) => Ok(DataType::Utf8), + _ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"), + }, + _ => not_impl_err!("`None` Substrait kind is not supported"), + } +} + +/// Convert Substrait NamedStruct to DataFusion DFSchemaRef +pub fn from_substrait_named_struct( + consumer: &impl SubstraitConsumer, + base_schema: &NamedStruct, +) -> datafusion::common::Result { + let mut name_idx = 0; + let fields = from_substrait_struct_type( + consumer, + base_schema.r#struct.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Named struct must contain a struct") + })?, + &base_schema.names, + &mut name_idx, + ); + if name_idx != base_schema.names.len() { + return substrait_err!( + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + base_schema.names.len() + ); + } + DFSchema::try_from(Schema::new(fields?)) +} + +fn from_substrait_struct_type( + consumer: &impl SubstraitConsumer, + s: &r#type::Struct, + dfs_names: &[String], + name_idx: &mut usize, +) -> datafusion::common::Result { + let mut fields = vec![]; + for (i, f) in s.types.iter().enumerate() { + let field = Field::new( + next_struct_field_name(i, dfs_names, name_idx)?, + from_substrait_type(consumer, f, dfs_names, name_idx)?, + true, // We assume everything to be nullable since that's easier than ensuring it matches + ); + fields.push(field); + } + Ok(fields.into()) +} diff --git a/datafusion/substrait/src/logical_plan/consumer/utils.rs b/datafusion/substrait/src/logical_plan/consumer/utils.rs new file mode 100644 index 000000000000..396c5e673f85 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/consumer/utils.rs @@ -0,0 +1,653 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::arrow::datatypes::{DataType, Field, Schema, UnionFields}; +use datafusion::common::{ + exec_err, not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, + DFSchemaRef, TableReference, +}; +use datafusion::logical_expr::expr::Sort; +use datafusion::logical_expr::{Cast, Expr, ExprSchemable, LogicalPlanBuilder}; +use std::collections::HashSet; +use std::sync::Arc; +use substrait::proto::sort_field::SortDirection; +use substrait::proto::sort_field::SortKind::{ComparisonFunctionReference, Direction}; +use substrait::proto::SortField; + +// Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which +// is the same as the expectation for any non-empty timezone in DF, so any non-empty timezone +// results in correct points on the timeline, and we pick UTC as a reasonable default. +// However, DF uses the timezone also for some arithmetic and display purposes (see e.g. +// https://github.com/apache/arrow-rs/blob/ee5694078c86c8201549654246900a4232d531a9/arrow-cast/src/cast/mod.rs#L1749). +pub(super) const DEFAULT_TIMEZONE: &str = "UTC"; + +/// (Re)qualify the sides of a join if needed, i.e. if the columns from one side would otherwise +/// conflict with the columns from the other. +/// Substrait doesn't currently allow specifying aliases, neither for columns nor for tables. For +/// Substrait the names don't matter since it only refers to columns by indices, however DataFusion +/// requires columns to be uniquely identifiable, in some places (see e.g. DFSchema::check_names). +pub(super) fn requalify_sides_if_needed( + left: LogicalPlanBuilder, + right: LogicalPlanBuilder, +) -> datafusion::common::Result<(LogicalPlanBuilder, LogicalPlanBuilder)> { + let left_cols = left.schema().columns(); + let right_cols = right.schema().columns(); + if left_cols.iter().any(|l| { + right_cols.iter().any(|r| { + l == r || (l.name == r.name && (l.relation.is_none() || r.relation.is_none())) + }) + }) { + // These names have no connection to the original plan, but they'll make the columns + // (mostly) unique. + Ok(( + left.alias(TableReference::bare("left"))?, + right.alias(TableReference::bare("right"))?, + )) + } else { + Ok((left, right)) + } +} + +pub(super) fn next_struct_field_name( + column_idx: usize, + dfs_names: &[String], + name_idx: &mut usize, +) -> datafusion::common::Result { + if dfs_names.is_empty() { + // If names are not given, create dummy names + // c0, c1, ... align with e.g. SqlToRel::create_named_struct + Ok(format!("c{column_idx}")) + } else { + let name = dfs_names.get(*name_idx).cloned().ok_or_else(|| { + substrait_datafusion_err!("Named schema must contain names for all fields") + })?; + *name_idx += 1; + Ok(name) + } +} + +/// Traverse through the field, renaming the provided field itself and all its inner struct fields. +pub fn rename_field( + field: &Field, + dfs_names: &Vec, + unnamed_field_suffix: usize, // If Substrait doesn't provide a name, we'll use this "c{unnamed_field_suffix}" + name_idx: &mut usize, // Index into dfs_names +) -> datafusion::common::Result { + let name = next_struct_field_name(unnamed_field_suffix, dfs_names, name_idx)?; + rename_fields_data_type(field.clone().with_name(name), dfs_names, name_idx) +} + +/// Rename the field's data type but not the field itself. +pub fn rename_fields_data_type( + field: Field, + dfs_names: &Vec, + name_idx: &mut usize, // Index into dfs_names +) -> datafusion::common::Result { + let dt = rename_data_type(field.data_type(), dfs_names, name_idx)?; + Ok(field.with_data_type(dt)) +} + +/// Traverse through the data type (incl. lists/maps/etc), renaming all inner struct fields. +pub fn rename_data_type( + data_type: &DataType, + dfs_names: &Vec, + name_idx: &mut usize, // Index into dfs_names +) -> datafusion::common::Result { + match data_type { + DataType::Struct(children) => { + let children = children + .iter() + .enumerate() + .map(|(field_idx, f)| { + rename_field(f.as_ref(), dfs_names, field_idx, name_idx) + }) + .collect::>()?; + Ok(DataType::Struct(children)) + } + DataType::List(inner) => Ok(DataType::List(Arc::new(rename_fields_data_type( + inner.as_ref().to_owned(), + dfs_names, + name_idx, + )?))), + DataType::LargeList(inner) => Ok(DataType::LargeList(Arc::new( + rename_fields_data_type(inner.as_ref().to_owned(), dfs_names, name_idx)?, + ))), + DataType::ListView(inner) => Ok(DataType::ListView(Arc::new( + rename_fields_data_type(inner.as_ref().to_owned(), dfs_names, name_idx)?, + ))), + DataType::LargeListView(inner) => Ok(DataType::LargeListView(Arc::new( + rename_fields_data_type(inner.as_ref().to_owned(), dfs_names, name_idx)?, + ))), + DataType::FixedSizeList(inner, len) => Ok(DataType::FixedSizeList( + Arc::new(rename_fields_data_type( + inner.as_ref().to_owned(), + dfs_names, + name_idx, + )?), + *len, + )), + DataType::Map(entries, sorted) => { + let entries_data_type = match entries.data_type() { + DataType::Struct(fields) => { + // This should be two fields, normally "key" and "value", but not guaranteed + let fields = fields + .iter() + .map(|f| { + rename_fields_data_type( + f.as_ref().to_owned(), + dfs_names, + name_idx, + ) + }) + .collect::>()?; + Ok(DataType::Struct(fields)) + } + _ => exec_err!("Expected map type to contain an inner struct type"), + }?; + Ok(DataType::Map( + Arc::new( + entries + .as_ref() + .to_owned() + .with_data_type(entries_data_type), + ), + *sorted, + )) + } + DataType::Dictionary(key_type, value_type) => { + // Dicts probably shouldn't contain structs, but support them just in case one does + Ok(DataType::Dictionary( + Box::new(rename_data_type(key_type, dfs_names, name_idx)?), + Box::new(rename_data_type(value_type, dfs_names, name_idx)?), + )) + } + DataType::RunEndEncoded(run_ends_field, values_field) => { + // At least the run_ends_field shouldn't contain names (since it should be i16/i32/i64), + // but we'll try renaming its datatype just in case. + let run_ends_field = rename_fields_data_type( + run_ends_field.as_ref().clone(), + dfs_names, + name_idx, + )?; + let values_field = rename_fields_data_type( + values_field.as_ref().clone(), + dfs_names, + name_idx, + )?; + + Ok(DataType::RunEndEncoded( + Arc::new(run_ends_field), + Arc::new(values_field), + )) + } + DataType::Union(fields, mode) => { + let fields = fields + .iter() + .map(|(i, f)| { + Ok(( + i, + Arc::new(rename_fields_data_type( + f.as_ref().clone(), + dfs_names, + name_idx, + )?), + )) + }) + .collect::>()?; + Ok(DataType::Union(fields, *mode)) + } + // Explicitly listing the rest (which can not contain inner fields needing renaming) + // to ensure we're exhaustive + DataType::Null + | DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Timestamp(_, _) + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Interval(_) + | DataType::Binary + | DataType::FixedSizeBinary(_) + | DataType::LargeBinary + | DataType::BinaryView + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Utf8View + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => Ok(data_type.clone()), + } +} + +/// Produce a version of the given schema with names matching the given list of names. +/// Substrait doesn't deal with column (incl. nested struct field) names within the schema, +/// but it does give us the list of expected names at the end of the plan, so we use this +/// to rename the schema to match the expected names. +pub(super) fn make_renamed_schema( + schema: &DFSchemaRef, + dfs_names: &Vec, +) -> datafusion::common::Result { + let mut name_idx = 0; + + let (qualifiers, fields): (_, Vec) = schema + .iter() + .enumerate() + .map(|(field_idx, (q, f))| { + let renamed_f = + rename_field(f.as_ref(), dfs_names, field_idx, &mut name_idx)?; + Ok((q.cloned(), renamed_f)) + }) + .collect::>>()? + .into_iter() + .unzip(); + + if name_idx != dfs_names.len() { + return substrait_err!( + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + dfs_names.len()); + } + + DFSchema::from_field_specific_qualified_schema( + qualifiers, + &Arc::new(Schema::new(fields)), + ) +} + +/// Ensure the expressions have the right name(s) according to the new schema. +/// This includes the top-level (column) name, which will be renamed through aliasing if needed, +/// as well as nested names (if the expression produces any struct types), which will be renamed +/// through casting if needed. +pub(super) fn rename_expressions( + exprs: impl IntoIterator, + input_schema: &DFSchema, + new_schema_fields: &[Arc], +) -> datafusion::common::Result> { + exprs + .into_iter() + .zip(new_schema_fields) + .map(|(old_expr, new_field)| { + // Check if type (i.e. nested struct field names) match, use Cast to rename if needed + let new_expr = if &old_expr.get_type(input_schema)? != new_field.data_type() { + Expr::Cast(Cast::new( + Box::new(old_expr), + new_field.data_type().to_owned(), + )) + } else { + old_expr + }; + // Alias column if needed to fix the top-level name + match &new_expr { + // If expr is a column reference, alias_if_changed would cause an aliasing if the old expr has a qualifier + Expr::Column(c) if &c.name == new_field.name() => Ok(new_expr), + _ => new_expr.alias_if_changed(new_field.name().to_owned()), + } + }) + .collect() +} + +/// Ensures that the given Substrait schema is compatible with the schema as given by DataFusion +/// +/// This means: +/// 1. All fields present in the Substrait schema are present in the DataFusion schema. The +/// DataFusion schema may have MORE fields, but not the other way around. +/// 2. All fields are compatible. See [`ensure_field_compatibility`] for details +pub(super) fn ensure_schema_compatibility( + table_schema: &DFSchema, + substrait_schema: DFSchema, +) -> datafusion::common::Result<()> { + substrait_schema + .strip_qualifiers() + .fields() + .iter() + .try_for_each(|substrait_field| { + let df_field = + table_schema.field_with_unqualified_name(substrait_field.name())?; + ensure_field_compatibility(df_field, substrait_field) + }) +} + +/// Ensures that the given Substrait field is compatible with the given DataFusion field +/// +/// A field is compatible between Substrait and DataFusion if: +/// 1. They have logically equivalent types. +/// 2. They have the same nullability OR the Substrait field is nullable and the DataFusion fields +/// is not nullable. +/// +/// If a Substrait field is not nullable, the Substrait plan may be built around assuming it is not +/// nullable. As such if DataFusion has that field as nullable the plan should be rejected. +fn ensure_field_compatibility( + datafusion_field: &Field, + substrait_field: &Field, +) -> datafusion::common::Result<()> { + if !DFSchema::datatype_is_logically_equal( + datafusion_field.data_type(), + substrait_field.data_type(), + ) { + return substrait_err!( + "Field '{}' in Substrait schema has a different type ({}) than the corresponding field in the table schema ({}).", + substrait_field.name(), + substrait_field.data_type(), + datafusion_field.data_type() + ); + } + + if !compatible_nullabilities( + datafusion_field.is_nullable(), + substrait_field.is_nullable(), + ) { + // TODO: from_substrait_struct_type needs to be updated to set the nullability correctly. It defaults to true for now. + return substrait_err!( + "Field '{}' is nullable in the DataFusion schema but not nullable in the Substrait schema.", + substrait_field.name() + ); + } + Ok(()) +} + +/// Returns true if the DataFusion and Substrait nullabilities are compatible, false otherwise +fn compatible_nullabilities( + datafusion_nullability: bool, + substrait_nullability: bool, +) -> bool { + // DataFusion and Substrait have the same nullability + (datafusion_nullability == substrait_nullability) + // DataFusion is not nullable and Substrait is nullable + || (!datafusion_nullability && substrait_nullability) +} + +pub(super) struct NameTracker { + seen_names: HashSet, +} + +pub(super) enum NameTrackerStatus { + NeverSeen, + SeenBefore, +} + +impl NameTracker { + pub(super) fn new() -> Self { + NameTracker { + seen_names: HashSet::default(), + } + } + pub(super) fn get_unique_name( + &mut self, + name: String, + ) -> (String, NameTrackerStatus) { + match self.seen_names.insert(name.clone()) { + true => (name, NameTrackerStatus::NeverSeen), + false => { + let mut counter = 0; + loop { + let candidate_name = format!("{name}__temp__{counter}"); + if self.seen_names.insert(candidate_name.clone()) { + return (candidate_name, NameTrackerStatus::SeenBefore); + } + counter += 1; + } + } + } + } + + pub(super) fn get_uniquely_named_expr( + &mut self, + expr: Expr, + ) -> datafusion::common::Result { + match self.get_unique_name(expr.name_for_alias()?) { + (_, NameTrackerStatus::NeverSeen) => Ok(expr), + (name, NameTrackerStatus::SeenBefore) => Ok(expr.alias(name)), + } + } +} + +/// Convert Substrait Sorts to DataFusion Exprs +pub async fn from_substrait_sorts( + consumer: &impl SubstraitConsumer, + substrait_sorts: &Vec, + input_schema: &DFSchema, +) -> datafusion::common::Result> { + let mut sorts: Vec = vec![]; + for s in substrait_sorts { + let expr = consumer + .consume_expression(s.expr.as_ref().unwrap(), input_schema) + .await?; + let asc_nullfirst = match &s.sort_kind { + Some(k) => match k { + Direction(d) => { + let Ok(direction) = SortDirection::try_from(*d) else { + return not_impl_err!( + "Unsupported Substrait SortDirection value {d}" + ); + }; + + match direction { + SortDirection::AscNullsFirst => Ok((true, true)), + SortDirection::AscNullsLast => Ok((true, false)), + SortDirection::DescNullsFirst => Ok((false, true)), + SortDirection::DescNullsLast => Ok((false, false)), + SortDirection::Clustered => not_impl_err!( + "Sort with direction clustered is not yet supported" + ), + SortDirection::Unspecified => { + not_impl_err!("Unspecified sort direction is invalid") + } + } + } + ComparisonFunctionReference(_) => not_impl_err!( + "Sort using comparison function reference is not supported" + ), + }, + None => not_impl_err!("Sort without sort kind is invalid"), + }; + let (asc, nulls_first) = asc_nullfirst.unwrap(); + sorts.push(Sort { + expr, + asc, + nulls_first, + }); + } + Ok(sorts) +} + +#[cfg(test)] +pub(crate) mod tests { + use super::make_renamed_schema; + use crate::extensions::Extensions; + use crate::logical_plan::consumer::DefaultSubstraitConsumer; + use datafusion::arrow::datatypes::{DataType, Field}; + use datafusion::common::DFSchema; + use datafusion::error::Result; + use datafusion::execution::SessionState; + use datafusion::prelude::SessionContext; + use datafusion::sql::TableReference; + use std::collections::HashMap; + use std::sync::{Arc, LazyLock}; + + pub(crate) static TEST_SESSION_STATE: LazyLock = + LazyLock::new(|| SessionContext::default().state()); + pub(crate) static TEST_EXTENSIONS: LazyLock = + LazyLock::new(Extensions::default); + pub(crate) fn test_consumer() -> DefaultSubstraitConsumer<'static> { + let extensions = &TEST_EXTENSIONS; + let state = &TEST_SESSION_STATE; + DefaultSubstraitConsumer::new(extensions, state) + } + + #[tokio::test] + async fn rename_schema() -> Result<()> { + let table_ref = TableReference::bare("test"); + let fields = vec![ + ( + Some(table_ref.clone()), + Arc::new(Field::new("0", DataType::Int32, false)), + ), + ( + Some(table_ref.clone()), + Arc::new(Field::new_struct( + "1", + vec![ + Field::new("2", DataType::Int32, false), + Field::new_struct( + "3", + vec![Field::new("4", DataType::Int32, false)], + false, + ), + ], + false, + )), + ), + ( + Some(table_ref.clone()), + Arc::new(Field::new_list( + "5", + Arc::new(Field::new_struct( + "item", + vec![Field::new("6", DataType::Int32, false)], + false, + )), + false, + )), + ), + ( + Some(table_ref.clone()), + Arc::new(Field::new_large_list( + "7", + Arc::new(Field::new_struct( + "item", + vec![Field::new("8", DataType::Int32, false)], + false, + )), + false, + )), + ), + ( + Some(table_ref.clone()), + Arc::new(Field::new_map( + "9", + "entries", + Arc::new(Field::new_struct( + "keys", + vec![Field::new("10", DataType::Int32, false)], + false, + )), + Arc::new(Field::new_struct( + "values", + vec![Field::new("11", DataType::Int32, false)], + false, + )), + false, + false, + )), + ), + ]; + + let schema = Arc::new(DFSchema::new_with_metadata(fields, HashMap::default())?); + let dfs_names = vec![ + "a".to_string(), + "b".to_string(), + "c".to_string(), + "d".to_string(), + "e".to_string(), + "f".to_string(), + "g".to_string(), + "h".to_string(), + "i".to_string(), + "j".to_string(), + "k".to_string(), + "l".to_string(), + ]; + let renamed_schema = make_renamed_schema(&schema, &dfs_names)?; + + assert_eq!(renamed_schema.fields().len(), 5); + assert_eq!( + *renamed_schema.field(0), + Field::new("a", DataType::Int32, false) + ); + assert_eq!( + *renamed_schema.field(1), + Field::new_struct( + "b", + vec![ + Field::new("c", DataType::Int32, false), + Field::new_struct( + "d", + vec![Field::new("e", DataType::Int32, false)], + false, + ) + ], + false, + ) + ); + assert_eq!( + *renamed_schema.field(2), + Field::new_list( + "f", + Arc::new(Field::new_struct( + "item", + vec![Field::new("g", DataType::Int32, false)], + false, + )), + false, + ) + ); + assert_eq!( + *renamed_schema.field(3), + Field::new_large_list( + "h", + Arc::new(Field::new_struct( + "item", + vec![Field::new("i", DataType::Int32, false)], + false, + )), + false, + ) + ); + assert_eq!( + *renamed_schema.field(4), + Field::new_map( + "j", + "entries", + Arc::new(Field::new_struct( + "keys", + vec![Field::new("k", DataType::Int32, false)], + false, + )), + Arc::new(Field::new_struct( + "values", + vec![Field::new("l", DataType::Int32, false)], + false, + )), + false, + false, + ) + ); + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs deleted file mode 100644 index 07bf0cb96aa3..000000000000 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ /dev/null @@ -1,2915 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::Arc; -use substrait::proto::expression_reference::ExprType; - -use datafusion::arrow::datatypes::{Field, IntervalUnit}; -use datafusion::logical_expr::{ - Aggregate, Distinct, EmptyRelation, Extension, Filter, Join, Like, Limit, - Partitioning, Projection, Repartition, Sort, SortExpr, SubqueryAlias, TableScan, - TryCast, Union, Values, Window, WindowFrameUnits, -}; -use datafusion::{ - arrow::datatypes::{DataType, TimeUnit}, - error::{DataFusionError, Result}, - logical_expr::{WindowFrame, WindowFrameBound}, - prelude::JoinType, - scalar::ScalarValue, -}; - -use crate::extensions::Extensions; -use crate::variation_const::{ - DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, - DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, - VIEW_CONTAINER_TYPE_VARIATION_REF, -}; -use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; -use datafusion::arrow::temporal_conversions::NANOSECONDS; -use datafusion::common::{ - exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err, - substrait_err, Column, DFSchema, DFSchemaRef, ToDFSchema, -}; -use datafusion::execution::registry::SerializerRegistry; -use datafusion::execution::SessionState; -use datafusion::logical_expr::expr::{ - AggregateFunctionParams, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, - InSubquery, WindowFunction, WindowFunctionParams, -}; -use datafusion::logical_expr::utils::conjunction; -use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; -use datafusion::prelude::Expr; -use pbjson_types::Any as ProtoAny; -use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; -use substrait::proto::expression::cast::FailureBehavior; -use substrait::proto::expression::field_reference::{RootReference, RootType}; -use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode; -use substrait::proto::expression::literal::map::KeyValue; -use substrait::proto::expression::literal::{ - IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, List, Map, - PrecisionTimestamp, Struct, -}; -use substrait::proto::expression::subquery::InPredicate; -use substrait::proto::expression::window_function::BoundsType; -use substrait::proto::expression::ScalarFunction; -use substrait::proto::read_rel::VirtualTable; -use substrait::proto::rel_common::EmitKind; -use substrait::proto::rel_common::EmitKind::Emit; -use substrait::proto::{ - fetch_rel, rel_common, ExchangeRel, ExpressionReference, ExtendedExpression, - RelCommon, -}; -use substrait::{ - proto::{ - aggregate_function::AggregationInvocation, - aggregate_rel::{Grouping, Measure}, - expression::{ - field_reference::ReferenceType, - if_then::IfClause, - literal::{Decimal, LiteralType}, - mask_expression::{StructItem, StructSelect}, - reference_segment, - window_function::bound as SubstraitBound, - window_function::bound::Kind as BoundKind, - window_function::Bound, - FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, RexType, - SingularOrList, WindowFunction as SubstraitWindowFunction, - }, - function_argument::ArgType, - join_rel, plan_rel, r#type, - read_rel::{NamedTable, ReadType}, - rel::RelType, - set_rel, - sort_field::{SortDirection, SortKind}, - AggregateFunction, AggregateRel, AggregationPhase, Expression, ExtensionLeafRel, - ExtensionMultiRel, ExtensionSingleRel, FetchRel, FilterRel, FunctionArgument, - JoinRel, NamedStruct, Plan, PlanRel, ProjectRel, ReadRel, Rel, RelRoot, SetRel, - SortField, SortRel, - }, - version, -}; - -/// This trait is used to produce Substrait plans, converting them from DataFusion Logical Plans. -/// It can be implemented by users to allow for custom handling of relations, expressions, etc. -/// -/// Combined with the [crate::logical_plan::consumer::SubstraitConsumer] this allows for fully -/// customizable Substrait serde. -/// -/// # Example Usage -/// -/// ``` -/// # use std::sync::Arc; -/// # use substrait::proto::{Expression, Rel}; -/// # use substrait::proto::rel::RelType; -/// # use datafusion::common::DFSchemaRef; -/// # use datafusion::error::Result; -/// # use datafusion::execution::SessionState; -/// # use datafusion::logical_expr::{Between, Extension, Projection}; -/// # use datafusion_substrait::extensions::Extensions; -/// # use datafusion_substrait::logical_plan::producer::{from_projection, SubstraitProducer}; -/// -/// struct CustomSubstraitProducer { -/// extensions: Extensions, -/// state: Arc, -/// } -/// -/// impl SubstraitProducer for CustomSubstraitProducer { -/// -/// fn register_function(&mut self, signature: String) -> u32 { -/// self.extensions.register_function(signature) -/// } -/// -/// fn get_extensions(self) -> Extensions { -/// self.extensions -/// } -/// -/// // You can set additional metadata on the Rels you produce -/// fn handle_projection(&mut self, plan: &Projection) -> Result> { -/// let mut rel = from_projection(self, plan)?; -/// match rel.rel_type { -/// Some(RelType::Project(mut project)) => { -/// let mut project = project.clone(); -/// // set common metadata or advanced extension -/// project.common = None; -/// project.advanced_extension = None; -/// Ok(Box::new(Rel { -/// rel_type: Some(RelType::Project(project)), -/// })) -/// } -/// rel_type => Ok(Box::new(Rel { rel_type })), -/// } -/// } -/// -/// // You can tweak how you convert expressions for your target system -/// fn handle_between(&mut self, between: &Between, schema: &DFSchemaRef) -> Result { -/// // add your own encoding for Between -/// todo!() -/// } -/// -/// // You can fully control how you convert UserDefinedLogicalNodes into Substrait -/// fn handle_extension(&mut self, _plan: &Extension) -> Result> { -/// // implement your own serializer into Substrait -/// todo!() -/// } -/// } -/// ``` -pub trait SubstraitProducer: Send + Sync + Sized { - /// Within a Substrait plan, functions are referenced using function anchors that are stored at - /// the top level of the [Plan] within - /// [ExtensionFunction](substrait::proto::extensions::simple_extension_declaration::ExtensionFunction) - /// messages. - /// - /// When given a function signature, this method should return the existing anchor for it if - /// there is one. Otherwise, it should generate a new anchor. - fn register_function(&mut self, signature: String) -> u32; - - /// Consume the producer to generate the [Extensions] for the Substrait plan based on the - /// functions that have been registered - fn get_extensions(self) -> Extensions; - - // Logical Plan Methods - // There is one method per LogicalPlan to allow for easy overriding of producer behaviour. - // These methods have default implementations calling the common handler code, to allow for users - // to re-use common handling logic. - - fn handle_plan(&mut self, plan: &LogicalPlan) -> Result> { - to_substrait_rel(self, plan) - } - - fn handle_projection(&mut self, plan: &Projection) -> Result> { - from_projection(self, plan) - } - - fn handle_filter(&mut self, plan: &Filter) -> Result> { - from_filter(self, plan) - } - - fn handle_window(&mut self, plan: &Window) -> Result> { - from_window(self, plan) - } - - fn handle_aggregate(&mut self, plan: &Aggregate) -> Result> { - from_aggregate(self, plan) - } - - fn handle_sort(&mut self, plan: &Sort) -> Result> { - from_sort(self, plan) - } - - fn handle_join(&mut self, plan: &Join) -> Result> { - from_join(self, plan) - } - - fn handle_repartition(&mut self, plan: &Repartition) -> Result> { - from_repartition(self, plan) - } - - fn handle_union(&mut self, plan: &Union) -> Result> { - from_union(self, plan) - } - - fn handle_table_scan(&mut self, plan: &TableScan) -> Result> { - from_table_scan(self, plan) - } - - fn handle_empty_relation(&mut self, plan: &EmptyRelation) -> Result> { - from_empty_relation(plan) - } - - fn handle_subquery_alias(&mut self, plan: &SubqueryAlias) -> Result> { - from_subquery_alias(self, plan) - } - - fn handle_limit(&mut self, plan: &Limit) -> Result> { - from_limit(self, plan) - } - - fn handle_values(&mut self, plan: &Values) -> Result> { - from_values(self, plan) - } - - fn handle_distinct(&mut self, plan: &Distinct) -> Result> { - from_distinct(self, plan) - } - - fn handle_extension(&mut self, _plan: &Extension) -> Result> { - substrait_err!("Specify handling for LogicalPlan::Extension by implementing the SubstraitProducer trait") - } - - // Expression Methods - // There is one method per DataFusion Expr to allow for easy overriding of producer behaviour - // These methods have default implementations calling the common handler code, to allow for users - // to re-use common handling logic. - - fn handle_expr(&mut self, expr: &Expr, schema: &DFSchemaRef) -> Result { - to_substrait_rex(self, expr, schema) - } - - fn handle_alias( - &mut self, - alias: &Alias, - schema: &DFSchemaRef, - ) -> Result { - from_alias(self, alias, schema) - } - - fn handle_column( - &mut self, - column: &Column, - schema: &DFSchemaRef, - ) -> Result { - from_column(column, schema) - } - - fn handle_literal(&mut self, value: &ScalarValue) -> Result { - from_literal(self, value) - } - - fn handle_binary_expr( - &mut self, - expr: &BinaryExpr, - schema: &DFSchemaRef, - ) -> Result { - from_binary_expr(self, expr, schema) - } - - fn handle_like(&mut self, like: &Like, schema: &DFSchemaRef) -> Result { - from_like(self, like, schema) - } - - /// For handling Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative - fn handle_unary_expr( - &mut self, - expr: &Expr, - schema: &DFSchemaRef, - ) -> Result { - from_unary_expr(self, expr, schema) - } - - fn handle_between( - &mut self, - between: &Between, - schema: &DFSchemaRef, - ) -> Result { - from_between(self, between, schema) - } - - fn handle_case(&mut self, case: &Case, schema: &DFSchemaRef) -> Result { - from_case(self, case, schema) - } - - fn handle_cast(&mut self, cast: &Cast, schema: &DFSchemaRef) -> Result { - from_cast(self, cast, schema) - } - - fn handle_try_cast( - &mut self, - cast: &TryCast, - schema: &DFSchemaRef, - ) -> Result { - from_try_cast(self, cast, schema) - } - - fn handle_scalar_function( - &mut self, - scalar_fn: &expr::ScalarFunction, - schema: &DFSchemaRef, - ) -> Result { - from_scalar_function(self, scalar_fn, schema) - } - - fn handle_aggregate_function( - &mut self, - agg_fn: &expr::AggregateFunction, - schema: &DFSchemaRef, - ) -> Result { - from_aggregate_function(self, agg_fn, schema) - } - - fn handle_window_function( - &mut self, - window_fn: &WindowFunction, - schema: &DFSchemaRef, - ) -> Result { - from_window_function(self, window_fn, schema) - } - - fn handle_in_list( - &mut self, - in_list: &InList, - schema: &DFSchemaRef, - ) -> Result { - from_in_list(self, in_list, schema) - } - - fn handle_in_subquery( - &mut self, - in_subquery: &InSubquery, - schema: &DFSchemaRef, - ) -> Result { - from_in_subquery(self, in_subquery, schema) - } -} - -pub struct DefaultSubstraitProducer<'a> { - extensions: Extensions, - serializer_registry: &'a dyn SerializerRegistry, -} - -impl<'a> DefaultSubstraitProducer<'a> { - pub fn new(state: &'a SessionState) -> Self { - DefaultSubstraitProducer { - extensions: Extensions::default(), - serializer_registry: state.serializer_registry().as_ref(), - } - } -} - -impl SubstraitProducer for DefaultSubstraitProducer<'_> { - fn register_function(&mut self, fn_name: String) -> u32 { - self.extensions.register_function(fn_name) - } - - fn get_extensions(self) -> Extensions { - self.extensions - } - - fn handle_extension(&mut self, plan: &Extension) -> Result> { - let extension_bytes = self - .serializer_registry - .serialize_logical_plan(plan.node.as_ref())?; - let detail = ProtoAny { - type_url: plan.node.name().to_string(), - value: extension_bytes.into(), - }; - let mut inputs_rel = plan - .node - .inputs() - .into_iter() - .map(|plan| self.handle_plan(plan)) - .collect::>>()?; - let rel_type = match inputs_rel.len() { - 0 => RelType::ExtensionLeaf(ExtensionLeafRel { - common: None, - detail: Some(detail), - }), - 1 => RelType::ExtensionSingle(Box::new(ExtensionSingleRel { - common: None, - detail: Some(detail), - input: Some(inputs_rel.pop().unwrap()), - })), - _ => RelType::ExtensionMulti(ExtensionMultiRel { - common: None, - detail: Some(detail), - inputs: inputs_rel.into_iter().map(|r| *r).collect(), - }), - }; - Ok(Box::new(Rel { - rel_type: Some(rel_type), - })) - } -} - -/// Convert DataFusion LogicalPlan to Substrait Plan -pub fn to_substrait_plan(plan: &LogicalPlan, state: &SessionState) -> Result> { - // Parse relation nodes - // Generate PlanRel(s) - // Note: Only 1 relation tree is currently supported - - let mut producer: DefaultSubstraitProducer = DefaultSubstraitProducer::new(state); - let plan_rels = vec![PlanRel { - rel_type: Some(plan_rel::RelType::Root(RelRoot { - input: Some(*producer.handle_plan(plan)?), - names: to_substrait_named_struct(plan.schema())?.names, - })), - }]; - - // Return parsed plan - let extensions = producer.get_extensions(); - Ok(Box::new(Plan { - version: Some(version::version_with_producer("datafusion")), - extension_uris: vec![], - extensions: extensions.into(), - relations: plan_rels, - advanced_extensions: None, - expected_type_urls: vec![], - parameter_bindings: vec![], - })) -} - -/// Serializes a collection of expressions to a Substrait ExtendedExpression message -/// -/// The ExtendedExpression message is a top-level message that can be used to send -/// expressions (not plans) between systems. -/// -/// Each expression is also given names for the output type. These are provided as a -/// field and not a String (since the names may be nested, e.g. a struct). The data -/// type and nullability of this field is redundant (those can be determined by the -/// Expr) and will be ignored. -/// -/// Substrait also requires the input schema of the expressions to be included in the -/// message. The field names of the input schema will be serialized. -pub fn to_substrait_extended_expr( - exprs: &[(&Expr, &Field)], - schema: &DFSchemaRef, - state: &SessionState, -) -> Result> { - let mut producer = DefaultSubstraitProducer::new(state); - let substrait_exprs = exprs - .iter() - .map(|(expr, field)| { - let substrait_expr = producer.handle_expr(expr, schema)?; - let mut output_names = Vec::new(); - flatten_names(field, false, &mut output_names)?; - Ok(ExpressionReference { - output_names, - expr_type: Some(ExprType::Expression(substrait_expr)), - }) - }) - .collect::>>()?; - let substrait_schema = to_substrait_named_struct(schema)?; - - let extensions = producer.get_extensions(); - Ok(Box::new(ExtendedExpression { - advanced_extensions: None, - expected_type_urls: vec![], - extension_uris: vec![], - extensions: extensions.into(), - version: Some(version::version_with_producer("datafusion")), - referred_expr: substrait_exprs, - base_schema: Some(substrait_schema), - })) -} - -pub fn to_substrait_rel( - producer: &mut impl SubstraitProducer, - plan: &LogicalPlan, -) -> Result> { - match plan { - LogicalPlan::Projection(plan) => producer.handle_projection(plan), - LogicalPlan::Filter(plan) => producer.handle_filter(plan), - LogicalPlan::Window(plan) => producer.handle_window(plan), - LogicalPlan::Aggregate(plan) => producer.handle_aggregate(plan), - LogicalPlan::Sort(plan) => producer.handle_sort(plan), - LogicalPlan::Join(plan) => producer.handle_join(plan), - LogicalPlan::Repartition(plan) => producer.handle_repartition(plan), - LogicalPlan::Union(plan) => producer.handle_union(plan), - LogicalPlan::TableScan(plan) => producer.handle_table_scan(plan), - LogicalPlan::EmptyRelation(plan) => producer.handle_empty_relation(plan), - LogicalPlan::Subquery(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::SubqueryAlias(plan) => producer.handle_subquery_alias(plan), - LogicalPlan::Limit(plan) => producer.handle_limit(plan), - LogicalPlan::Statement(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::Values(plan) => producer.handle_values(plan), - LogicalPlan::Explain(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::Analyze(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::Extension(plan) => producer.handle_extension(plan), - LogicalPlan::Distinct(plan) => producer.handle_distinct(plan), - LogicalPlan::Dml(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::Ddl(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::Copy(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::DescribeTable(plan) => { - not_impl_err!("Unsupported plan type: {plan:?}")? - } - LogicalPlan::Unnest(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, - LogicalPlan::RecursiveQuery(plan) => { - not_impl_err!("Unsupported plan type: {plan:?}")? - } - } -} - -pub fn from_table_scan( - producer: &mut impl SubstraitProducer, - scan: &TableScan, -) -> Result> { - let projection = scan.projection.as_ref().map(|p| { - p.iter() - .map(|i| StructItem { - field: *i as i32, - child: None, - }) - .collect() - }); - - let projection = projection.map(|struct_items| MaskExpression { - select: Some(StructSelect { struct_items }), - maintain_singular_struct: false, - }); - - let table_schema = scan.source.schema().to_dfschema_ref()?; - let base_schema = to_substrait_named_struct(&table_schema)?; - - let filter_option = if scan.filters.is_empty() { - None - } else { - let table_schema_qualified = Arc::new( - DFSchema::try_from_qualified_schema( - scan.table_name.clone(), - &(scan.source.schema()), - ) - .unwrap(), - ); - - let combined_expr = conjunction(scan.filters.clone()).unwrap(); - let filter_expr = - producer.handle_expr(&combined_expr, &table_schema_qualified)?; - Some(Box::new(filter_expr)) - }; - - Ok(Box::new(Rel { - rel_type: Some(RelType::Read(Box::new(ReadRel { - common: None, - base_schema: Some(base_schema), - filter: filter_option, - best_effort_filter: None, - projection, - advanced_extension: None, - read_type: Some(ReadType::NamedTable(NamedTable { - names: scan.table_name.to_vec(), - advanced_extension: None, - })), - }))), - })) -} - -pub fn from_empty_relation(e: &EmptyRelation) -> Result> { - if e.produce_one_row { - return not_impl_err!("Producing a row from empty relation is unsupported"); - } - #[allow(deprecated)] - Ok(Box::new(Rel { - rel_type: Some(RelType::Read(Box::new(ReadRel { - common: None, - base_schema: Some(to_substrait_named_struct(&e.schema)?), - filter: None, - best_effort_filter: None, - projection: None, - advanced_extension: None, - read_type: Some(ReadType::VirtualTable(VirtualTable { - values: vec![], - expressions: vec![], - })), - }))), - })) -} - -pub fn from_values( - producer: &mut impl SubstraitProducer, - v: &Values, -) -> Result> { - let values = v - .values - .iter() - .map(|row| { - let fields = row - .iter() - .map(|v| match v { - Expr::Literal(sv) => to_substrait_literal(producer, sv), - Expr::Alias(alias) => match alias.expr.as_ref() { - // The schema gives us the names, so we can skip aliases - Expr::Literal(sv) => to_substrait_literal(producer, sv), - _ => Err(substrait_datafusion_err!( - "Only literal types can be aliased in Virtual Tables, got: {}", alias.expr.variant_name() - )), - }, - _ => Err(substrait_datafusion_err!( - "Only literal types and aliases are supported in Virtual Tables, got: {}", v.variant_name() - )), - }) - .collect::>()?; - Ok(Struct { fields }) - }) - .collect::>()?; - #[allow(deprecated)] - Ok(Box::new(Rel { - rel_type: Some(RelType::Read(Box::new(ReadRel { - common: None, - base_schema: Some(to_substrait_named_struct(&v.schema)?), - filter: None, - best_effort_filter: None, - projection: None, - advanced_extension: None, - read_type: Some(ReadType::VirtualTable(VirtualTable { - values, - expressions: vec![], - })), - }))), - })) -} - -pub fn from_projection( - producer: &mut impl SubstraitProducer, - p: &Projection, -) -> Result> { - let expressions = p - .expr - .iter() - .map(|e| producer.handle_expr(e, p.input.schema())) - .collect::>>()?; - - let emit_kind = create_project_remapping( - expressions.len(), - p.input.as_ref().schema().fields().len(), - ); - let common = RelCommon { - emit_kind: Some(emit_kind), - hint: None, - advanced_extension: None, - }; - - Ok(Box::new(Rel { - rel_type: Some(RelType::Project(Box::new(ProjectRel { - common: Some(common), - input: Some(producer.handle_plan(p.input.as_ref())?), - expressions, - advanced_extension: None, - }))), - })) -} - -pub fn from_filter( - producer: &mut impl SubstraitProducer, - filter: &Filter, -) -> Result> { - let input = producer.handle_plan(filter.input.as_ref())?; - let filter_expr = producer.handle_expr(&filter.predicate, filter.input.schema())?; - Ok(Box::new(Rel { - rel_type: Some(RelType::Filter(Box::new(FilterRel { - common: None, - input: Some(input), - condition: Some(Box::new(filter_expr)), - advanced_extension: None, - }))), - })) -} - -pub fn from_limit( - producer: &mut impl SubstraitProducer, - limit: &Limit, -) -> Result> { - let input = producer.handle_plan(limit.input.as_ref())?; - let empty_schema = Arc::new(DFSchema::empty()); - let offset_mode = limit - .skip - .as_ref() - .map(|expr| producer.handle_expr(expr.as_ref(), &empty_schema)) - .transpose()? - .map(Box::new) - .map(fetch_rel::OffsetMode::OffsetExpr); - let count_mode = limit - .fetch - .as_ref() - .map(|expr| producer.handle_expr(expr.as_ref(), &empty_schema)) - .transpose()? - .map(Box::new) - .map(fetch_rel::CountMode::CountExpr); - Ok(Box::new(Rel { - rel_type: Some(RelType::Fetch(Box::new(FetchRel { - common: None, - input: Some(input), - offset_mode, - count_mode, - advanced_extension: None, - }))), - })) -} - -pub fn from_sort(producer: &mut impl SubstraitProducer, sort: &Sort) -> Result> { - let Sort { expr, input, fetch } = sort; - let sort_fields = expr - .iter() - .map(|e| substrait_sort_field(producer, e, input.schema())) - .collect::>>()?; - - let input = producer.handle_plan(input.as_ref())?; - - let sort_rel = Box::new(Rel { - rel_type: Some(RelType::Sort(Box::new(SortRel { - common: None, - input: Some(input), - sorts: sort_fields, - advanced_extension: None, - }))), - }); - - match fetch { - Some(amount) => { - let count_mode = - Some(fetch_rel::CountMode::CountExpr(Box::new(Expression { - rex_type: Some(RexType::Literal(Literal { - nullable: false, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - literal_type: Some(LiteralType::I64(*amount as i64)), - })), - }))); - Ok(Box::new(Rel { - rel_type: Some(RelType::Fetch(Box::new(FetchRel { - common: None, - input: Some(sort_rel), - offset_mode: None, - count_mode, - advanced_extension: None, - }))), - })) - } - None => Ok(sort_rel), - } -} - -pub fn from_aggregate( - producer: &mut impl SubstraitProducer, - agg: &Aggregate, -) -> Result> { - let input = producer.handle_plan(agg.input.as_ref())?; - let (grouping_expressions, groupings) = - to_substrait_groupings(producer, &agg.group_expr, agg.input.schema())?; - let measures = agg - .aggr_expr - .iter() - .map(|e| to_substrait_agg_measure(producer, e, agg.input.schema())) - .collect::>>()?; - - Ok(Box::new(Rel { - rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { - common: None, - input: Some(input), - grouping_expressions, - groupings, - measures, - advanced_extension: None, - }))), - })) -} - -pub fn from_distinct( - producer: &mut impl SubstraitProducer, - distinct: &Distinct, -) -> Result> { - match distinct { - Distinct::All(plan) => { - // Use Substrait's AggregateRel with empty measures to represent `select distinct` - let input = producer.handle_plan(plan.as_ref())?; - // Get grouping keys from the input relation's number of output fields - let grouping = (0..plan.schema().fields().len()) - .map(substrait_field_ref) - .collect::>>()?; - - #[allow(deprecated)] - Ok(Box::new(Rel { - rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { - common: None, - input: Some(input), - grouping_expressions: vec![], - groupings: vec![Grouping { - grouping_expressions: grouping, - expression_references: vec![], - }], - measures: vec![], - advanced_extension: None, - }))), - })) - } - Distinct::On(_) => not_impl_err!("Cannot convert Distinct::On"), - } -} - -pub fn from_join(producer: &mut impl SubstraitProducer, join: &Join) -> Result> { - let left = producer.handle_plan(join.left.as_ref())?; - let right = producer.handle_plan(join.right.as_ref())?; - let join_type = to_substrait_jointype(join.join_type); - // we only support basic joins so return an error for anything not yet supported - match join.join_constraint { - JoinConstraint::On => {} - JoinConstraint::Using => return not_impl_err!("join constraint: `using`"), - } - let in_join_schema = Arc::new(join.left.schema().join(join.right.schema())?); - - // convert filter if present - let join_filter = match &join.filter { - Some(filter) => Some(producer.handle_expr(filter, &in_join_schema)?), - None => None, - }; - - // map the left and right columns to binary expressions in the form `l = r` - // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` - let eq_op = if join.null_equals_null { - Operator::IsNotDistinctFrom - } else { - Operator::Eq - }; - let join_on = to_substrait_join_expr(producer, &join.on, eq_op, &in_join_schema)?; - - // create conjunction between `join_on` and `join_filter` to embed all join conditions, - // whether equal or non-equal in a single expression - let join_expr = match &join_on { - Some(on_expr) => match &join_filter { - Some(filter) => Some(Box::new(make_binary_op_scalar_func( - producer, - on_expr, - filter, - Operator::And, - ))), - None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist - }, - None => match &join_filter { - Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist - None => None, - }, - }; - - Ok(Box::new(Rel { - rel_type: Some(RelType::Join(Box::new(JoinRel { - common: None, - left: Some(left), - right: Some(right), - r#type: join_type as i32, - expression: join_expr, - post_join_filter: None, - advanced_extension: None, - }))), - })) -} - -pub fn from_subquery_alias( - producer: &mut impl SubstraitProducer, - alias: &SubqueryAlias, -) -> Result> { - // Do nothing if encounters SubqueryAlias - // since there is no corresponding relation type in Substrait - producer.handle_plan(alias.input.as_ref()) -} - -pub fn from_union( - producer: &mut impl SubstraitProducer, - union: &Union, -) -> Result> { - let input_rels = union - .inputs - .iter() - .map(|input| producer.handle_plan(input.as_ref())) - .collect::>>()? - .into_iter() - .map(|ptr| *ptr) - .collect(); - Ok(Box::new(Rel { - rel_type: Some(RelType::Set(SetRel { - common: None, - inputs: input_rels, - op: set_rel::SetOp::UnionAll as i32, // UNION DISTINCT gets translated to AGGREGATION + UNION ALL - advanced_extension: None, - })), - })) -} - -pub fn from_window( - producer: &mut impl SubstraitProducer, - window: &Window, -) -> Result> { - let input = producer.handle_plan(window.input.as_ref())?; - - // create a field reference for each input field - let mut expressions = (0..window.input.schema().fields().len()) - .map(substrait_field_ref) - .collect::>>()?; - - // process and add each window function expression - for expr in &window.window_expr { - expressions.push(producer.handle_expr(expr, window.input.schema())?); - } - - let emit_kind = - create_project_remapping(expressions.len(), window.input.schema().fields().len()); - let common = RelCommon { - emit_kind: Some(emit_kind), - hint: None, - advanced_extension: None, - }; - let project_rel = Box::new(ProjectRel { - common: Some(common), - input: Some(input), - expressions, - advanced_extension: None, - }); - - Ok(Box::new(Rel { - rel_type: Some(RelType::Project(project_rel)), - })) -} - -pub fn from_repartition( - producer: &mut impl SubstraitProducer, - repartition: &Repartition, -) -> Result> { - let input = producer.handle_plan(repartition.input.as_ref())?; - let partition_count = match repartition.partitioning_scheme { - Partitioning::RoundRobinBatch(num) => num, - Partitioning::Hash(_, num) => num, - Partitioning::DistributeBy(_) => { - return not_impl_err!( - "Physical plan does not support DistributeBy partitioning" - ) - } - }; - // ref: https://substrait.io/relations/physical_relations/#exchange-types - let exchange_kind = match &repartition.partitioning_scheme { - Partitioning::RoundRobinBatch(_) => { - ExchangeKind::RoundRobin(RoundRobin::default()) - } - Partitioning::Hash(exprs, _) => { - let fields = exprs - .iter() - .map(|e| try_to_substrait_field_reference(e, repartition.input.schema())) - .collect::>>()?; - ExchangeKind::ScatterByFields(ScatterFields { fields }) - } - Partitioning::DistributeBy(_) => { - return not_impl_err!( - "Physical plan does not support DistributeBy partitioning" - ) - } - }; - let exchange_rel = ExchangeRel { - common: None, - input: Some(input), - exchange_kind: Some(exchange_kind), - advanced_extension: None, - partition_count: partition_count as i32, - targets: vec![], - }; - Ok(Box::new(Rel { - rel_type: Some(RelType::Exchange(Box::new(exchange_rel))), - })) -} - -/// By default, a Substrait Project outputs all input fields followed by all expressions. -/// A DataFusion Projection only outputs expressions. In order to keep the Substrait -/// plan consistent with DataFusion, we must apply an output mapping that skips the input -/// fields so that the Substrait Project will only output the expression fields. -fn create_project_remapping(expr_count: usize, input_field_count: usize) -> EmitKind { - let expression_field_start = input_field_count; - let expression_field_end = expression_field_start + expr_count; - let output_mapping = (expression_field_start..expression_field_end) - .map(|i| i as i32) - .collect(); - Emit(rel_common::Emit { output_mapping }) -} - -// Substrait wants a list of all field names, including nested fields from structs, -// also from within e.g. lists and maps. However, it does not want the list and map field names -// themselves - only proper structs fields are considered to have useful names. -fn flatten_names(field: &Field, skip_self: bool, names: &mut Vec) -> Result<()> { - if !skip_self { - names.push(field.name().to_string()); - } - match field.data_type() { - DataType::Struct(fields) => { - for field in fields { - flatten_names(field, false, names)?; - } - Ok(()) - } - DataType::List(l) => flatten_names(l, true, names), - DataType::LargeList(l) => flatten_names(l, true, names), - DataType::Map(m, _) => match m.data_type() { - DataType::Struct(key_and_value) if key_and_value.len() == 2 => { - flatten_names(&key_and_value[0], true, names)?; - flatten_names(&key_and_value[1], true, names) - } - _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), - }, - _ => Ok(()), - }?; - Ok(()) -} - -fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { - let mut names = Vec::with_capacity(schema.fields().len()); - for field in schema.fields() { - flatten_names(field, false, &mut names)?; - } - - let field_types = r#type::Struct { - types: schema - .fields() - .iter() - .map(|f| to_substrait_type(f.data_type(), f.is_nullable())) - .collect::>()?, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability: r#type::Nullability::Required as i32, - }; - - Ok(NamedStruct { - names, - r#struct: Some(field_types), - }) -} - -fn to_substrait_join_expr( - producer: &mut impl SubstraitProducer, - join_conditions: &Vec<(Expr, Expr)>, - eq_op: Operator, - join_schema: &DFSchemaRef, -) -> Result> { - // Only support AND conjunction for each binary expression in join conditions - let mut exprs: Vec = vec![]; - for (left, right) in join_conditions { - let l = producer.handle_expr(left, join_schema)?; - let r = producer.handle_expr(right, join_schema)?; - // AND with existing expression - exprs.push(make_binary_op_scalar_func(producer, &l, &r, eq_op)); - } - - let join_expr: Option = - exprs.into_iter().reduce(|acc: Expression, e: Expression| { - make_binary_op_scalar_func(producer, &acc, &e, Operator::And) - }); - Ok(join_expr) -} - -fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType { - match join_type { - JoinType::Inner => join_rel::JoinType::Inner, - JoinType::Left => join_rel::JoinType::Left, - JoinType::Right => join_rel::JoinType::Right, - JoinType::Full => join_rel::JoinType::Outer, - JoinType::LeftAnti => join_rel::JoinType::LeftAnti, - JoinType::LeftSemi => join_rel::JoinType::LeftSemi, - JoinType::LeftMark => join_rel::JoinType::LeftMark, - JoinType::RightAnti | JoinType::RightSemi => { - unimplemented!() - } - } -} - -pub fn operator_to_name(op: Operator) -> &'static str { - match op { - Operator::Eq => "equal", - Operator::NotEq => "not_equal", - Operator::Lt => "lt", - Operator::LtEq => "lte", - Operator::Gt => "gt", - Operator::GtEq => "gte", - Operator::Plus => "add", - Operator::Minus => "subtract", - Operator::Multiply => "multiply", - Operator::Divide => "divide", - Operator::Modulo => "modulus", - Operator::And => "and", - Operator::Or => "or", - Operator::IsDistinctFrom => "is_distinct_from", - Operator::IsNotDistinctFrom => "is_not_distinct_from", - Operator::RegexMatch => "regex_match", - Operator::RegexIMatch => "regex_imatch", - Operator::RegexNotMatch => "regex_not_match", - Operator::RegexNotIMatch => "regex_not_imatch", - Operator::LikeMatch => "like_match", - Operator::ILikeMatch => "like_imatch", - Operator::NotLikeMatch => "like_not_match", - Operator::NotILikeMatch => "like_not_imatch", - Operator::BitwiseAnd => "bitwise_and", - Operator::BitwiseOr => "bitwise_or", - Operator::StringConcat => "str_concat", - Operator::AtArrow => "at_arrow", - Operator::ArrowAt => "arrow_at", - Operator::Arrow => "arrow", - Operator::LongArrow => "long_arrow", - Operator::HashArrow => "hash_arrow", - Operator::HashLongArrow => "hash_long_arrow", - Operator::AtAt => "at_at", - Operator::IntegerDivide => "integer_divide", - Operator::HashMinus => "hash_minus", - Operator::AtQuestion => "at_question", - Operator::Question => "question", - Operator::QuestionAnd => "question_and", - Operator::QuestionPipe => "question_pipe", - Operator::BitwiseXor => "bitwise_xor", - Operator::BitwiseShiftRight => "bitwise_shift_right", - Operator::BitwiseShiftLeft => "bitwise_shift_left", - } -} - -pub fn parse_flat_grouping_exprs( - producer: &mut impl SubstraitProducer, - exprs: &[Expr], - schema: &DFSchemaRef, - ref_group_exprs: &mut Vec, -) -> Result { - let mut expression_references = vec![]; - let mut grouping_expressions = vec![]; - - for e in exprs { - let rex = producer.handle_expr(e, schema)?; - grouping_expressions.push(rex.clone()); - ref_group_exprs.push(rex); - expression_references.push((ref_group_exprs.len() - 1) as u32); - } - #[allow(deprecated)] - Ok(Grouping { - grouping_expressions, - expression_references, - }) -} - -pub fn to_substrait_groupings( - producer: &mut impl SubstraitProducer, - exprs: &[Expr], - schema: &DFSchemaRef, -) -> Result<(Vec, Vec)> { - let mut ref_group_exprs = vec![]; - let groupings = match exprs.len() { - 1 => match &exprs[0] { - Expr::GroupingSet(gs) => match gs { - GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented( - "GroupingSet CUBE is not yet supported".to_string(), - )), - GroupingSet::GroupingSets(sets) => Ok(sets - .iter() - .map(|set| { - parse_flat_grouping_exprs( - producer, - set, - schema, - &mut ref_group_exprs, - ) - }) - .collect::>>()?), - GroupingSet::Rollup(set) => { - let mut sets: Vec> = vec![vec![]]; - for i in 0..set.len() { - sets.push(set[..=i].to_vec()); - } - Ok(sets - .iter() - .rev() - .map(|set| { - parse_flat_grouping_exprs( - producer, - set, - schema, - &mut ref_group_exprs, - ) - }) - .collect::>>()?) - } - }, - _ => Ok(vec![parse_flat_grouping_exprs( - producer, - exprs, - schema, - &mut ref_group_exprs, - )?]), - }, - _ => Ok(vec![parse_flat_grouping_exprs( - producer, - exprs, - schema, - &mut ref_group_exprs, - )?]), - }?; - Ok((ref_group_exprs, groupings)) -} - -pub fn from_aggregate_function( - producer: &mut impl SubstraitProducer, - agg_fn: &expr::AggregateFunction, - schema: &DFSchemaRef, -) -> Result { - let expr::AggregateFunction { - func, - params: - AggregateFunctionParams { - args, - distinct, - filter, - order_by, - null_treatment: _null_treatment, - }, - } = agg_fn; - let sorts = if let Some(order_by) = order_by { - order_by - .iter() - .map(|expr| to_substrait_sort_field(producer, expr, schema)) - .collect::>>()? - } else { - vec![] - }; - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), - }); - } - let function_anchor = producer.register_function(func.name().to_string()); - #[allow(deprecated)] - Ok(Measure { - measure: Some(AggregateFunction { - function_reference: function_anchor, - arguments, - sorts, - output_type: None, - invocation: match distinct { - true => AggregationInvocation::Distinct as i32, - false => AggregationInvocation::All as i32, - }, - phase: AggregationPhase::Unspecified as i32, - args: vec![], - options: vec![], - }), - filter: match filter { - Some(f) => Some(producer.handle_expr(f, schema)?), - None => None, - }, - }) -} - -pub fn to_substrait_agg_measure( - producer: &mut impl SubstraitProducer, - expr: &Expr, - schema: &DFSchemaRef, -) -> Result { - match expr { - Expr::AggregateFunction(agg_fn) => from_aggregate_function(producer, agg_fn, schema), - Expr::Alias(Alias { expr, .. }) => { - to_substrait_agg_measure(producer, expr, schema) - } - _ => internal_err!( - "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", - expr, - expr.variant_name() - ), - } -} - -/// Converts sort expression to corresponding substrait `SortField` -fn to_substrait_sort_field( - producer: &mut impl SubstraitProducer, - sort: &expr::Sort, - schema: &DFSchemaRef, -) -> Result { - let sort_kind = match (sort.asc, sort.nulls_first) { - (true, true) => SortDirection::AscNullsFirst, - (true, false) => SortDirection::AscNullsLast, - (false, true) => SortDirection::DescNullsFirst, - (false, false) => SortDirection::DescNullsLast, - }; - Ok(SortField { - expr: Some(producer.handle_expr(&sort.expr, schema)?), - sort_kind: Some(SortKind::Direction(sort_kind.into())), - }) -} - -/// Return Substrait scalar function with two arguments -pub fn make_binary_op_scalar_func( - producer: &mut impl SubstraitProducer, - lhs: &Expression, - rhs: &Expression, - op: Operator, -) -> Expression { - let function_anchor = producer.register_function(operator_to_name(op).to_string()); - #[allow(deprecated)] - Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![ - FunctionArgument { - arg_type: Some(ArgType::Value(lhs.clone())), - }, - FunctionArgument { - arg_type: Some(ArgType::Value(rhs.clone())), - }, - ], - output_type: None, - args: vec![], - options: vec![], - })), - } -} - -/// Convert DataFusion Expr to Substrait Rex -/// -/// # Arguments -/// * `producer` - SubstraitProducer implementation which the handles the actual conversion -/// * `expr` - DataFusion expression to convert into a Substrait expression -/// * `schema` - DataFusion input schema for looking up columns -pub fn to_substrait_rex( - producer: &mut impl SubstraitProducer, - expr: &Expr, - schema: &DFSchemaRef, -) -> Result { - match expr { - Expr::Alias(expr) => producer.handle_alias(expr, schema), - Expr::Column(expr) => producer.handle_column(expr, schema), - Expr::ScalarVariable(_, _) => { - not_impl_err!("Cannot convert {expr:?} to Substrait") - } - Expr::Literal(expr) => producer.handle_literal(expr), - Expr::BinaryExpr(expr) => producer.handle_binary_expr(expr, schema), - Expr::Like(expr) => producer.handle_like(expr, schema), - Expr::SimilarTo(_) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::Not(_) => producer.handle_unary_expr(expr, schema), - Expr::IsNotNull(_) => producer.handle_unary_expr(expr, schema), - Expr::IsNull(_) => producer.handle_unary_expr(expr, schema), - Expr::IsTrue(_) => producer.handle_unary_expr(expr, schema), - Expr::IsFalse(_) => producer.handle_unary_expr(expr, schema), - Expr::IsUnknown(_) => producer.handle_unary_expr(expr, schema), - Expr::IsNotTrue(_) => producer.handle_unary_expr(expr, schema), - Expr::IsNotFalse(_) => producer.handle_unary_expr(expr, schema), - Expr::IsNotUnknown(_) => producer.handle_unary_expr(expr, schema), - Expr::Negative(_) => producer.handle_unary_expr(expr, schema), - Expr::Between(expr) => producer.handle_between(expr, schema), - Expr::Case(expr) => producer.handle_case(expr, schema), - Expr::Cast(expr) => producer.handle_cast(expr, schema), - Expr::TryCast(expr) => producer.handle_try_cast(expr, schema), - Expr::ScalarFunction(expr) => producer.handle_scalar_function(expr, schema), - Expr::AggregateFunction(_) => { - internal_err!( - "AggregateFunction should only be encountered as part of a LogicalPlan::Aggregate" - ) - } - Expr::WindowFunction(expr) => producer.handle_window_function(expr, schema), - Expr::InList(expr) => producer.handle_in_list(expr, schema), - Expr::Exists(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema), - Expr::ScalarSubquery(expr) => { - not_impl_err!("Cannot convert {expr:?} to Substrait") - } - #[expect(deprecated)] - Expr::Wildcard { .. } => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::GroupingSet(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::Placeholder(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::OuterReferenceColumn(_, _) => { - not_impl_err!("Cannot convert {expr:?} to Substrait") - } - Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - } -} - -pub fn from_in_list( - producer: &mut impl SubstraitProducer, - in_list: &InList, - schema: &DFSchemaRef, -) -> Result { - let InList { - expr, - list, - negated, - } = in_list; - let substrait_list = list - .iter() - .map(|x| producer.handle_expr(x, schema)) - .collect::>>()?; - let substrait_expr = producer.handle_expr(expr, schema)?; - - let substrait_or_list = Expression { - rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { - value: Some(Box::new(substrait_expr)), - options: substrait_list, - }))), - }; - - if *negated { - let function_anchor = producer.register_function("not".to_string()); - - #[allow(deprecated)] - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_or_list)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) - } else { - Ok(substrait_or_list) - } -} - -pub fn from_scalar_function( - producer: &mut impl SubstraitProducer, - fun: &expr::ScalarFunction, - schema: &DFSchemaRef, -) -> Result { - let mut arguments: Vec = vec![]; - for arg in &fun.args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), - }); - } - - let function_anchor = producer.register_function(fun.name().to_string()); - #[allow(deprecated)] - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments, - output_type: None, - options: vec![], - args: vec![], - })), - }) -} - -pub fn from_between( - producer: &mut impl SubstraitProducer, - between: &Between, - schema: &DFSchemaRef, -) -> Result { - let Between { - expr, - negated, - low, - high, - } = between; - if *negated { - // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) - let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; - let substrait_low = producer.handle_expr(low.as_ref(), schema)?; - let substrait_high = producer.handle_expr(high.as_ref(), schema)?; - - let l_expr = make_binary_op_scalar_func( - producer, - &substrait_expr, - &substrait_low, - Operator::Lt, - ); - let r_expr = make_binary_op_scalar_func( - producer, - &substrait_high, - &substrait_expr, - Operator::Lt, - ); - - Ok(make_binary_op_scalar_func( - producer, - &l_expr, - &r_expr, - Operator::Or, - )) - } else { - // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) - let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; - let substrait_low = producer.handle_expr(low.as_ref(), schema)?; - let substrait_high = producer.handle_expr(high.as_ref(), schema)?; - - let l_expr = make_binary_op_scalar_func( - producer, - &substrait_low, - &substrait_expr, - Operator::LtEq, - ); - let r_expr = make_binary_op_scalar_func( - producer, - &substrait_expr, - &substrait_high, - Operator::LtEq, - ); - - Ok(make_binary_op_scalar_func( - producer, - &l_expr, - &r_expr, - Operator::And, - )) - } -} -pub fn from_column(col: &Column, schema: &DFSchemaRef) -> Result { - let index = schema.index_of_column(col)?; - substrait_field_ref(index) -} - -pub fn from_binary_expr( - producer: &mut impl SubstraitProducer, - expr: &BinaryExpr, - schema: &DFSchemaRef, -) -> Result { - let BinaryExpr { left, op, right } = expr; - let l = producer.handle_expr(left, schema)?; - let r = producer.handle_expr(right, schema)?; - Ok(make_binary_op_scalar_func(producer, &l, &r, *op)) -} -pub fn from_case( - producer: &mut impl SubstraitProducer, - case: &Case, - schema: &DFSchemaRef, -) -> Result { - let Case { - expr, - when_then_expr, - else_expr, - } = case; - let mut ifs: Vec = vec![]; - // Parse base - if let Some(e) = expr { - // Base expression exists - ifs.push(IfClause { - r#if: Some(producer.handle_expr(e, schema)?), - then: None, - }); - } - // Parse `when`s - for (r#if, then) in when_then_expr { - ifs.push(IfClause { - r#if: Some(producer.handle_expr(r#if, schema)?), - then: Some(producer.handle_expr(then, schema)?), - }); - } - - // Parse outer `else` - let r#else: Option> = match else_expr { - Some(e) => Some(Box::new(producer.handle_expr(e, schema)?)), - None => None, - }; - - Ok(Expression { - rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))), - }) -} - -pub fn from_cast( - producer: &mut impl SubstraitProducer, - cast: &Cast, - schema: &DFSchemaRef, -) -> Result { - let Cast { expr, data_type } = cast; - Ok(Expression { - rex_type: Some(RexType::Cast(Box::new( - substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(producer.handle_expr(expr, schema)?)), - failure_behavior: FailureBehavior::ThrowException.into(), - }, - ))), - }) -} - -pub fn from_try_cast( - producer: &mut impl SubstraitProducer, - cast: &TryCast, - schema: &DFSchemaRef, -) -> Result { - let TryCast { expr, data_type } = cast; - Ok(Expression { - rex_type: Some(RexType::Cast(Box::new( - substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(producer.handle_expr(expr, schema)?)), - failure_behavior: FailureBehavior::ReturnNull.into(), - }, - ))), - }) -} - -pub fn from_literal( - producer: &mut impl SubstraitProducer, - value: &ScalarValue, -) -> Result { - to_substrait_literal_expr(producer, value) -} - -pub fn from_alias( - producer: &mut impl SubstraitProducer, - alias: &Alias, - schema: &DFSchemaRef, -) -> Result { - producer.handle_expr(alias.expr.as_ref(), schema) -} - -pub fn from_window_function( - producer: &mut impl SubstraitProducer, - window_fn: &WindowFunction, - schema: &DFSchemaRef, -) -> Result { - let WindowFunction { - fun, - params: - WindowFunctionParams { - args, - partition_by, - order_by, - window_frame, - null_treatment: _, - }, - } = window_fn; - // function reference - let function_anchor = producer.register_function(fun.to_string()); - // arguments - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), - }); - } - // partition by expressions - let partition_by = partition_by - .iter() - .map(|e| producer.handle_expr(e, schema)) - .collect::>>()?; - // order by expressions - let order_by = order_by - .iter() - .map(|e| substrait_sort_field(producer, e, schema)) - .collect::>>()?; - // window frame - let bounds = to_substrait_bounds(window_frame)?; - let bound_type = to_substrait_bound_type(window_frame)?; - Ok(make_substrait_window_function( - function_anchor, - arguments, - partition_by, - order_by, - bounds, - bound_type, - )) -} - -pub fn from_like( - producer: &mut impl SubstraitProducer, - like: &Like, - schema: &DFSchemaRef, -) -> Result { - let Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - } = like; - make_substrait_like_expr( - producer, - *case_insensitive, - *negated, - expr, - pattern, - *escape_char, - schema, - ) -} - -pub fn from_in_subquery( - producer: &mut impl SubstraitProducer, - subquery: &InSubquery, - schema: &DFSchemaRef, -) -> Result { - let InSubquery { - expr, - subquery, - negated, - } = subquery; - let substrait_expr = producer.handle_expr(expr, schema)?; - - let subquery_plan = producer.handle_plan(subquery.subquery.as_ref())?; - - let substrait_subquery = Expression { - rex_type: Some(RexType::Subquery(Box::new( - substrait::proto::expression::Subquery { - subquery_type: Some( - substrait::proto::expression::subquery::SubqueryType::InPredicate( - Box::new(InPredicate { - needles: (vec![substrait_expr]), - haystack: Some(subquery_plan), - }), - ), - ), - }, - ))), - }; - if *negated { - let function_anchor = producer.register_function("not".to_string()); - - #[allow(deprecated)] - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_subquery)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) - } else { - Ok(substrait_subquery) - } -} - -pub fn from_unary_expr( - producer: &mut impl SubstraitProducer, - expr: &Expr, - schema: &DFSchemaRef, -) -> Result { - let (fn_name, arg) = match expr { - Expr::Not(arg) => ("not", arg), - Expr::IsNull(arg) => ("is_null", arg), - Expr::IsNotNull(arg) => ("is_not_null", arg), - Expr::IsTrue(arg) => ("is_true", arg), - Expr::IsFalse(arg) => ("is_false", arg), - Expr::IsUnknown(arg) => ("is_unknown", arg), - Expr::IsNotTrue(arg) => ("is_not_true", arg), - Expr::IsNotFalse(arg) => ("is_not_false", arg), - Expr::IsNotUnknown(arg) => ("is_not_unknown", arg), - Expr::Negative(arg) => ("negate", arg), - expr => not_impl_err!("Unsupported expression: {expr:?}")?, - }; - to_substrait_unary_scalar_fn(producer, fn_name, arg, schema) -} - -fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { - let nullability = if nullable { - r#type::Nullability::Nullable as i32 - } else { - r#type::Nullability::Required as i32 - }; - match dt { - DataType::Null => internal_err!("Null cast is not valid"), - DataType::Boolean => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Bool(r#type::Boolean { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Int8 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I8(r#type::I8 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::UInt8 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I8(r#type::I8 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Int16 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I16(r#type::I16 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::UInt16 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I16(r#type::I16 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Int32 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::UInt32 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Int64 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::UInt64 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, - nullability, - })), - }), - // Float16 is not supported in Substrait - DataType::Float32 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Fp32(r#type::Fp32 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Float64 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Fp64(r#type::Fp64 { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Timestamp(unit, tz) => { - let precision = match unit { - TimeUnit::Second => 0, - TimeUnit::Millisecond => 3, - TimeUnit::Microsecond => 6, - TimeUnit::Nanosecond => 9, - }; - let kind = match tz { - None => r#type::Kind::PrecisionTimestamp(r#type::PrecisionTimestamp { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - precision, - }), - Some(_) => { - // If timezone is present, no matter what the actual tz value is, it indicates the - // value of the timestamp is tied to UTC epoch. That's all that Substrait cares about. - // As the timezone is lost, this conversion may be lossy for downstream use of the value. - r#type::Kind::PrecisionTimestampTz(r#type::PrecisionTimestampTz { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - precision, - }) - } - }; - Ok(substrait::proto::Type { kind: Some(kind) }) - } - DataType::Date32 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Date(r#type::Date { - type_variation_reference: DATE_32_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Date64 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Date(r#type::Date { - type_variation_reference: DATE_64_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Interval(interval_unit) => { - match interval_unit { - IntervalUnit::YearMonth => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::IntervalYear(r#type::IntervalYear { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - IntervalUnit::DayTime => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::IntervalDay(r#type::IntervalDay { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - precision: Some(3), // DayTime precision is always milliseconds - })), - }), - IntervalUnit::MonthDayNano => { - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::IntervalCompound( - r#type::IntervalCompound { - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - precision: 9, // nanos - }, - )), - }) - } - } - } - DataType::Binary => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::FixedSizeBinary(length) => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::FixedBinary(r#type::FixedBinary { - length: *length, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::LargeBinary => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::BinaryView => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Utf8 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::LargeUtf8 => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::Utf8View => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, - nullability, - })), - }), - DataType::List(inner) => { - let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::List(Box::new(r#type::List { - r#type: Some(Box::new(inner_type)), - type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, - nullability, - }))), - }) - } - DataType::LargeList(inner) => { - let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::List(Box::new(r#type::List { - r#type: Some(Box::new(inner_type)), - type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, - nullability, - }))), - }) - } - DataType::Map(inner, _) => match inner.data_type() { - DataType::Struct(key_and_value) if key_and_value.len() == 2 => { - let key_type = to_substrait_type( - key_and_value[0].data_type(), - key_and_value[0].is_nullable(), - )?; - let value_type = to_substrait_type( - key_and_value[1].data_type(), - key_and_value[1].is_nullable(), - )?; - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Map(Box::new(r#type::Map { - key: Some(Box::new(key_type)), - value: Some(Box::new(value_type)), - type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, - nullability, - }))), - }) - } - _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), - }, - DataType::Struct(fields) => { - let field_types = fields - .iter() - .map(|field| to_substrait_type(field.data_type(), field.is_nullable())) - .collect::>>()?; - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Struct(r#type::Struct { - types: field_types, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - nullability, - })), - }) - } - DataType::Decimal128(p, s) => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Decimal(r#type::Decimal { - type_variation_reference: DECIMAL_128_TYPE_VARIATION_REF, - nullability, - scale: *s as i32, - precision: *p as i32, - })), - }), - DataType::Decimal256(p, s) => Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Decimal(r#type::Decimal { - type_variation_reference: DECIMAL_256_TYPE_VARIATION_REF, - nullability, - scale: *s as i32, - precision: *p as i32, - })), - }), - _ => not_impl_err!("Unsupported cast type: {dt:?}"), - } -} - -fn make_substrait_window_function( - function_reference: u32, - arguments: Vec, - partitions: Vec, - sorts: Vec, - bounds: (Bound, Bound), - bounds_type: BoundsType, -) -> Expression { - #[allow(deprecated)] - Expression { - rex_type: Some(RexType::WindowFunction(SubstraitWindowFunction { - function_reference, - arguments, - partitions, - sorts, - options: vec![], - output_type: None, - phase: 0, // default to AGGREGATION_PHASE_UNSPECIFIED - invocation: 0, // TODO: fix - lower_bound: Some(bounds.0), - upper_bound: Some(bounds.1), - args: vec![], - bounds_type: bounds_type as i32, - })), - } -} - -fn make_substrait_like_expr( - producer: &mut impl SubstraitProducer, - ignore_case: bool, - negated: bool, - expr: &Expr, - pattern: &Expr, - escape_char: Option, - schema: &DFSchemaRef, -) -> Result { - let function_anchor = if ignore_case { - producer.register_function("ilike".to_string()) - } else { - producer.register_function("like".to_string()) - }; - let expr = producer.handle_expr(expr, schema)?; - let pattern = producer.handle_expr(pattern, schema)?; - let escape_char = to_substrait_literal_expr( - producer, - &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), - )?; - let arguments = vec![ - FunctionArgument { - arg_type: Some(ArgType::Value(expr)), - }, - FunctionArgument { - arg_type: Some(ArgType::Value(pattern)), - }, - FunctionArgument { - arg_type: Some(ArgType::Value(escape_char)), - }, - ]; - - #[allow(deprecated)] - let substrait_like = Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments, - output_type: None, - args: vec![], - options: vec![], - })), - }; - - if negated { - let function_anchor = producer.register_function("not".to_string()); - - #[allow(deprecated)] - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_like)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) - } else { - Ok(substrait_like) - } -} - -fn to_substrait_bound_offset(value: &ScalarValue) -> Option { - match value { - ScalarValue::UInt8(Some(v)) => Some(*v as i64), - ScalarValue::UInt16(Some(v)) => Some(*v as i64), - ScalarValue::UInt32(Some(v)) => Some(*v as i64), - ScalarValue::UInt64(Some(v)) => Some(*v as i64), - ScalarValue::Int8(Some(v)) => Some(*v as i64), - ScalarValue::Int16(Some(v)) => Some(*v as i64), - ScalarValue::Int32(Some(v)) => Some(*v as i64), - ScalarValue::Int64(Some(v)) => Some(*v), - _ => None, - } -} - -fn to_substrait_bound(bound: &WindowFrameBound) -> Bound { - match bound { - WindowFrameBound::CurrentRow => Bound { - kind: Some(BoundKind::CurrentRow(SubstraitBound::CurrentRow {})), - }, - WindowFrameBound::Preceding(s) => match to_substrait_bound_offset(s) { - Some(offset) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { offset })), - }, - None => Bound { - kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), - }, - }, - WindowFrameBound::Following(s) => match to_substrait_bound_offset(s) { - Some(offset) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { offset })), - }, - None => Bound { - kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), - }, - }, - } -} - -fn to_substrait_bound_type(window_frame: &WindowFrame) -> Result { - match window_frame.units { - WindowFrameUnits::Rows => Ok(BoundsType::Rows), // ROWS - WindowFrameUnits::Range => Ok(BoundsType::Range), // RANGE - // TODO: Support GROUPS - unit => not_impl_err!("Unsupported window frame unit: {unit:?}"), - } -} - -fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> { - Ok(( - to_substrait_bound(&window_frame.start_bound), - to_substrait_bound(&window_frame.end_bound), - )) -} - -fn to_substrait_literal( - producer: &mut impl SubstraitProducer, - value: &ScalarValue, -) -> Result { - if value.is_null() { - return Ok(Literal { - nullable: true, - type_variation_reference: DEFAULT_TYPE_VARIATION_REF, - literal_type: Some(LiteralType::Null(to_substrait_type( - &value.data_type(), - true, - )?)), - }); - } - let (literal_type, type_variation_reference) = match value { - ScalarValue::Boolean(Some(b)) => { - (LiteralType::Boolean(*b), DEFAULT_TYPE_VARIATION_REF) - } - ScalarValue::Int8(Some(n)) => { - (LiteralType::I8(*n as i32), DEFAULT_TYPE_VARIATION_REF) - } - ScalarValue::UInt8(Some(n)) => ( - LiteralType::I8(*n as i32), - UNSIGNED_INTEGER_TYPE_VARIATION_REF, - ), - ScalarValue::Int16(Some(n)) => { - (LiteralType::I16(*n as i32), DEFAULT_TYPE_VARIATION_REF) - } - ScalarValue::UInt16(Some(n)) => ( - LiteralType::I16(*n as i32), - UNSIGNED_INTEGER_TYPE_VARIATION_REF, - ), - ScalarValue::Int32(Some(n)) => (LiteralType::I32(*n), DEFAULT_TYPE_VARIATION_REF), - ScalarValue::UInt32(Some(n)) => ( - LiteralType::I32(*n as i32), - UNSIGNED_INTEGER_TYPE_VARIATION_REF, - ), - ScalarValue::Int64(Some(n)) => (LiteralType::I64(*n), DEFAULT_TYPE_VARIATION_REF), - ScalarValue::UInt64(Some(n)) => ( - LiteralType::I64(*n as i64), - UNSIGNED_INTEGER_TYPE_VARIATION_REF, - ), - ScalarValue::Float32(Some(f)) => { - (LiteralType::Fp32(*f), DEFAULT_TYPE_VARIATION_REF) - } - ScalarValue::Float64(Some(f)) => { - (LiteralType::Fp64(*f), DEFAULT_TYPE_VARIATION_REF) - } - ScalarValue::TimestampSecond(Some(t), None) => ( - LiteralType::PrecisionTimestamp(PrecisionTimestamp { - precision: 0, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::TimestampMillisecond(Some(t), None) => ( - LiteralType::PrecisionTimestamp(PrecisionTimestamp { - precision: 3, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::TimestampMicrosecond(Some(t), None) => ( - LiteralType::PrecisionTimestamp(PrecisionTimestamp { - precision: 6, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::TimestampNanosecond(Some(t), None) => ( - LiteralType::PrecisionTimestamp(PrecisionTimestamp { - precision: 9, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - // If timezone is present, no matter what the actual tz value is, it indicates the - // value of the timestamp is tied to UTC epoch. That's all that Substrait cares about. - // As the timezone is lost, this conversion may be lossy for downstream use of the value. - ScalarValue::TimestampSecond(Some(t), Some(_)) => ( - LiteralType::PrecisionTimestampTz(PrecisionTimestamp { - precision: 0, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::TimestampMillisecond(Some(t), Some(_)) => ( - LiteralType::PrecisionTimestampTz(PrecisionTimestamp { - precision: 3, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::TimestampMicrosecond(Some(t), Some(_)) => ( - LiteralType::PrecisionTimestampTz(PrecisionTimestamp { - precision: 6, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::TimestampNanosecond(Some(t), Some(_)) => ( - LiteralType::PrecisionTimestampTz(PrecisionTimestamp { - precision: 9, - value: *t, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::Date32(Some(d)) => { - (LiteralType::Date(*d), DATE_32_TYPE_VARIATION_REF) - } - // Date64 literal is not supported in Substrait - ScalarValue::IntervalYearMonth(Some(i)) => ( - LiteralType::IntervalYearToMonth(IntervalYearToMonth { - // DF only tracks total months, but there should always be 12 months in a year - years: *i / 12, - months: *i % 12, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::IntervalMonthDayNano(Some(i)) => ( - LiteralType::IntervalCompound(IntervalCompound { - interval_year_to_month: Some(IntervalYearToMonth { - years: i.months / 12, - months: i.months % 12, - }), - interval_day_to_second: Some(IntervalDayToSecond { - days: i.days, - seconds: (i.nanoseconds / NANOSECONDS) as i32, - subseconds: i.nanoseconds % NANOSECONDS, - precision_mode: Some(PrecisionMode::Precision(9)), // nanoseconds - }), - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::IntervalDayTime(Some(i)) => ( - LiteralType::IntervalDayToSecond(IntervalDayToSecond { - days: i.days, - seconds: i.milliseconds / 1000, - subseconds: (i.milliseconds % 1000) as i64, - precision_mode: Some(PrecisionMode::Precision(3)), // 3 for milliseconds - }), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::Binary(Some(b)) => ( - LiteralType::Binary(b.clone()), - DEFAULT_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::LargeBinary(Some(b)) => ( - LiteralType::Binary(b.clone()), - LARGE_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::BinaryView(Some(b)) => ( - LiteralType::Binary(b.clone()), - VIEW_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::FixedSizeBinary(_, Some(b)) => ( - LiteralType::FixedBinary(b.clone()), - DEFAULT_TYPE_VARIATION_REF, - ), - ScalarValue::Utf8(Some(s)) => ( - LiteralType::String(s.clone()), - DEFAULT_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::LargeUtf8(Some(s)) => ( - LiteralType::String(s.clone()), - LARGE_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::Utf8View(Some(s)) => ( - LiteralType::String(s.clone()), - VIEW_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::Decimal128(v, p, s) if v.is_some() => ( - LiteralType::Decimal(Decimal { - value: v.unwrap().to_le_bytes().to_vec(), - precision: *p as i32, - scale: *s as i32, - }), - DECIMAL_128_TYPE_VARIATION_REF, - ), - ScalarValue::List(l) => ( - convert_array_to_literal_list(producer, l)?, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::LargeList(l) => ( - convert_array_to_literal_list(producer, l)?, - LARGE_CONTAINER_TYPE_VARIATION_REF, - ), - ScalarValue::Map(m) => { - let map = if m.is_empty() || m.value(0).is_empty() { - let mt = to_substrait_type(m.data_type(), m.is_nullable())?; - let mt = match mt { - substrait::proto::Type { - kind: Some(r#type::Kind::Map(mt)), - } => Ok(mt.as_ref().to_owned()), - _ => exec_err!("Unexpected type for a map: {mt:?}"), - }?; - LiteralType::EmptyMap(mt) - } else { - let keys = (0..m.keys().len()) - .map(|i| { - to_substrait_literal( - producer, - &ScalarValue::try_from_array(&m.keys(), i)?, - ) - }) - .collect::>>()?; - let values = (0..m.values().len()) - .map(|i| { - to_substrait_literal( - producer, - &ScalarValue::try_from_array(&m.values(), i)?, - ) - }) - .collect::>>()?; - - let key_values = keys - .into_iter() - .zip(values.into_iter()) - .map(|(k, v)| { - Ok(KeyValue { - key: Some(k), - value: Some(v), - }) - }) - .collect::>>()?; - LiteralType::Map(Map { key_values }) - }; - (map, DEFAULT_CONTAINER_TYPE_VARIATION_REF) - } - ScalarValue::Struct(s) => ( - LiteralType::Struct(Struct { - fields: s - .columns() - .iter() - .map(|col| { - to_substrait_literal( - producer, - &ScalarValue::try_from_array(col, 0)?, - ) - }) - .collect::>>()?, - }), - DEFAULT_TYPE_VARIATION_REF, - ), - _ => ( - not_impl_err!("Unsupported literal: {value:?}")?, - DEFAULT_TYPE_VARIATION_REF, - ), - }; - - Ok(Literal { - nullable: false, - type_variation_reference, - literal_type: Some(literal_type), - }) -} - -fn convert_array_to_literal_list( - producer: &mut impl SubstraitProducer, - array: &GenericListArray, -) -> Result { - assert_eq!(array.len(), 1); - let nested_array = array.value(0); - - let values = (0..nested_array.len()) - .map(|i| { - to_substrait_literal( - producer, - &ScalarValue::try_from_array(&nested_array, i)?, - ) - }) - .collect::>>()?; - - if values.is_empty() { - let lt = match to_substrait_type(array.data_type(), array.is_nullable())? { - substrait::proto::Type { - kind: Some(r#type::Kind::List(lt)), - } => lt.as_ref().to_owned(), - _ => unreachable!(), - }; - Ok(LiteralType::EmptyList(lt)) - } else { - Ok(LiteralType::List(List { values })) - } -} - -fn to_substrait_literal_expr( - producer: &mut impl SubstraitProducer, - value: &ScalarValue, -) -> Result { - let literal = to_substrait_literal(producer, value)?; - Ok(Expression { - rex_type: Some(RexType::Literal(literal)), - }) -} - -/// Util to generate substrait [RexType::ScalarFunction] with one argument -fn to_substrait_unary_scalar_fn( - producer: &mut impl SubstraitProducer, - fn_name: &str, - arg: &Expr, - schema: &DFSchemaRef, -) -> Result { - let function_anchor = producer.register_function(fn_name.to_string()); - let substrait_expr = producer.handle_expr(arg, schema)?; - - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_expr)), - }], - output_type: None, - options: vec![], - ..Default::default() - })), - }) -} - -/// Try to convert an [Expr] to a [FieldReference]. -/// Returns `Err` if the [Expr] is not a [Expr::Column]. -fn try_to_substrait_field_reference( - expr: &Expr, - schema: &DFSchemaRef, -) -> Result { - match expr { - Expr::Column(col) => { - let index = schema.index_of_column(col)?; - Ok(FieldReference { - reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { - reference_type: Some(reference_segment::ReferenceType::StructField( - Box::new(reference_segment::StructField { - field: index as i32, - child: None, - }), - )), - })), - root_type: Some(RootType::RootReference(RootReference {})), - }) - } - _ => substrait_err!("Expect a `Column` expr, but found {expr:?}"), - } -} - -fn substrait_sort_field( - producer: &mut impl SubstraitProducer, - sort: &SortExpr, - schema: &DFSchemaRef, -) -> Result { - let SortExpr { - expr, - asc, - nulls_first, - } = sort; - let e = producer.handle_expr(expr, schema)?; - let d = match (asc, nulls_first) { - (true, true) => SortDirection::AscNullsFirst, - (true, false) => SortDirection::AscNullsLast, - (false, true) => SortDirection::DescNullsFirst, - (false, false) => SortDirection::DescNullsLast, - }; - Ok(SortField { - expr: Some(e), - sort_kind: Some(SortKind::Direction(d as i32)), - }) -} - -fn substrait_field_ref(index: usize) -> Result { - Ok(Expression { - rex_type: Some(RexType::Selection(Box::new(FieldReference { - reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { - reference_type: Some(reference_segment::ReferenceType::StructField( - Box::new(reference_segment::StructField { - field: index as i32, - child: None, - }), - )), - })), - root_type: Some(RootType::RootReference(RootReference {})), - }))), - }) -} - -#[cfg(test)] -mod test { - use super::*; - use crate::logical_plan::consumer::{ - from_substrait_extended_expr, from_substrait_literal_without_names, - from_substrait_named_struct, from_substrait_type_without_names, - DefaultSubstraitConsumer, - }; - use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; - use datafusion::arrow; - use datafusion::arrow::array::{ - GenericListArray, Int64Builder, MapBuilder, StringBuilder, - }; - use datafusion::arrow::datatypes::{Field, Fields, Schema}; - use datafusion::common::scalar::ScalarStructBuilder; - use datafusion::common::DFSchema; - use datafusion::execution::{SessionState, SessionStateBuilder}; - use datafusion::prelude::SessionContext; - use std::sync::LazyLock; - - static TEST_SESSION_STATE: LazyLock = - LazyLock::new(|| SessionContext::default().state()); - static TEST_EXTENSIONS: LazyLock = LazyLock::new(Extensions::default); - fn test_consumer() -> DefaultSubstraitConsumer<'static> { - let extensions = &TEST_EXTENSIONS; - let state = &TEST_SESSION_STATE; - DefaultSubstraitConsumer::new(extensions, state) - } - - #[test] - fn round_trip_literals() -> Result<()> { - round_trip_literal(ScalarValue::Boolean(None))?; - round_trip_literal(ScalarValue::Boolean(Some(true)))?; - round_trip_literal(ScalarValue::Boolean(Some(false)))?; - - round_trip_literal(ScalarValue::Int8(None))?; - round_trip_literal(ScalarValue::Int8(Some(i8::MIN)))?; - round_trip_literal(ScalarValue::Int8(Some(i8::MAX)))?; - round_trip_literal(ScalarValue::UInt8(None))?; - round_trip_literal(ScalarValue::UInt8(Some(u8::MIN)))?; - round_trip_literal(ScalarValue::UInt8(Some(u8::MAX)))?; - - round_trip_literal(ScalarValue::Int16(None))?; - round_trip_literal(ScalarValue::Int16(Some(i16::MIN)))?; - round_trip_literal(ScalarValue::Int16(Some(i16::MAX)))?; - round_trip_literal(ScalarValue::UInt16(None))?; - round_trip_literal(ScalarValue::UInt16(Some(u16::MIN)))?; - round_trip_literal(ScalarValue::UInt16(Some(u16::MAX)))?; - - round_trip_literal(ScalarValue::Int32(None))?; - round_trip_literal(ScalarValue::Int32(Some(i32::MIN)))?; - round_trip_literal(ScalarValue::Int32(Some(i32::MAX)))?; - round_trip_literal(ScalarValue::UInt32(None))?; - round_trip_literal(ScalarValue::UInt32(Some(u32::MIN)))?; - round_trip_literal(ScalarValue::UInt32(Some(u32::MAX)))?; - - round_trip_literal(ScalarValue::Int64(None))?; - round_trip_literal(ScalarValue::Int64(Some(i64::MIN)))?; - round_trip_literal(ScalarValue::Int64(Some(i64::MAX)))?; - round_trip_literal(ScalarValue::UInt64(None))?; - round_trip_literal(ScalarValue::UInt64(Some(u64::MIN)))?; - round_trip_literal(ScalarValue::UInt64(Some(u64::MAX)))?; - - for (ts, tz) in [ - (Some(12345), None), - (None, None), - (Some(12345), Some("UTC".into())), - (None, Some("UTC".into())), - ] { - round_trip_literal(ScalarValue::TimestampSecond(ts, tz.clone()))?; - round_trip_literal(ScalarValue::TimestampMillisecond(ts, tz.clone()))?; - round_trip_literal(ScalarValue::TimestampMicrosecond(ts, tz.clone()))?; - round_trip_literal(ScalarValue::TimestampNanosecond(ts, tz))?; - } - - round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( - &[ScalarValue::Float32(Some(1.0))], - &DataType::Float32, - )))?; - round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( - &[], - &DataType::Float32, - )))?; - round_trip_literal(ScalarValue::List(Arc::new(GenericListArray::new_null( - Field::new_list_field(DataType::Float32, true).into(), - 1, - ))))?; - round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list( - &[ScalarValue::Float32(Some(1.0))], - &DataType::Float32, - )))?; - round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list( - &[], - &DataType::Float32, - )))?; - round_trip_literal(ScalarValue::LargeList(Arc::new( - GenericListArray::new_null( - Field::new_list_field(DataType::Float32, true).into(), - 1, - ), - )))?; - - // Null map - let mut map_builder = - MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); - map_builder.append(false)?; - round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; - - // Empty map - let mut map_builder = - MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); - map_builder.append(true)?; - round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; - - // Valid map - let mut map_builder = - MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); - map_builder.keys().append_value("key1"); - map_builder.keys().append_value("key2"); - map_builder.values().append_value(1); - map_builder.values().append_value(2); - map_builder.append(true)?; - round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; - - let c0 = Field::new("c0", DataType::Boolean, true); - let c1 = Field::new("c1", DataType::Int32, true); - let c2 = Field::new("c2", DataType::Utf8, true); - round_trip_literal( - ScalarStructBuilder::new() - .with_scalar(c0.to_owned(), ScalarValue::Boolean(Some(true))) - .with_scalar(c1.to_owned(), ScalarValue::Int32(Some(1))) - .with_scalar(c2.to_owned(), ScalarValue::Utf8(None)) - .build()?, - )?; - round_trip_literal(ScalarStructBuilder::new_null(vec![c0, c1, c2]))?; - - round_trip_literal(ScalarValue::IntervalYearMonth(Some(17)))?; - round_trip_literal(ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNano::new(17, 25, 1234567890), - )))?; - round_trip_literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime::new( - 57, 123456, - ))))?; - - Ok(()) - } - - fn round_trip_literal(scalar: ScalarValue) -> Result<()> { - println!("Checking round trip of {scalar:?}"); - let state = SessionContext::default().state(); - let mut producer = DefaultSubstraitProducer::new(&state); - let substrait_literal = to_substrait_literal(&mut producer, &scalar)?; - let roundtrip_scalar = - from_substrait_literal_without_names(&test_consumer(), &substrait_literal)?; - assert_eq!(scalar, roundtrip_scalar); - Ok(()) - } - - #[test] - fn round_trip_types() -> Result<()> { - round_trip_type(DataType::Boolean)?; - round_trip_type(DataType::Int8)?; - round_trip_type(DataType::UInt8)?; - round_trip_type(DataType::Int16)?; - round_trip_type(DataType::UInt16)?; - round_trip_type(DataType::Int32)?; - round_trip_type(DataType::UInt32)?; - round_trip_type(DataType::Int64)?; - round_trip_type(DataType::UInt64)?; - round_trip_type(DataType::Float32)?; - round_trip_type(DataType::Float64)?; - - for tz in [None, Some("UTC".into())] { - round_trip_type(DataType::Timestamp(TimeUnit::Second, tz.clone()))?; - round_trip_type(DataType::Timestamp(TimeUnit::Millisecond, tz.clone()))?; - round_trip_type(DataType::Timestamp(TimeUnit::Microsecond, tz.clone()))?; - round_trip_type(DataType::Timestamp(TimeUnit::Nanosecond, tz))?; - } - - round_trip_type(DataType::Date32)?; - round_trip_type(DataType::Date64)?; - round_trip_type(DataType::Binary)?; - round_trip_type(DataType::FixedSizeBinary(10))?; - round_trip_type(DataType::LargeBinary)?; - round_trip_type(DataType::BinaryView)?; - round_trip_type(DataType::Utf8)?; - round_trip_type(DataType::LargeUtf8)?; - round_trip_type(DataType::Utf8View)?; - round_trip_type(DataType::Decimal128(10, 2))?; - round_trip_type(DataType::Decimal256(30, 2))?; - - round_trip_type(DataType::List( - Field::new_list_field(DataType::Int32, true).into(), - ))?; - round_trip_type(DataType::LargeList( - Field::new_list_field(DataType::Int32, true).into(), - ))?; - - round_trip_type(DataType::Map( - Field::new_struct( - "entries", - [ - Field::new("key", DataType::Utf8, false).into(), - Field::new("value", DataType::Int32, true).into(), - ], - false, - ) - .into(), - false, - ))?; - - round_trip_type(DataType::Struct( - vec![ - Field::new("c0", DataType::Int32, true), - Field::new("c1", DataType::Utf8, true), - ] - .into(), - ))?; - - round_trip_type(DataType::Interval(IntervalUnit::YearMonth))?; - round_trip_type(DataType::Interval(IntervalUnit::MonthDayNano))?; - round_trip_type(DataType::Interval(IntervalUnit::DayTime))?; - - Ok(()) - } - - fn round_trip_type(dt: DataType) -> Result<()> { - println!("Checking round trip of {dt:?}"); - - // As DataFusion doesn't consider nullability as a property of the type, but field, - // it doesn't matter if we set nullability to true or false here. - let substrait = to_substrait_type(&dt, true)?; - let consumer = test_consumer(); - let roundtrip_dt = from_substrait_type_without_names(&consumer, &substrait)?; - assert_eq!(dt, roundtrip_dt); - Ok(()) - } - - #[test] - fn to_field_reference() -> Result<()> { - let expression = substrait_field_ref(2)?; - - match &expression.rex_type { - Some(RexType::Selection(field_ref)) => { - assert_eq!( - field_ref - .root_type - .clone() - .expect("root type should be set"), - RootType::RootReference(RootReference {}) - ); - } - - _ => panic!("Should not be anything other than field reference"), - } - Ok(()) - } - - #[test] - fn named_struct_names() -> Result<()> { - let schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ - Field::new("int", DataType::Int32, true), - Field::new( - "struct", - DataType::Struct(Fields::from(vec![Field::new( - "inner", - DataType::List(Arc::new(Field::new_list_field(DataType::Utf8, true))), - true, - )])), - true, - ), - Field::new("trailer", DataType::Float64, true), - ]))?); - - let named_struct = to_substrait_named_struct(&schema)?; - - // Struct field names should be flattened DFS style - // List field names should be omitted - assert_eq!( - named_struct.names, - vec!["int", "struct", "inner", "trailer"] - ); - - let roundtrip_schema = - from_substrait_named_struct(&test_consumer(), &named_struct)?; - assert_eq!(schema.as_ref(), &roundtrip_schema); - Ok(()) - } - - #[tokio::test] - async fn extended_expressions() -> Result<()> { - let state = SessionStateBuilder::default().build(); - - // One expression, empty input schema - let expr = Expr::Literal(ScalarValue::Int32(Some(42))); - let field = Field::new("out", DataType::Int32, false); - let empty_schema = DFSchemaRef::new(DFSchema::empty()); - let substrait = - to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state)?; - let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?; - - assert_eq!(roundtrip_expr.input_schema, empty_schema); - assert_eq!(roundtrip_expr.exprs.len(), 1); - - let (rt_expr, rt_field) = roundtrip_expr.exprs.first().unwrap(); - assert_eq!(rt_field, &field); - assert_eq!(rt_expr, &expr); - - // Multiple expressions, with column references - let expr1 = Expr::Column("c0".into()); - let expr2 = Expr::Column("c1".into()); - let out1 = Field::new("out1", DataType::Int32, true); - let out2 = Field::new("out2", DataType::Utf8, true); - let input_schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ - Field::new("c0", DataType::Int32, true), - Field::new("c1", DataType::Utf8, true), - ]))?); - - let substrait = to_substrait_extended_expr( - &[(&expr1, &out1), (&expr2, &out2)], - &input_schema, - &state, - )?; - let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?; - - assert_eq!(roundtrip_expr.input_schema, input_schema); - assert_eq!(roundtrip_expr.exprs.len(), 2); - - let mut exprs = roundtrip_expr.exprs.into_iter(); - - let (rt_expr, rt_field) = exprs.next().unwrap(); - assert_eq!(rt_field, out1); - assert_eq!(rt_expr, expr1); - - let (rt_expr, rt_field) = exprs.next().unwrap(); - assert_eq!(rt_field, out2); - assert_eq!(rt_expr, expr2); - - Ok(()) - } - - #[tokio::test] - async fn invalid_extended_expression() { - let state = SessionStateBuilder::default().build(); - - // Not ok if input schema is missing field referenced by expr - let expr = Expr::Column("missing".into()); - let field = Field::new("out", DataType::Int32, false); - let empty_schema = DFSchemaRef::new(DFSchema::empty()); - - let err = to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state); - - assert!(matches!(err, Err(DataFusionError::SchemaError(_, _)))); - } -} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/aggregate_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/aggregate_function.rs new file mode 100644 index 000000000000..0619b497532d --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/aggregate_function.rs @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::DFSchemaRef; +use datafusion::logical_expr::expr; +use datafusion::logical_expr::expr::AggregateFunctionParams; +use substrait::proto::aggregate_function::AggregationInvocation; +use substrait::proto::aggregate_rel::Measure; +use substrait::proto::function_argument::ArgType; +use substrait::proto::sort_field::{SortDirection, SortKind}; +use substrait::proto::{ + AggregateFunction, AggregationPhase, FunctionArgument, SortField, +}; + +pub fn from_aggregate_function( + producer: &mut impl SubstraitProducer, + agg_fn: &expr::AggregateFunction, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let expr::AggregateFunction { + func, + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment: _null_treatment, + }, + } = agg_fn; + let sorts = if let Some(order_by) = order_by { + order_by + .iter() + .map(|expr| to_substrait_sort_field(producer, expr, schema)) + .collect::>>()? + } else { + vec![] + }; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), + }); + } + let function_anchor = producer.register_function(func.name().to_string()); + #[allow(deprecated)] + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: match distinct { + true => AggregationInvocation::Distinct as i32, + false => AggregationInvocation::All as i32, + }, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(producer.handle_expr(f, schema)?), + None => None, + }, + }) +} + +/// Converts sort expression to corresponding substrait `SortField` +fn to_substrait_sort_field( + producer: &mut impl SubstraitProducer, + sort: &expr::Sort, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let sort_kind = match (sort.asc, sort.nulls_first) { + (true, true) => SortDirection::AscNullsFirst, + (true, false) => SortDirection::AscNullsLast, + (false, true) => SortDirection::DescNullsFirst, + (false, false) => SortDirection::DescNullsLast, + }; + Ok(SortField { + expr: Some(producer.handle_expr(&sort.expr, schema)?), + sort_kind: Some(SortKind::Direction(sort_kind.into())), + }) +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs new file mode 100644 index 000000000000..9741dcdd1095 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{to_substrait_type, SubstraitProducer}; +use crate::variation_const::DEFAULT_TYPE_VARIATION_REF; +use datafusion::common::{DFSchemaRef, ScalarValue}; +use datafusion::logical_expr::{Cast, Expr, TryCast}; +use substrait::proto::expression::cast::FailureBehavior; +use substrait::proto::expression::literal::LiteralType; +use substrait::proto::expression::{Literal, RexType}; +use substrait::proto::Expression; + +pub fn from_cast( + producer: &mut impl SubstraitProducer, + cast: &Cast, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let Cast { expr, data_type } = cast; + // since substrait Null must be typed, so if we see a cast(null, dt), we make it a typed null + if let Expr::Literal(lit, _) = expr.as_ref() { + // only the untyped(a null scalar value) null literal need this special handling + // since all other kind of nulls are already typed and can be handled by substrait + // e.g. null:: or null:: + if matches!(lit, ScalarValue::Null) { + let lit = Literal { + nullable: true, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::Null(to_substrait_type( + data_type, true, + )?)), + }; + return Ok(Expression { + rex_type: Some(RexType::Literal(lit)), + }); + } + } + Ok(Expression { + rex_type: Some(RexType::Cast(Box::new( + substrait::proto::expression::Cast { + r#type: Some(to_substrait_type(data_type, true)?), + input: Some(Box::new(producer.handle_expr(expr, schema)?)), + failure_behavior: FailureBehavior::ThrowException.into(), + }, + ))), + }) +} + +pub fn from_try_cast( + producer: &mut impl SubstraitProducer, + cast: &TryCast, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let TryCast { expr, data_type } = cast; + Ok(Expression { + rex_type: Some(RexType::Cast(Box::new( + substrait::proto::expression::Cast { + r#type: Some(to_substrait_type(data_type, true)?), + input: Some(Box::new(producer.handle_expr(expr, schema)?)), + failure_behavior: FailureBehavior::ReturnNull.into(), + }, + ))), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::producer::to_substrait_extended_expr; + use datafusion::arrow::datatypes::{DataType, Field}; + use datafusion::common::DFSchema; + use datafusion::execution::SessionStateBuilder; + use datafusion::logical_expr::ExprSchemable; + use substrait::proto::expression_reference::ExprType; + + #[tokio::test] + async fn fold_cast_null() { + let state = SessionStateBuilder::default().build(); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + let field = Field::new("out", DataType::Int32, false); + + let expr = Expr::Literal(ScalarValue::Null, None) + .cast_to(&DataType::Int32, &empty_schema) + .unwrap(); + + let typed_null = + to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state) + .unwrap(); + + if let ExprType::Expression(expr) = + typed_null.referred_expr[0].expr_type.as_ref().unwrap() + { + let lit = Literal { + nullable: true, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::Null( + to_substrait_type(&DataType::Int32, true).unwrap(), + )), + }; + let expected = Expression { + rex_type: Some(RexType::Literal(lit)), + }; + assert_eq!(*expr, expected); + } else { + panic!("Expected expression type"); + } + + // a typed null should not be folded + let expr = Expr::Literal(ScalarValue::Int64(None), None) + .cast_to(&DataType::Int32, &empty_schema) + .unwrap(); + + let typed_null = + to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state) + .unwrap(); + + if let ExprType::Expression(expr) = + typed_null.referred_expr[0].expr_type.as_ref().unwrap() + { + let cast_expr = substrait::proto::expression::Cast { + r#type: Some(to_substrait_type(&DataType::Int32, true).unwrap()), + input: Some(Box::new(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: true, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::Null( + to_substrait_type(&DataType::Int64, true).unwrap(), + )), + })), + })), + failure_behavior: FailureBehavior::ThrowException as i32, + }; + let expected = Expression { + rex_type: Some(RexType::Cast(Box::new(cast_expr))), + }; + assert_eq!(*expr, expected); + } else { + panic!("Expected expression type"); + } + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/field_reference.rs b/datafusion/substrait/src/logical_plan/producer/expr/field_reference.rs new file mode 100644 index 000000000000..d1d80ca545ff --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/field_reference.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::{substrait_err, Column, DFSchemaRef}; +use datafusion::logical_expr::Expr; +use substrait::proto::expression::field_reference::{ + ReferenceType, RootReference, RootType, +}; +use substrait::proto::expression::{ + reference_segment, FieldReference, ReferenceSegment, RexType, +}; +use substrait::proto::Expression; + +pub fn from_column( + col: &Column, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let index = schema.index_of_column(col)?; + substrait_field_ref(index) +} + +pub(crate) fn substrait_field_ref( + index: usize, +) -> datafusion::common::Result { + Ok(Expression { + rex_type: Some(RexType::Selection(Box::new(FieldReference { + reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(reference_segment::ReferenceType::StructField( + Box::new(reference_segment::StructField { + field: index as i32, + child: None, + }), + )), + })), + root_type: Some(RootType::RootReference(RootReference {})), + }))), + }) +} + +/// Try to convert an [Expr] to a [FieldReference]. +/// Returns `Err` if the [Expr] is not a [Expr::Column]. +pub(crate) fn try_to_substrait_field_reference( + expr: &Expr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + match expr { + Expr::Column(col) => { + let index = schema.index_of_column(col)?; + Ok(FieldReference { + reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(reference_segment::ReferenceType::StructField( + Box::new(reference_segment::StructField { + field: index as i32, + child: None, + }), + )), + })), + root_type: Some(RootType::RootReference(RootReference {})), + }) + } + _ => substrait_err!("Expect a `Column` expr, but found {expr:?}"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::common::Result; + + #[test] + fn to_field_reference() -> Result<()> { + let expression = substrait_field_ref(2)?; + + match &expression.rex_type { + Some(RexType::Selection(field_ref)) => { + assert_eq!( + field_ref + .root_type + .clone() + .expect("root type should be set"), + RootType::RootReference(RootReference {}) + ); + } + + _ => panic!("Should not be anything other than field reference"), + } + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/if_then.rs b/datafusion/substrait/src/logical_plan/producer/expr/if_then.rs new file mode 100644 index 000000000000..a34959ead76d --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/if_then.rs @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::DFSchemaRef; +use datafusion::logical_expr::Case; +use substrait::proto::expression::if_then::IfClause; +use substrait::proto::expression::{IfThen, RexType}; +use substrait::proto::Expression; + +pub fn from_case( + producer: &mut impl SubstraitProducer, + case: &Case, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let Case { + expr, + when_then_expr, + else_expr, + } = case; + let mut ifs: Vec = vec![]; + // Parse base + if let Some(e) = expr { + // Base expression exists + ifs.push(IfClause { + r#if: Some(producer.handle_expr(e, schema)?), + then: None, + }); + } + // Parse `when`s + for (r#if, then) in when_then_expr { + ifs.push(IfClause { + r#if: Some(producer.handle_expr(r#if, schema)?), + then: Some(producer.handle_expr(then, schema)?), + }); + } + + // Parse outer `else` + let r#else: Option> = match else_expr { + Some(e) => Some(Box::new(producer.handle_expr(e, schema)?)), + None => None, + }; + + Ok(Expression { + rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))), + }) +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/literal.rs b/datafusion/substrait/src/logical_plan/producer/expr/literal.rs new file mode 100644 index 000000000000..31f4866bdc85 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/literal.rs @@ -0,0 +1,483 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{to_substrait_type, SubstraitProducer}; +use crate::variation_const::{ + DATE_32_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, +}; +use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; +use datafusion::arrow::temporal_conversions::NANOSECONDS; +use datafusion::common::{exec_err, not_impl_err, ScalarValue}; +use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode; +use substrait::proto::expression::literal::map::KeyValue; +use substrait::proto::expression::literal::{ + Decimal, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, List, + LiteralType, Map, PrecisionTimestamp, Struct, +}; +use substrait::proto::expression::{Literal, RexType}; +use substrait::proto::{r#type, Expression}; + +pub fn from_literal( + producer: &mut impl SubstraitProducer, + value: &ScalarValue, +) -> datafusion::common::Result { + to_substrait_literal_expr(producer, value) +} + +pub(crate) fn to_substrait_literal_expr( + producer: &mut impl SubstraitProducer, + value: &ScalarValue, +) -> datafusion::common::Result { + let literal = to_substrait_literal(producer, value)?; + Ok(Expression { + rex_type: Some(RexType::Literal(literal)), + }) +} + +pub(crate) fn to_substrait_literal( + producer: &mut impl SubstraitProducer, + value: &ScalarValue, +) -> datafusion::common::Result { + if value.is_null() { + return Ok(Literal { + nullable: true, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::Null(to_substrait_type( + &value.data_type(), + true, + )?)), + }); + } + let (literal_type, type_variation_reference) = match value { + ScalarValue::Boolean(Some(b)) => { + (LiteralType::Boolean(*b), DEFAULT_TYPE_VARIATION_REF) + } + ScalarValue::Int8(Some(n)) => { + (LiteralType::I8(*n as i32), DEFAULT_TYPE_VARIATION_REF) + } + ScalarValue::UInt8(Some(n)) => ( + LiteralType::I8(*n as i32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Int16(Some(n)) => { + (LiteralType::I16(*n as i32), DEFAULT_TYPE_VARIATION_REF) + } + ScalarValue::UInt16(Some(n)) => ( + LiteralType::I16(*n as i32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Int32(Some(n)) => (LiteralType::I32(*n), DEFAULT_TYPE_VARIATION_REF), + ScalarValue::UInt32(Some(n)) => ( + LiteralType::I32(*n as i32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Int64(Some(n)) => (LiteralType::I64(*n), DEFAULT_TYPE_VARIATION_REF), + ScalarValue::UInt64(Some(n)) => ( + LiteralType::I64(*n as i64), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Float32(Some(f)) => { + (LiteralType::Fp32(*f), DEFAULT_TYPE_VARIATION_REF) + } + ScalarValue::Float64(Some(f)) => { + (LiteralType::Fp64(*f), DEFAULT_TYPE_VARIATION_REF) + } + ScalarValue::TimestampSecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 0, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMillisecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 3, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMicrosecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 6, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampNanosecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 9, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + // If timezone is present, no matter what the actual tz value is, it indicates the + // value of the timestamp is tied to UTC epoch. That's all that Substrait cares about. + // As the timezone is lost, this conversion may be lossy for downstream use of the value. + ScalarValue::TimestampSecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 0, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMillisecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 3, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMicrosecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 6, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampNanosecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 9, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::Date32(Some(d)) => { + (LiteralType::Date(*d), DATE_32_TYPE_VARIATION_REF) + } + // Date64 literal is not supported in Substrait + ScalarValue::IntervalYearMonth(Some(i)) => ( + LiteralType::IntervalYearToMonth(IntervalYearToMonth { + // DF only tracks total months, but there should always be 12 months in a year + years: *i / 12, + months: *i % 12, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::IntervalMonthDayNano(Some(i)) => ( + LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month: Some(IntervalYearToMonth { + years: i.months / 12, + months: i.months % 12, + }), + interval_day_to_second: Some(IntervalDayToSecond { + days: i.days, + seconds: (i.nanoseconds / NANOSECONDS) as i32, + subseconds: i.nanoseconds % NANOSECONDS, + precision_mode: Some(PrecisionMode::Precision(9)), // nanoseconds + }), + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::IntervalDayTime(Some(i)) => ( + LiteralType::IntervalDayToSecond(IntervalDayToSecond { + days: i.days, + seconds: i.milliseconds / 1000, + subseconds: (i.milliseconds % 1000) as i64, + precision_mode: Some(PrecisionMode::Precision(3)), // 3 for milliseconds + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::Binary(Some(b)) => ( + LiteralType::Binary(b.clone()), + DEFAULT_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::LargeBinary(Some(b)) => ( + LiteralType::Binary(b.clone()), + LARGE_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::BinaryView(Some(b)) => ( + LiteralType::Binary(b.clone()), + VIEW_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::FixedSizeBinary(_, Some(b)) => ( + LiteralType::FixedBinary(b.clone()), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::Utf8(Some(s)) => ( + LiteralType::String(s.clone()), + DEFAULT_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::LargeUtf8(Some(s)) => ( + LiteralType::String(s.clone()), + LARGE_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::Utf8View(Some(s)) => ( + LiteralType::String(s.clone()), + VIEW_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::Decimal128(v, p, s) if v.is_some() => ( + LiteralType::Decimal(Decimal { + value: v.unwrap().to_le_bytes().to_vec(), + precision: *p as i32, + scale: *s as i32, + }), + DECIMAL_128_TYPE_VARIATION_REF, + ), + ScalarValue::List(l) => ( + convert_array_to_literal_list(producer, l)?, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::LargeList(l) => ( + convert_array_to_literal_list(producer, l)?, + LARGE_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::Map(m) => { + let map = if m.is_empty() || m.value(0).is_empty() { + let mt = to_substrait_type(m.data_type(), m.is_nullable())?; + let mt = match mt { + substrait::proto::Type { + kind: Some(r#type::Kind::Map(mt)), + } => Ok(mt.as_ref().to_owned()), + _ => exec_err!("Unexpected type for a map: {mt:?}"), + }?; + LiteralType::EmptyMap(mt) + } else { + let keys = (0..m.keys().len()) + .map(|i| { + to_substrait_literal( + producer, + &ScalarValue::try_from_array(&m.keys(), i)?, + ) + }) + .collect::>>()?; + let values = (0..m.values().len()) + .map(|i| { + to_substrait_literal( + producer, + &ScalarValue::try_from_array(&m.values(), i)?, + ) + }) + .collect::>>()?; + + let key_values = keys + .into_iter() + .zip(values.into_iter()) + .map(|(k, v)| { + Ok(KeyValue { + key: Some(k), + value: Some(v), + }) + }) + .collect::>>()?; + LiteralType::Map(Map { key_values }) + }; + (map, DEFAULT_CONTAINER_TYPE_VARIATION_REF) + } + ScalarValue::Struct(s) => ( + LiteralType::Struct(Struct { + fields: s + .columns() + .iter() + .map(|col| { + to_substrait_literal( + producer, + &ScalarValue::try_from_array(col, 0)?, + ) + }) + .collect::>>()?, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + _ => ( + not_impl_err!("Unsupported literal: {value:?}")?, + DEFAULT_TYPE_VARIATION_REF, + ), + }; + + Ok(Literal { + nullable: false, + type_variation_reference, + literal_type: Some(literal_type), + }) +} + +fn convert_array_to_literal_list( + producer: &mut impl SubstraitProducer, + array: &GenericListArray, +) -> datafusion::common::Result { + assert_eq!(array.len(), 1); + let nested_array = array.value(0); + + let values = (0..nested_array.len()) + .map(|i| { + to_substrait_literal( + producer, + &ScalarValue::try_from_array(&nested_array, i)?, + ) + }) + .collect::>>()?; + + if values.is_empty() { + let lt = match to_substrait_type(array.data_type(), array.is_nullable())? { + substrait::proto::Type { + kind: Some(r#type::Kind::List(lt)), + } => lt.as_ref().to_owned(), + _ => unreachable!(), + }; + Ok(LiteralType::EmptyList(lt)) + } else { + Ok(LiteralType::List(List { values })) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::consumer::from_substrait_literal_without_names; + use crate::logical_plan::consumer::tests::test_consumer; + use crate::logical_plan::producer::DefaultSubstraitProducer; + use datafusion::arrow::array::{Int64Builder, MapBuilder, StringBuilder}; + use datafusion::arrow::datatypes::{ + DataType, Field, IntervalDayTime, IntervalMonthDayNano, + }; + use datafusion::common::scalar::ScalarStructBuilder; + use datafusion::common::Result; + use datafusion::prelude::SessionContext; + use std::sync::Arc; + + #[test] + fn round_trip_literals() -> Result<()> { + round_trip_literal(ScalarValue::Boolean(None))?; + round_trip_literal(ScalarValue::Boolean(Some(true)))?; + round_trip_literal(ScalarValue::Boolean(Some(false)))?; + + round_trip_literal(ScalarValue::Int8(None))?; + round_trip_literal(ScalarValue::Int8(Some(i8::MIN)))?; + round_trip_literal(ScalarValue::Int8(Some(i8::MAX)))?; + round_trip_literal(ScalarValue::UInt8(None))?; + round_trip_literal(ScalarValue::UInt8(Some(u8::MIN)))?; + round_trip_literal(ScalarValue::UInt8(Some(u8::MAX)))?; + + round_trip_literal(ScalarValue::Int16(None))?; + round_trip_literal(ScalarValue::Int16(Some(i16::MIN)))?; + round_trip_literal(ScalarValue::Int16(Some(i16::MAX)))?; + round_trip_literal(ScalarValue::UInt16(None))?; + round_trip_literal(ScalarValue::UInt16(Some(u16::MIN)))?; + round_trip_literal(ScalarValue::UInt16(Some(u16::MAX)))?; + + round_trip_literal(ScalarValue::Int32(None))?; + round_trip_literal(ScalarValue::Int32(Some(i32::MIN)))?; + round_trip_literal(ScalarValue::Int32(Some(i32::MAX)))?; + round_trip_literal(ScalarValue::UInt32(None))?; + round_trip_literal(ScalarValue::UInt32(Some(u32::MIN)))?; + round_trip_literal(ScalarValue::UInt32(Some(u32::MAX)))?; + + round_trip_literal(ScalarValue::Int64(None))?; + round_trip_literal(ScalarValue::Int64(Some(i64::MIN)))?; + round_trip_literal(ScalarValue::Int64(Some(i64::MAX)))?; + round_trip_literal(ScalarValue::UInt64(None))?; + round_trip_literal(ScalarValue::UInt64(Some(u64::MIN)))?; + round_trip_literal(ScalarValue::UInt64(Some(u64::MAX)))?; + + for (ts, tz) in [ + (Some(12345), None), + (None, None), + (Some(12345), Some("UTC".into())), + (None, Some("UTC".into())), + ] { + round_trip_literal(ScalarValue::TimestampSecond(ts, tz.clone()))?; + round_trip_literal(ScalarValue::TimestampMillisecond(ts, tz.clone()))?; + round_trip_literal(ScalarValue::TimestampMicrosecond(ts, tz.clone()))?; + round_trip_literal(ScalarValue::TimestampNanosecond(ts, tz))?; + } + + round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( + &[ScalarValue::Float32(Some(1.0))], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( + &[], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::List(Arc::new(GenericListArray::new_null( + Field::new_list_field(DataType::Float32, true).into(), + 1, + ))))?; + round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list( + &[ScalarValue::Float32(Some(1.0))], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list( + &[], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::LargeList(Arc::new( + GenericListArray::new_null( + Field::new_list_field(DataType::Float32, true).into(), + 1, + ), + )))?; + + // Null map + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); + map_builder.append(false)?; + round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; + + // Empty map + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); + map_builder.append(true)?; + round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; + + // Valid map + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); + map_builder.keys().append_value("key1"); + map_builder.keys().append_value("key2"); + map_builder.values().append_value(1); + map_builder.values().append_value(2); + map_builder.append(true)?; + round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; + + let c0 = Field::new("c0", DataType::Boolean, true); + let c1 = Field::new("c1", DataType::Int32, true); + let c2 = Field::new("c2", DataType::Utf8, true); + round_trip_literal( + ScalarStructBuilder::new() + .with_scalar(c0.to_owned(), ScalarValue::Boolean(Some(true))) + .with_scalar(c1.to_owned(), ScalarValue::Int32(Some(1))) + .with_scalar(c2.to_owned(), ScalarValue::Utf8(None)) + .build()?, + )?; + round_trip_literal(ScalarStructBuilder::new_null(vec![c0, c1, c2]))?; + + round_trip_literal(ScalarValue::IntervalYearMonth(Some(17)))?; + round_trip_literal(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano::new(17, 25, 1234567890), + )))?; + round_trip_literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime::new( + 57, 123456, + ))))?; + + Ok(()) + } + + fn round_trip_literal(scalar: ScalarValue) -> Result<()> { + println!("Checking round trip of {scalar:?}"); + let state = SessionContext::default().state(); + let mut producer = DefaultSubstraitProducer::new(&state); + let substrait_literal = to_substrait_literal(&mut producer, &scalar)?; + let roundtrip_scalar = + from_substrait_literal_without_names(&test_consumer(), &substrait_literal)?; + assert_eq!(scalar, roundtrip_scalar); + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs new file mode 100644 index 000000000000..42e1f962f1d1 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -0,0 +1,235 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod aggregate_function; +mod cast; +mod field_reference; +mod if_then; +mod literal; +mod scalar_function; +mod singular_or_list; +mod subquery; +mod window_function; + +pub use aggregate_function::*; +pub use cast::*; +pub use field_reference::*; +pub use if_then::*; +pub use literal::*; +pub use scalar_function::*; +pub use singular_or_list::*; +pub use subquery::*; +pub use window_function::*; + +use crate::logical_plan::producer::utils::flatten_names; +use crate::logical_plan::producer::{ + to_substrait_named_struct, DefaultSubstraitProducer, SubstraitProducer, +}; +use datafusion::arrow::datatypes::Field; +use datafusion::common::{internal_err, not_impl_err, DFSchemaRef}; +use datafusion::execution::SessionState; +use datafusion::logical_expr::expr::Alias; +use datafusion::logical_expr::Expr; +use substrait::proto::expression_reference::ExprType; +use substrait::proto::{Expression, ExpressionReference, ExtendedExpression}; +use substrait::version; + +/// Serializes a collection of expressions to a Substrait ExtendedExpression message +/// +/// The ExtendedExpression message is a top-level message that can be used to send +/// expressions (not plans) between systems. +/// +/// Each expression is also given names for the output type. These are provided as a +/// field and not a String (since the names may be nested, e.g. a struct). The data +/// type and nullability of this field is redundant (those can be determined by the +/// Expr) and will be ignored. +/// +/// Substrait also requires the input schema of the expressions to be included in the +/// message. The field names of the input schema will be serialized. +pub fn to_substrait_extended_expr( + exprs: &[(&Expr, &Field)], + schema: &DFSchemaRef, + state: &SessionState, +) -> datafusion::common::Result> { + let mut producer = DefaultSubstraitProducer::new(state); + let substrait_exprs = exprs + .iter() + .map(|(expr, field)| { + let substrait_expr = producer.handle_expr(expr, schema)?; + let mut output_names = Vec::new(); + flatten_names(field, false, &mut output_names)?; + Ok(ExpressionReference { + output_names, + expr_type: Some(ExprType::Expression(substrait_expr)), + }) + }) + .collect::>>()?; + let substrait_schema = to_substrait_named_struct(schema)?; + + let extensions = producer.get_extensions(); + Ok(Box::new(ExtendedExpression { + advanced_extensions: None, + expected_type_urls: vec![], + extension_uris: vec![], + extensions: extensions.into(), + version: Some(version::version_with_producer("datafusion")), + referred_expr: substrait_exprs, + base_schema: Some(substrait_schema), + })) +} + +/// Convert DataFusion Expr to Substrait Rex +/// +/// # Arguments +/// * `producer` - SubstraitProducer implementation which the handles the actual conversion +/// * `expr` - DataFusion expression to convert into a Substrait expression +/// * `schema` - DataFusion input schema for looking up columns +pub fn to_substrait_rex( + producer: &mut impl SubstraitProducer, + expr: &Expr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + match expr { + Expr::Alias(expr) => producer.handle_alias(expr, schema), + Expr::Column(expr) => producer.handle_column(expr, schema), + Expr::ScalarVariable(_, _) => { + not_impl_err!("Cannot convert {expr:?} to Substrait") + } + Expr::Literal(expr, _) => producer.handle_literal(expr), + Expr::BinaryExpr(expr) => producer.handle_binary_expr(expr, schema), + Expr::Like(expr) => producer.handle_like(expr, schema), + Expr::SimilarTo(_) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::Not(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotNull(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNull(_) => producer.handle_unary_expr(expr, schema), + Expr::IsTrue(_) => producer.handle_unary_expr(expr, schema), + Expr::IsFalse(_) => producer.handle_unary_expr(expr, schema), + Expr::IsUnknown(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotTrue(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotFalse(_) => producer.handle_unary_expr(expr, schema), + Expr::IsNotUnknown(_) => producer.handle_unary_expr(expr, schema), + Expr::Negative(_) => producer.handle_unary_expr(expr, schema), + Expr::Between(expr) => producer.handle_between(expr, schema), + Expr::Case(expr) => producer.handle_case(expr, schema), + Expr::Cast(expr) => producer.handle_cast(expr, schema), + Expr::TryCast(expr) => producer.handle_try_cast(expr, schema), + Expr::ScalarFunction(expr) => producer.handle_scalar_function(expr, schema), + Expr::AggregateFunction(_) => { + internal_err!( + "AggregateFunction should only be encountered as part of a LogicalPlan::Aggregate" + ) + } + Expr::WindowFunction(expr) => producer.handle_window_function(expr, schema), + Expr::InList(expr) => producer.handle_in_list(expr, schema), + Expr::Exists(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema), + Expr::ScalarSubquery(expr) => { + not_impl_err!("Cannot convert {expr:?} to Substrait") + } + #[expect(deprecated)] + Expr::Wildcard { .. } => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::GroupingSet(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::Placeholder(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::OuterReferenceColumn(_, _) => { + not_impl_err!("Cannot convert {expr:?} to Substrait") + } + Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + } +} + +pub fn from_alias( + producer: &mut impl SubstraitProducer, + alias: &Alias, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + producer.handle_expr(alias.expr.as_ref(), schema) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::consumer::from_substrait_extended_expr; + use datafusion::arrow::datatypes::{DataType, Schema}; + use datafusion::common::{DFSchema, DataFusionError, ScalarValue}; + use datafusion::execution::SessionStateBuilder; + + #[tokio::test] + async fn extended_expressions() -> datafusion::common::Result<()> { + let state = SessionStateBuilder::default().build(); + + // One expression, empty input schema + let expr = Expr::Literal(ScalarValue::Int32(Some(42)), None); + let field = Field::new("out", DataType::Int32, false); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + let substrait = + to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state)?; + let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?; + + assert_eq!(roundtrip_expr.input_schema, empty_schema); + assert_eq!(roundtrip_expr.exprs.len(), 1); + + let (rt_expr, rt_field) = roundtrip_expr.exprs.first().unwrap(); + assert_eq!(rt_field, &field); + assert_eq!(rt_expr, &expr); + + // Multiple expressions, with column references + let expr1 = Expr::Column("c0".into()); + let expr2 = Expr::Column("c1".into()); + let out1 = Field::new("out1", DataType::Int32, true); + let out2 = Field::new("out2", DataType::Utf8, true); + let input_schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ + Field::new("c0", DataType::Int32, true), + Field::new("c1", DataType::Utf8, true), + ]))?); + + let substrait = to_substrait_extended_expr( + &[(&expr1, &out1), (&expr2, &out2)], + &input_schema, + &state, + )?; + let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?; + + assert_eq!(roundtrip_expr.input_schema, input_schema); + assert_eq!(roundtrip_expr.exprs.len(), 2); + + let mut exprs = roundtrip_expr.exprs.into_iter(); + + let (rt_expr, rt_field) = exprs.next().unwrap(); + assert_eq!(rt_field, out1); + assert_eq!(rt_expr, expr1); + + let (rt_expr, rt_field) = exprs.next().unwrap(); + assert_eq!(rt_field, out2); + assert_eq!(rt_expr, expr2); + + Ok(()) + } + + #[tokio::test] + async fn invalid_extended_expression() { + let state = SessionStateBuilder::default().build(); + + // Not ok if input schema is missing field referenced by expr + let expr = Expr::Column("missing".into()); + let field = Field::new("out", DataType::Int32, false); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + + let err = to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state); + + assert!(matches!(err, Err(DataFusionError::SchemaError(_, _)))); + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs new file mode 100644 index 000000000000..1172c43319c6 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs @@ -0,0 +1,327 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{to_substrait_literal_expr, SubstraitProducer}; +use datafusion::common::{not_impl_err, DFSchemaRef, ScalarValue}; +use datafusion::logical_expr::{expr, Between, BinaryExpr, Expr, Like, Operator}; +use substrait::proto::expression::{RexType, ScalarFunction}; +use substrait::proto::function_argument::ArgType; +use substrait::proto::{Expression, FunctionArgument}; + +pub fn from_scalar_function( + producer: &mut impl SubstraitProducer, + fun: &expr::ScalarFunction, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let mut arguments: Vec = vec![]; + for arg in &fun.args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), + }); + } + + let function_anchor = producer.register_function(fun.name().to_string()); + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments, + output_type: None, + options: vec![], + args: vec![], + })), + }) +} + +pub fn from_unary_expr( + producer: &mut impl SubstraitProducer, + expr: &Expr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let (fn_name, arg) = match expr { + Expr::Not(arg) => ("not", arg), + Expr::IsNull(arg) => ("is_null", arg), + Expr::IsNotNull(arg) => ("is_not_null", arg), + Expr::IsTrue(arg) => ("is_true", arg), + Expr::IsFalse(arg) => ("is_false", arg), + Expr::IsUnknown(arg) => ("is_unknown", arg), + Expr::IsNotTrue(arg) => ("is_not_true", arg), + Expr::IsNotFalse(arg) => ("is_not_false", arg), + Expr::IsNotUnknown(arg) => ("is_not_unknown", arg), + Expr::Negative(arg) => ("negate", arg), + expr => not_impl_err!("Unsupported expression: {expr:?}")?, + }; + to_substrait_unary_scalar_fn(producer, fn_name, arg, schema) +} + +pub fn from_binary_expr( + producer: &mut impl SubstraitProducer, + expr: &BinaryExpr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let BinaryExpr { left, op, right } = expr; + let l = producer.handle_expr(left, schema)?; + let r = producer.handle_expr(right, schema)?; + Ok(make_binary_op_scalar_func(producer, &l, &r, *op)) +} + +pub fn from_like( + producer: &mut impl SubstraitProducer, + like: &Like, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + } = like; + make_substrait_like_expr( + producer, + *case_insensitive, + *negated, + expr, + pattern, + *escape_char, + schema, + ) +} + +fn make_substrait_like_expr( + producer: &mut impl SubstraitProducer, + ignore_case: bool, + negated: bool, + expr: &Expr, + pattern: &Expr, + escape_char: Option, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let function_anchor = if ignore_case { + producer.register_function("ilike".to_string()) + } else { + producer.register_function("like".to_string()) + }; + let expr = producer.handle_expr(expr, schema)?; + let pattern = producer.handle_expr(pattern, schema)?; + let escape_char = to_substrait_literal_expr( + producer, + &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), + )?; + let arguments = vec![ + FunctionArgument { + arg_type: Some(ArgType::Value(expr)), + }, + FunctionArgument { + arg_type: Some(ArgType::Value(pattern)), + }, + FunctionArgument { + arg_type: Some(ArgType::Value(escape_char)), + }, + ]; + + #[allow(deprecated)] + let substrait_like = Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments, + output_type: None, + args: vec![], + options: vec![], + })), + }; + + if negated { + let function_anchor = producer.register_function("not".to_string()); + + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_like)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_like) + } +} + +/// Util to generate substrait [RexType::ScalarFunction] with one argument +fn to_substrait_unary_scalar_fn( + producer: &mut impl SubstraitProducer, + fn_name: &str, + arg: &Expr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let function_anchor = producer.register_function(fn_name.to_string()); + let substrait_expr = producer.handle_expr(arg, schema)?; + + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_expr)), + }], + output_type: None, + options: vec![], + ..Default::default() + })), + }) +} + +/// Return Substrait scalar function with two arguments +pub fn make_binary_op_scalar_func( + producer: &mut impl SubstraitProducer, + lhs: &Expression, + rhs: &Expression, + op: Operator, +) -> Expression { + let function_anchor = producer.register_function(operator_to_name(op).to_string()); + #[allow(deprecated)] + Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![ + FunctionArgument { + arg_type: Some(ArgType::Value(lhs.clone())), + }, + FunctionArgument { + arg_type: Some(ArgType::Value(rhs.clone())), + }, + ], + output_type: None, + args: vec![], + options: vec![], + })), + } +} + +pub fn from_between( + producer: &mut impl SubstraitProducer, + between: &Between, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let Between { + expr, + negated, + low, + high, + } = between; + if *negated { + // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) + let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; + let substrait_low = producer.handle_expr(low.as_ref(), schema)?; + let substrait_high = producer.handle_expr(high.as_ref(), schema)?; + + let l_expr = make_binary_op_scalar_func( + producer, + &substrait_expr, + &substrait_low, + Operator::Lt, + ); + let r_expr = make_binary_op_scalar_func( + producer, + &substrait_high, + &substrait_expr, + Operator::Lt, + ); + + Ok(make_binary_op_scalar_func( + producer, + &l_expr, + &r_expr, + Operator::Or, + )) + } else { + // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) + let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; + let substrait_low = producer.handle_expr(low.as_ref(), schema)?; + let substrait_high = producer.handle_expr(high.as_ref(), schema)?; + + let l_expr = make_binary_op_scalar_func( + producer, + &substrait_low, + &substrait_expr, + Operator::LtEq, + ); + let r_expr = make_binary_op_scalar_func( + producer, + &substrait_expr, + &substrait_high, + Operator::LtEq, + ); + + Ok(make_binary_op_scalar_func( + producer, + &l_expr, + &r_expr, + Operator::And, + )) + } +} + +pub fn operator_to_name(op: Operator) -> &'static str { + match op { + Operator::Eq => "equal", + Operator::NotEq => "not_equal", + Operator::Lt => "lt", + Operator::LtEq => "lte", + Operator::Gt => "gt", + Operator::GtEq => "gte", + Operator::Plus => "add", + Operator::Minus => "subtract", + Operator::Multiply => "multiply", + Operator::Divide => "divide", + Operator::Modulo => "modulus", + Operator::And => "and", + Operator::Or => "or", + Operator::IsDistinctFrom => "is_distinct_from", + Operator::IsNotDistinctFrom => "is_not_distinct_from", + Operator::RegexMatch => "regex_match", + Operator::RegexIMatch => "regex_imatch", + Operator::RegexNotMatch => "regex_not_match", + Operator::RegexNotIMatch => "regex_not_imatch", + Operator::LikeMatch => "like_match", + Operator::ILikeMatch => "like_imatch", + Operator::NotLikeMatch => "like_not_match", + Operator::NotILikeMatch => "like_not_imatch", + Operator::BitwiseAnd => "bitwise_and", + Operator::BitwiseOr => "bitwise_or", + Operator::StringConcat => "str_concat", + Operator::AtArrow => "at_arrow", + Operator::ArrowAt => "arrow_at", + Operator::Arrow => "arrow", + Operator::LongArrow => "long_arrow", + Operator::HashArrow => "hash_arrow", + Operator::HashLongArrow => "hash_long_arrow", + Operator::AtAt => "at_at", + Operator::IntegerDivide => "integer_divide", + Operator::HashMinus => "hash_minus", + Operator::AtQuestion => "at_question", + Operator::Question => "question", + Operator::QuestionAnd => "question_and", + Operator::QuestionPipe => "question_pipe", + Operator::BitwiseXor => "bitwise_xor", + Operator::BitwiseShiftRight => "bitwise_shift_right", + Operator::BitwiseShiftLeft => "bitwise_shift_left", + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/singular_or_list.rs b/datafusion/substrait/src/logical_plan/producer/expr/singular_or_list.rs new file mode 100644 index 000000000000..1c0b6dcc154b --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/singular_or_list.rs @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::DFSchemaRef; +use datafusion::logical_expr::expr::InList; +use substrait::proto::expression::{RexType, ScalarFunction, SingularOrList}; +use substrait::proto::function_argument::ArgType; +use substrait::proto::{Expression, FunctionArgument}; + +pub fn from_in_list( + producer: &mut impl SubstraitProducer, + in_list: &InList, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let InList { + expr, + list, + negated, + } = in_list; + let substrait_list = list + .iter() + .map(|x| producer.handle_expr(x, schema)) + .collect::>>()?; + let substrait_expr = producer.handle_expr(expr, schema)?; + + let substrait_or_list = Expression { + rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { + value: Some(Box::new(substrait_expr)), + options: substrait_list, + }))), + }; + + if *negated { + let function_anchor = producer.register_function("not".to_string()); + + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_or_list)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_or_list) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs b/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs new file mode 100644 index 000000000000..c1ee78c68c25 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::DFSchemaRef; +use datafusion::logical_expr::expr::InSubquery; +use substrait::proto::expression::subquery::InPredicate; +use substrait::proto::expression::{RexType, ScalarFunction}; +use substrait::proto::function_argument::ArgType; +use substrait::proto::{Expression, FunctionArgument}; + +pub fn from_in_subquery( + producer: &mut impl SubstraitProducer, + subquery: &InSubquery, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let InSubquery { + expr, + subquery, + negated, + } = subquery; + let substrait_expr = producer.handle_expr(expr, schema)?; + + let subquery_plan = producer.handle_plan(subquery.subquery.as_ref())?; + + let substrait_subquery = Expression { + rex_type: Some(RexType::Subquery(Box::new( + substrait::proto::expression::Subquery { + subquery_type: Some( + substrait::proto::expression::subquery::SubqueryType::InPredicate( + Box::new(InPredicate { + needles: (vec![substrait_expr]), + haystack: Some(subquery_plan), + }), + ), + ), + }, + ))), + }; + if *negated { + let function_anchor = producer.register_function("not".to_string()); + + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_subquery)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_subquery) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs new file mode 100644 index 000000000000..17e71f2d7c14 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::utils::substrait_sort_field; +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::{not_impl_err, DFSchemaRef, ScalarValue}; +use datafusion::logical_expr::expr::{WindowFunction, WindowFunctionParams}; +use datafusion::logical_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; +use substrait::proto::expression::window_function::bound as SubstraitBound; +use substrait::proto::expression::window_function::bound::Kind as BoundKind; +use substrait::proto::expression::window_function::{Bound, BoundsType}; +use substrait::proto::expression::RexType; +use substrait::proto::expression::WindowFunction as SubstraitWindowFunction; +use substrait::proto::function_argument::ArgType; +use substrait::proto::{Expression, FunctionArgument, SortField}; + +pub fn from_window_function( + producer: &mut impl SubstraitProducer, + window_fn: &WindowFunction, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let WindowFunction { + fun, + params: + WindowFunctionParams { + args, + partition_by, + order_by, + window_frame, + null_treatment: _, + }, + } = window_fn; + // function reference + let function_anchor = producer.register_function(fun.to_string()); + // arguments + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(producer.handle_expr(arg, schema)?)), + }); + } + // partition by expressions + let partition_by = partition_by + .iter() + .map(|e| producer.handle_expr(e, schema)) + .collect::>>()?; + // order by expressions + let order_by = order_by + .iter() + .map(|e| substrait_sort_field(producer, e, schema)) + .collect::>>()?; + // window frame + let bounds = to_substrait_bounds(window_frame)?; + let bound_type = to_substrait_bound_type(window_frame)?; + Ok(make_substrait_window_function( + function_anchor, + arguments, + partition_by, + order_by, + bounds, + bound_type, + )) +} + +fn make_substrait_window_function( + function_reference: u32, + arguments: Vec, + partitions: Vec, + sorts: Vec, + bounds: (Bound, Bound), + bounds_type: BoundsType, +) -> Expression { + #[allow(deprecated)] + Expression { + rex_type: Some(RexType::WindowFunction(SubstraitWindowFunction { + function_reference, + arguments, + partitions, + sorts, + options: vec![], + output_type: None, + phase: 0, // default to AGGREGATION_PHASE_UNSPECIFIED + invocation: 0, // TODO: fix + lower_bound: Some(bounds.0), + upper_bound: Some(bounds.1), + args: vec![], + bounds_type: bounds_type as i32, + })), + } +} + +fn to_substrait_bound_type( + window_frame: &WindowFrame, +) -> datafusion::common::Result { + match window_frame.units { + WindowFrameUnits::Rows => Ok(BoundsType::Rows), // ROWS + WindowFrameUnits::Range => Ok(BoundsType::Range), // RANGE + // TODO: Support GROUPS + unit => not_impl_err!("Unsupported window frame unit: {unit:?}"), + } +} + +fn to_substrait_bounds( + window_frame: &WindowFrame, +) -> datafusion::common::Result<(Bound, Bound)> { + Ok(( + to_substrait_bound(&window_frame.start_bound), + to_substrait_bound(&window_frame.end_bound), + )) +} + +fn to_substrait_bound(bound: &WindowFrameBound) -> Bound { + match bound { + WindowFrameBound::CurrentRow => Bound { + kind: Some(BoundKind::CurrentRow(SubstraitBound::CurrentRow {})), + }, + WindowFrameBound::Preceding(s) => match to_substrait_bound_offset(s) { + Some(offset) => Bound { + kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { offset })), + }, + None => Bound { + kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), + }, + }, + WindowFrameBound::Following(s) => match to_substrait_bound_offset(s) { + Some(offset) => Bound { + kind: Some(BoundKind::Following(SubstraitBound::Following { offset })), + }, + None => Bound { + kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), + }, + }, + } +} + +fn to_substrait_bound_offset(value: &ScalarValue) -> Option { + match value { + ScalarValue::UInt8(Some(v)) => Some(*v as i64), + ScalarValue::UInt16(Some(v)) => Some(*v as i64), + ScalarValue::UInt32(Some(v)) => Some(*v as i64), + ScalarValue::UInt64(Some(v)) => Some(*v as i64), + ScalarValue::Int8(Some(v)) => Some(*v as i64), + ScalarValue::Int16(Some(v)) => Some(*v as i64), + ScalarValue::Int32(Some(v)) => Some(*v as i64), + ScalarValue::Int64(Some(v)) => Some(*v), + _ => None, + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/mod.rs b/datafusion/substrait/src/logical_plan/producer/mod.rs new file mode 100644 index 000000000000..fc4af94a25fe --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/mod.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod expr; +mod plan; +mod rel; +mod substrait_producer; +mod types; +mod utils; + +pub use expr::*; +pub use plan::*; +pub use rel::*; +pub use substrait_producer::*; +pub(crate) use types::*; +pub(crate) use utils::*; diff --git a/datafusion/substrait/src/logical_plan/producer/plan.rs b/datafusion/substrait/src/logical_plan/producer/plan.rs new file mode 100644 index 000000000000..7d5b7754122d --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/plan.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{ + to_substrait_named_struct, DefaultSubstraitProducer, SubstraitProducer, +}; +use datafusion::execution::SessionState; +use datafusion::logical_expr::{LogicalPlan, SubqueryAlias}; +use substrait::proto::{plan_rel, Plan, PlanRel, Rel, RelRoot}; +use substrait::version; + +/// Convert DataFusion LogicalPlan to Substrait Plan +pub fn to_substrait_plan( + plan: &LogicalPlan, + state: &SessionState, +) -> datafusion::common::Result> { + // Parse relation nodes + // Generate PlanRel(s) + // Note: Only 1 relation tree is currently supported + + let mut producer: DefaultSubstraitProducer = DefaultSubstraitProducer::new(state); + let plan_rels = vec![PlanRel { + rel_type: Some(plan_rel::RelType::Root(RelRoot { + input: Some(*producer.handle_plan(plan)?), + names: to_substrait_named_struct(plan.schema())?.names, + })), + }]; + + // Return parsed plan + let extensions = producer.get_extensions(); + Ok(Box::new(Plan { + version: Some(version::version_with_producer("datafusion")), + extension_uris: vec![], + extensions: extensions.into(), + relations: plan_rels, + advanced_extensions: None, + expected_type_urls: vec![], + parameter_bindings: vec![], + })) +} + +pub fn from_subquery_alias( + producer: &mut impl SubstraitProducer, + alias: &SubqueryAlias, +) -> datafusion::common::Result> { + // Do nothing if encounters SubqueryAlias + // since there is no corresponding relation type in Substrait + producer.handle_plan(alias.input.as_ref()) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/aggregate_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/aggregate_rel.rs new file mode 100644 index 000000000000..4abd283a7ee0 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/aggregate_rel.rs @@ -0,0 +1,182 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{ + from_aggregate_function, substrait_field_ref, SubstraitProducer, +}; +use datafusion::common::{internal_err, not_impl_err, DFSchemaRef, DataFusionError}; +use datafusion::logical_expr::expr::Alias; +use datafusion::logical_expr::{Aggregate, Distinct, Expr, GroupingSet}; +use substrait::proto::aggregate_rel::{Grouping, Measure}; +use substrait::proto::rel::RelType; +use substrait::proto::{AggregateRel, Expression, Rel}; + +pub fn from_aggregate( + producer: &mut impl SubstraitProducer, + agg: &Aggregate, +) -> datafusion::common::Result> { + let input = producer.handle_plan(agg.input.as_ref())?; + let (grouping_expressions, groupings) = + to_substrait_groupings(producer, &agg.group_expr, agg.input.schema())?; + let measures = agg + .aggr_expr + .iter() + .map(|e| to_substrait_agg_measure(producer, e, agg.input.schema())) + .collect::>>()?; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { + common: None, + input: Some(input), + grouping_expressions, + groupings, + measures, + advanced_extension: None, + }))), + })) +} + +pub fn from_distinct( + producer: &mut impl SubstraitProducer, + distinct: &Distinct, +) -> datafusion::common::Result> { + match distinct { + Distinct::All(plan) => { + // Use Substrait's AggregateRel with empty measures to represent `select distinct` + let input = producer.handle_plan(plan.as_ref())?; + // Get grouping keys from the input relation's number of output fields + let grouping = (0..plan.schema().fields().len()) + .map(substrait_field_ref) + .collect::>>()?; + + #[allow(deprecated)] + Ok(Box::new(Rel { + rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { + common: None, + input: Some(input), + grouping_expressions: vec![], + groupings: vec![Grouping { + grouping_expressions: grouping, + expression_references: vec![], + }], + measures: vec![], + advanced_extension: None, + }))), + })) + } + Distinct::On(_) => not_impl_err!("Cannot convert Distinct::On"), + } +} + +pub fn to_substrait_groupings( + producer: &mut impl SubstraitProducer, + exprs: &[Expr], + schema: &DFSchemaRef, +) -> datafusion::common::Result<(Vec, Vec)> { + let mut ref_group_exprs = vec![]; + let groupings = match exprs.len() { + 1 => match &exprs[0] { + Expr::GroupingSet(gs) => match gs { + GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented( + "GroupingSet CUBE is not yet supported".to_string(), + )), + GroupingSet::GroupingSets(sets) => Ok(sets + .iter() + .map(|set| { + parse_flat_grouping_exprs( + producer, + set, + schema, + &mut ref_group_exprs, + ) + }) + .collect::>>()?), + GroupingSet::Rollup(set) => { + let mut sets: Vec> = vec![vec![]]; + for i in 0..set.len() { + sets.push(set[..=i].to_vec()); + } + Ok(sets + .iter() + .rev() + .map(|set| { + parse_flat_grouping_exprs( + producer, + set, + schema, + &mut ref_group_exprs, + ) + }) + .collect::>>()?) + } + }, + _ => Ok(vec![parse_flat_grouping_exprs( + producer, + exprs, + schema, + &mut ref_group_exprs, + )?]), + }, + _ => Ok(vec![parse_flat_grouping_exprs( + producer, + exprs, + schema, + &mut ref_group_exprs, + )?]), + }?; + Ok((ref_group_exprs, groupings)) +} + +pub fn parse_flat_grouping_exprs( + producer: &mut impl SubstraitProducer, + exprs: &[Expr], + schema: &DFSchemaRef, + ref_group_exprs: &mut Vec, +) -> datafusion::common::Result { + let mut expression_references = vec![]; + let mut grouping_expressions = vec![]; + + for e in exprs { + let rex = producer.handle_expr(e, schema)?; + grouping_expressions.push(rex.clone()); + ref_group_exprs.push(rex); + expression_references.push((ref_group_exprs.len() - 1) as u32); + } + #[allow(deprecated)] + Ok(Grouping { + grouping_expressions, + expression_references, + }) +} + +pub fn to_substrait_agg_measure( + producer: &mut impl SubstraitProducer, + expr: &Expr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + match expr { + Expr::AggregateFunction(agg_fn) => from_aggregate_function(producer, agg_fn, schema), + Expr::Alias(Alias { expr, .. }) => { + to_substrait_agg_measure(producer, expr, schema) + } + _ => internal_err!( + "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", + expr, + expr.variant_name() + ), + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/exchange_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/exchange_rel.rs new file mode 100644 index 000000000000..9e0ef8905f43 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/exchange_rel.rs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{ + try_to_substrait_field_reference, SubstraitProducer, +}; +use datafusion::common::not_impl_err; +use datafusion::logical_expr::{Partitioning, Repartition}; +use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; +use substrait::proto::rel::RelType; +use substrait::proto::{ExchangeRel, Rel}; + +pub fn from_repartition( + producer: &mut impl SubstraitProducer, + repartition: &Repartition, +) -> datafusion::common::Result> { + let input = producer.handle_plan(repartition.input.as_ref())?; + let partition_count = match repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(num) => num, + Partitioning::Hash(_, num) => num, + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) + } + }; + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let exchange_kind = match &repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(_) => { + ExchangeKind::RoundRobin(RoundRobin::default()) + } + Partitioning::Hash(exprs, _) => { + let fields = exprs + .iter() + .map(|e| try_to_substrait_field_reference(e, repartition.input.schema())) + .collect::>>()?; + ExchangeKind::ScatterByFields(ScatterFields { fields }) + } + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) + } + }; + let exchange_rel = ExchangeRel { + common: None, + input: Some(input), + exchange_kind: Some(exchange_kind), + advanced_extension: None, + partition_count: partition_count as i32, + targets: vec![], + }; + Ok(Box::new(Rel { + rel_type: Some(RelType::Exchange(Box::new(exchange_rel))), + })) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/fetch_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/fetch_rel.rs new file mode 100644 index 000000000000..4706401d558e --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/fetch_rel.rs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::DFSchema; +use datafusion::logical_expr::Limit; +use std::sync::Arc; +use substrait::proto::rel::RelType; +use substrait::proto::{fetch_rel, FetchRel, Rel}; + +pub fn from_limit( + producer: &mut impl SubstraitProducer, + limit: &Limit, +) -> datafusion::common::Result> { + let input = producer.handle_plan(limit.input.as_ref())?; + let empty_schema = Arc::new(DFSchema::empty()); + let offset_mode = limit + .skip + .as_ref() + .map(|expr| producer.handle_expr(expr.as_ref(), &empty_schema)) + .transpose()? + .map(Box::new) + .map(fetch_rel::OffsetMode::OffsetExpr); + let count_mode = limit + .fetch + .as_ref() + .map(|expr| producer.handle_expr(expr.as_ref(), &empty_schema)) + .transpose()? + .map(Box::new) + .map(fetch_rel::CountMode::CountExpr); + Ok(Box::new(Rel { + rel_type: Some(RelType::Fetch(Box::new(FetchRel { + common: None, + input: Some(input), + offset_mode, + count_mode, + advanced_extension: None, + }))), + })) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/filter_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/filter_rel.rs new file mode 100644 index 000000000000..770696dfe1a9 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/filter_rel.rs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::logical_expr::Filter; +use substrait::proto::rel::RelType; +use substrait::proto::{FilterRel, Rel}; + +pub fn from_filter( + producer: &mut impl SubstraitProducer, + filter: &Filter, +) -> datafusion::common::Result> { + let input = producer.handle_plan(filter.input.as_ref())?; + let filter_expr = producer.handle_expr(&filter.predicate, filter.input.schema())?; + Ok(Box::new(Rel { + rel_type: Some(RelType::Filter(Box::new(FilterRel { + common: None, + input: Some(input), + condition: Some(Box::new(filter_expr)), + advanced_extension: None, + }))), + })) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/join.rs b/datafusion/substrait/src/logical_plan/producer/rel/join.rs new file mode 100644 index 000000000000..3dbac636feed --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/join.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{make_binary_op_scalar_func, SubstraitProducer}; +use datafusion::common::{ + not_impl_err, DFSchemaRef, JoinConstraint, JoinType, NullEquality, +}; +use datafusion::logical_expr::{Expr, Join, Operator}; +use std::sync::Arc; +use substrait::proto::rel::RelType; +use substrait::proto::{join_rel, Expression, JoinRel, Rel}; + +pub fn from_join( + producer: &mut impl SubstraitProducer, + join: &Join, +) -> datafusion::common::Result> { + let left = producer.handle_plan(join.left.as_ref())?; + let right = producer.handle_plan(join.right.as_ref())?; + let join_type = to_substrait_jointype(join.join_type); + // we only support basic joins so return an error for anything not yet supported + match join.join_constraint { + JoinConstraint::On => {} + JoinConstraint::Using => return not_impl_err!("join constraint: `using`"), + } + let in_join_schema = Arc::new(join.left.schema().join(join.right.schema())?); + + // convert filter if present + let join_filter = match &join.filter { + Some(filter) => Some(producer.handle_expr(filter, &in_join_schema)?), + None => None, + }; + + // map the left and right columns to binary expressions in the form `l = r` + // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` + let eq_op = match join.null_equality { + NullEquality::NullEqualsNothing => Operator::Eq, + NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom, + }; + let join_on = to_substrait_join_expr(producer, &join.on, eq_op, &in_join_schema)?; + + // create conjunction between `join_on` and `join_filter` to embed all join conditions, + // whether equal or non-equal in a single expression + let join_expr = match &join_on { + Some(on_expr) => match &join_filter { + Some(filter) => Some(Box::new(make_binary_op_scalar_func( + producer, + on_expr, + filter, + Operator::And, + ))), + None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist + }, + None => match &join_filter { + Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist + None => None, + }, + }; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Join(Box::new(JoinRel { + common: None, + left: Some(left), + right: Some(right), + r#type: join_type as i32, + expression: join_expr, + post_join_filter: None, + advanced_extension: None, + }))), + })) +} + +fn to_substrait_join_expr( + producer: &mut impl SubstraitProducer, + join_conditions: &Vec<(Expr, Expr)>, + eq_op: Operator, + join_schema: &DFSchemaRef, +) -> datafusion::common::Result> { + // Only support AND conjunction for each binary expression in join conditions + let mut exprs: Vec = vec![]; + for (left, right) in join_conditions { + let l = producer.handle_expr(left, join_schema)?; + let r = producer.handle_expr(right, join_schema)?; + // AND with existing expression + exprs.push(make_binary_op_scalar_func(producer, &l, &r, eq_op)); + } + + let join_expr: Option = + exprs.into_iter().reduce(|acc: Expression, e: Expression| { + make_binary_op_scalar_func(producer, &acc, &e, Operator::And) + }); + Ok(join_expr) +} + +fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType { + match join_type { + JoinType::Inner => join_rel::JoinType::Inner, + JoinType::Left => join_rel::JoinType::Left, + JoinType::Right => join_rel::JoinType::Right, + JoinType::Full => join_rel::JoinType::Outer, + JoinType::LeftAnti => join_rel::JoinType::LeftAnti, + JoinType::LeftSemi => join_rel::JoinType::LeftSemi, + JoinType::LeftMark => join_rel::JoinType::LeftMark, + JoinType::RightMark => join_rel::JoinType::RightMark, + JoinType::RightAnti | JoinType::RightSemi => { + unimplemented!() + } + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs new file mode 100644 index 000000000000..c3599a2635ff --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod aggregate_rel; +mod exchange_rel; +mod fetch_rel; +mod filter_rel; +mod join; +mod project_rel; +mod read_rel; +mod set_rel; +mod sort_rel; + +pub use aggregate_rel::*; +pub use exchange_rel::*; +pub use fetch_rel::*; +pub use filter_rel::*; +pub use join::*; +pub use project_rel::*; +pub use read_rel::*; +pub use set_rel::*; +pub use sort_rel::*; + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::not_impl_err; +use datafusion::logical_expr::LogicalPlan; +use substrait::proto::Rel; + +pub fn to_substrait_rel( + producer: &mut impl SubstraitProducer, + plan: &LogicalPlan, +) -> datafusion::common::Result> { + match plan { + LogicalPlan::Projection(plan) => producer.handle_projection(plan), + LogicalPlan::Filter(plan) => producer.handle_filter(plan), + LogicalPlan::Window(plan) => producer.handle_window(plan), + LogicalPlan::Aggregate(plan) => producer.handle_aggregate(plan), + LogicalPlan::Sort(plan) => producer.handle_sort(plan), + LogicalPlan::Join(plan) => producer.handle_join(plan), + LogicalPlan::Repartition(plan) => producer.handle_repartition(plan), + LogicalPlan::Union(plan) => producer.handle_union(plan), + LogicalPlan::TableScan(plan) => producer.handle_table_scan(plan), + LogicalPlan::EmptyRelation(plan) => producer.handle_empty_relation(plan), + LogicalPlan::Subquery(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::SubqueryAlias(plan) => producer.handle_subquery_alias(plan), + LogicalPlan::Limit(plan) => producer.handle_limit(plan), + LogicalPlan::Statement(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Values(plan) => producer.handle_values(plan), + LogicalPlan::Explain(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Analyze(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Extension(plan) => producer.handle_extension(plan), + LogicalPlan::Distinct(plan) => producer.handle_distinct(plan), + LogicalPlan::Dml(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Ddl(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::Copy(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::DescribeTable(plan) => { + not_impl_err!("Unsupported plan type: {plan:?}")? + } + LogicalPlan::Unnest(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, + LogicalPlan::RecursiveQuery(plan) => { + not_impl_err!("Unsupported plan type: {plan:?}")? + } + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/project_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/project_rel.rs new file mode 100644 index 000000000000..0190dca12bf5 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/project_rel.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{substrait_field_ref, SubstraitProducer}; +use datafusion::logical_expr::{Projection, Window}; +use substrait::proto::rel::RelType; +use substrait::proto::rel_common::EmitKind; +use substrait::proto::rel_common::EmitKind::Emit; +use substrait::proto::{rel_common, ProjectRel, Rel, RelCommon}; + +pub fn from_projection( + producer: &mut impl SubstraitProducer, + p: &Projection, +) -> datafusion::common::Result> { + let expressions = p + .expr + .iter() + .map(|e| producer.handle_expr(e, p.input.schema())) + .collect::>>()?; + + let emit_kind = create_project_remapping( + expressions.len(), + p.input.as_ref().schema().fields().len(), + ); + let common = RelCommon { + emit_kind: Some(emit_kind), + hint: None, + advanced_extension: None, + }; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Project(Box::new(ProjectRel { + common: Some(common), + input: Some(producer.handle_plan(p.input.as_ref())?), + expressions, + advanced_extension: None, + }))), + })) +} + +pub fn from_window( + producer: &mut impl SubstraitProducer, + window: &Window, +) -> datafusion::common::Result> { + let input = producer.handle_plan(window.input.as_ref())?; + + // create a field reference for each input field + let mut expressions = (0..window.input.schema().fields().len()) + .map(substrait_field_ref) + .collect::>>()?; + + // process and add each window function expression + for expr in &window.window_expr { + expressions.push(producer.handle_expr(expr, window.input.schema())?); + } + + let emit_kind = + create_project_remapping(expressions.len(), window.input.schema().fields().len()); + let common = RelCommon { + emit_kind: Some(emit_kind), + hint: None, + advanced_extension: None, + }; + let project_rel = Box::new(ProjectRel { + common: Some(common), + input: Some(input), + expressions, + advanced_extension: None, + }); + + Ok(Box::new(Rel { + rel_type: Some(RelType::Project(project_rel)), + })) +} + +/// By default, a Substrait Project outputs all input fields followed by all expressions. +/// A DataFusion Projection only outputs expressions. In order to keep the Substrait +/// plan consistent with DataFusion, we must apply an output mapping that skips the input +/// fields so that the Substrait Project will only output the expression fields. +fn create_project_remapping(expr_count: usize, input_field_count: usize) -> EmitKind { + let expression_field_start = input_field_count; + let expression_field_end = expression_field_start + expr_count; + let output_mapping = (expression_field_start..expression_field_end) + .map(|i| i as i32) + .collect(); + Emit(rel_common::Emit { output_mapping }) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/read_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/read_rel.rs new file mode 100644 index 000000000000..212874e7913b --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/read_rel.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{ + to_substrait_literal, to_substrait_named_struct, SubstraitProducer, +}; +use datafusion::common::{not_impl_err, substrait_datafusion_err, DFSchema, ToDFSchema}; +use datafusion::logical_expr::utils::conjunction; +use datafusion::logical_expr::{EmptyRelation, Expr, TableScan, Values}; +use std::sync::Arc; +use substrait::proto::expression::literal::Struct; +use substrait::proto::expression::mask_expression::{StructItem, StructSelect}; +use substrait::proto::expression::MaskExpression; +use substrait::proto::read_rel::{NamedTable, ReadType, VirtualTable}; +use substrait::proto::rel::RelType; +use substrait::proto::{ReadRel, Rel}; + +pub fn from_table_scan( + producer: &mut impl SubstraitProducer, + scan: &TableScan, +) -> datafusion::common::Result> { + let projection = scan.projection.as_ref().map(|p| { + p.iter() + .map(|i| StructItem { + field: *i as i32, + child: None, + }) + .collect() + }); + + let projection = projection.map(|struct_items| MaskExpression { + select: Some(StructSelect { struct_items }), + maintain_singular_struct: false, + }); + + let table_schema = scan.source.schema().to_dfschema_ref()?; + let base_schema = to_substrait_named_struct(&table_schema)?; + + let filter_option = if scan.filters.is_empty() { + None + } else { + let table_schema_qualified = Arc::new( + DFSchema::try_from_qualified_schema( + scan.table_name.clone(), + &(scan.source.schema()), + ) + .unwrap(), + ); + + let combined_expr = conjunction(scan.filters.clone()).unwrap(); + let filter_expr = + producer.handle_expr(&combined_expr, &table_schema_qualified)?; + Some(Box::new(filter_expr)) + }; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(base_schema), + filter: filter_option, + best_effort_filter: None, + projection, + advanced_extension: None, + read_type: Some(ReadType::NamedTable(NamedTable { + names: scan.table_name.to_vec(), + advanced_extension: None, + })), + }))), + })) +} + +pub fn from_empty_relation(e: &EmptyRelation) -> datafusion::common::Result> { + if e.produce_one_row { + return not_impl_err!("Producing a row from empty relation is unsupported"); + } + #[allow(deprecated)] + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(to_substrait_named_struct(&e.schema)?), + filter: None, + best_effort_filter: None, + projection: None, + advanced_extension: None, + read_type: Some(ReadType::VirtualTable(VirtualTable { + values: vec![], + expressions: vec![], + })), + }))), + })) +} + +pub fn from_values( + producer: &mut impl SubstraitProducer, + v: &Values, +) -> datafusion::common::Result> { + let values = v + .values + .iter() + .map(|row| { + let fields = row + .iter() + .map(|v| match v { + Expr::Literal(sv, _) => to_substrait_literal(producer, sv), + Expr::Alias(alias) => match alias.expr.as_ref() { + // The schema gives us the names, so we can skip aliases + Expr::Literal(sv, _) => to_substrait_literal(producer, sv), + _ => Err(substrait_datafusion_err!( + "Only literal types can be aliased in Virtual Tables, got: {}", alias.expr.variant_name() + )), + }, + _ => Err(substrait_datafusion_err!( + "Only literal types and aliases are supported in Virtual Tables, got: {}", v.variant_name() + )), + }) + .collect::>()?; + Ok(Struct { fields }) + }) + .collect::>()?; + #[allow(deprecated)] + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(to_substrait_named_struct(&v.schema)?), + filter: None, + best_effort_filter: None, + projection: None, + advanced_extension: None, + read_type: Some(ReadType::VirtualTable(VirtualTable { + values, + expressions: vec![], + })), + }))), + })) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/set_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/set_rel.rs new file mode 100644 index 000000000000..58ddfca3617a --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/set_rel.rs @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::logical_expr::Union; +use substrait::proto::rel::RelType; +use substrait::proto::{set_rel, Rel, SetRel}; + +pub fn from_union( + producer: &mut impl SubstraitProducer, + union: &Union, +) -> datafusion::common::Result> { + let input_rels = union + .inputs + .iter() + .map(|input| producer.handle_plan(input.as_ref())) + .collect::>>()? + .into_iter() + .map(|ptr| *ptr) + .collect(); + Ok(Box::new(Rel { + rel_type: Some(RelType::Set(SetRel { + common: None, + inputs: input_rels, + op: set_rel::SetOp::UnionAll as i32, // UNION DISTINCT gets translated to AGGREGATION + UNION ALL + advanced_extension: None, + })), + })) +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/sort_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/sort_rel.rs new file mode 100644 index 000000000000..aaa8be163560 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/rel/sort_rel.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{substrait_sort_field, SubstraitProducer}; +use crate::variation_const::DEFAULT_TYPE_VARIATION_REF; +use datafusion::logical_expr::Sort; +use substrait::proto::expression::literal::LiteralType; +use substrait::proto::expression::{Literal, RexType}; +use substrait::proto::rel::RelType; +use substrait::proto::{fetch_rel, Expression, FetchRel, Rel, SortRel}; + +pub fn from_sort( + producer: &mut impl SubstraitProducer, + sort: &Sort, +) -> datafusion::common::Result> { + let Sort { expr, input, fetch } = sort; + let sort_fields = expr + .iter() + .map(|e| substrait_sort_field(producer, e, input.schema())) + .collect::>>()?; + + let input = producer.handle_plan(input.as_ref())?; + + let sort_rel = Box::new(Rel { + rel_type: Some(RelType::Sort(Box::new(SortRel { + common: None, + input: Some(input), + sorts: sort_fields, + advanced_extension: None, + }))), + }); + + match fetch { + Some(amount) => { + let count_mode = + Some(fetch_rel::CountMode::CountExpr(Box::new(Expression { + rex_type: Some(RexType::Literal(Literal { + nullable: false, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::I64(*amount as i64)), + })), + }))); + Ok(Box::new(Rel { + rel_type: Some(RelType::Fetch(Box::new(FetchRel { + common: None, + input: Some(sort_rel), + offset_mode: None, + count_mode, + advanced_extension: None, + }))), + })) + } + None => Ok(sort_rel), + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs new file mode 100644 index 000000000000..56edfac5769c --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs @@ -0,0 +1,411 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::extensions::Extensions; +use crate::logical_plan::producer::{ + from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr, + from_case, from_cast, from_column, from_distinct, from_empty_relation, from_filter, + from_in_list, from_in_subquery, from_join, from_like, from_limit, from_literal, + from_projection, from_repartition, from_scalar_function, from_sort, + from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union, + from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex, +}; +use datafusion::common::{substrait_err, Column, DFSchemaRef, ScalarValue}; +use datafusion::execution::registry::SerializerRegistry; +use datafusion::execution::SessionState; +use datafusion::logical_expr::expr::{Alias, InList, InSubquery, WindowFunction}; +use datafusion::logical_expr::{ + expr, Aggregate, Between, BinaryExpr, Case, Cast, Distinct, EmptyRelation, Expr, + Extension, Filter, Join, Like, Limit, LogicalPlan, Projection, Repartition, Sort, + SubqueryAlias, TableScan, TryCast, Union, Values, Window, +}; +use pbjson_types::Any as ProtoAny; +use substrait::proto::aggregate_rel::Measure; +use substrait::proto::rel::RelType; +use substrait::proto::{ + Expression, ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel, Rel, +}; + +/// This trait is used to produce Substrait plans, converting them from DataFusion Logical Plans. +/// It can be implemented by users to allow for custom handling of relations, expressions, etc. +/// +/// Combined with the [crate::logical_plan::consumer::SubstraitConsumer] this allows for fully +/// customizable Substrait serde. +/// +/// # Example Usage +/// +/// ``` +/// # use std::sync::Arc; +/// # use substrait::proto::{Expression, Rel}; +/// # use substrait::proto::rel::RelType; +/// # use datafusion::common::DFSchemaRef; +/// # use datafusion::error::Result; +/// # use datafusion::execution::SessionState; +/// # use datafusion::logical_expr::{Between, Extension, Projection}; +/// # use datafusion_substrait::extensions::Extensions; +/// # use datafusion_substrait::logical_plan::producer::{from_projection, SubstraitProducer}; +/// +/// struct CustomSubstraitProducer { +/// extensions: Extensions, +/// state: Arc, +/// } +/// +/// impl SubstraitProducer for CustomSubstraitProducer { +/// +/// fn register_function(&mut self, signature: String) -> u32 { +/// self.extensions.register_function(signature) +/// } +/// +/// fn get_extensions(self) -> Extensions { +/// self.extensions +/// } +/// +/// // You can set additional metadata on the Rels you produce +/// fn handle_projection(&mut self, plan: &Projection) -> Result> { +/// let mut rel = from_projection(self, plan)?; +/// match rel.rel_type { +/// Some(RelType::Project(mut project)) => { +/// let mut project = project.clone(); +/// // set common metadata or advanced extension +/// project.common = None; +/// project.advanced_extension = None; +/// Ok(Box::new(Rel { +/// rel_type: Some(RelType::Project(project)), +/// })) +/// } +/// rel_type => Ok(Box::new(Rel { rel_type })), +/// } +/// } +/// +/// // You can tweak how you convert expressions for your target system +/// fn handle_between(&mut self, between: &Between, schema: &DFSchemaRef) -> Result { +/// // add your own encoding for Between +/// todo!() +/// } +/// +/// // You can fully control how you convert UserDefinedLogicalNodes into Substrait +/// fn handle_extension(&mut self, _plan: &Extension) -> Result> { +/// // implement your own serializer into Substrait +/// todo!() +/// } +/// } +/// ``` +pub trait SubstraitProducer: Send + Sync + Sized { + /// Within a Substrait plan, functions are referenced using function anchors that are stored at + /// the top level of the [Plan](substrait::proto::Plan) within + /// [ExtensionFunction](substrait::proto::extensions::simple_extension_declaration::ExtensionFunction) + /// messages. + /// + /// When given a function signature, this method should return the existing anchor for it if + /// there is one. Otherwise, it should generate a new anchor. + fn register_function(&mut self, signature: String) -> u32; + + /// Consume the producer to generate the [Extensions] for the Substrait plan based on the + /// functions that have been registered + fn get_extensions(self) -> Extensions; + + // Logical Plan Methods + // There is one method per LogicalPlan to allow for easy overriding of producer behaviour. + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + fn handle_plan( + &mut self, + plan: &LogicalPlan, + ) -> datafusion::common::Result> { + to_substrait_rel(self, plan) + } + + fn handle_projection( + &mut self, + plan: &Projection, + ) -> datafusion::common::Result> { + from_projection(self, plan) + } + + fn handle_filter(&mut self, plan: &Filter) -> datafusion::common::Result> { + from_filter(self, plan) + } + + fn handle_window(&mut self, plan: &Window) -> datafusion::common::Result> { + from_window(self, plan) + } + + fn handle_aggregate( + &mut self, + plan: &Aggregate, + ) -> datafusion::common::Result> { + from_aggregate(self, plan) + } + + fn handle_sort(&mut self, plan: &Sort) -> datafusion::common::Result> { + from_sort(self, plan) + } + + fn handle_join(&mut self, plan: &Join) -> datafusion::common::Result> { + from_join(self, plan) + } + + fn handle_repartition( + &mut self, + plan: &Repartition, + ) -> datafusion::common::Result> { + from_repartition(self, plan) + } + + fn handle_union(&mut self, plan: &Union) -> datafusion::common::Result> { + from_union(self, plan) + } + + fn handle_table_scan( + &mut self, + plan: &TableScan, + ) -> datafusion::common::Result> { + from_table_scan(self, plan) + } + + fn handle_empty_relation( + &mut self, + plan: &EmptyRelation, + ) -> datafusion::common::Result> { + from_empty_relation(plan) + } + + fn handle_subquery_alias( + &mut self, + plan: &SubqueryAlias, + ) -> datafusion::common::Result> { + from_subquery_alias(self, plan) + } + + fn handle_limit(&mut self, plan: &Limit) -> datafusion::common::Result> { + from_limit(self, plan) + } + + fn handle_values(&mut self, plan: &Values) -> datafusion::common::Result> { + from_values(self, plan) + } + + fn handle_distinct( + &mut self, + plan: &Distinct, + ) -> datafusion::common::Result> { + from_distinct(self, plan) + } + + fn handle_extension( + &mut self, + _plan: &Extension, + ) -> datafusion::common::Result> { + substrait_err!("Specify handling for LogicalPlan::Extension by implementing the SubstraitProducer trait") + } + + // Expression Methods + // There is one method per DataFusion Expr to allow for easy overriding of producer behaviour + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + fn handle_expr( + &mut self, + expr: &Expr, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + to_substrait_rex(self, expr, schema) + } + + fn handle_alias( + &mut self, + alias: &Alias, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_alias(self, alias, schema) + } + + fn handle_column( + &mut self, + column: &Column, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_column(column, schema) + } + + fn handle_literal( + &mut self, + value: &ScalarValue, + ) -> datafusion::common::Result { + from_literal(self, value) + } + + fn handle_binary_expr( + &mut self, + expr: &BinaryExpr, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_binary_expr(self, expr, schema) + } + + fn handle_like( + &mut self, + like: &Like, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_like(self, like, schema) + } + + /// For handling Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative + fn handle_unary_expr( + &mut self, + expr: &Expr, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_unary_expr(self, expr, schema) + } + + fn handle_between( + &mut self, + between: &Between, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_between(self, between, schema) + } + + fn handle_case( + &mut self, + case: &Case, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_case(self, case, schema) + } + + fn handle_cast( + &mut self, + cast: &Cast, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_cast(self, cast, schema) + } + + fn handle_try_cast( + &mut self, + cast: &TryCast, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_try_cast(self, cast, schema) + } + + fn handle_scalar_function( + &mut self, + scalar_fn: &expr::ScalarFunction, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_scalar_function(self, scalar_fn, schema) + } + + fn handle_aggregate_function( + &mut self, + agg_fn: &expr::AggregateFunction, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_aggregate_function(self, agg_fn, schema) + } + + fn handle_window_function( + &mut self, + window_fn: &WindowFunction, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_window_function(self, window_fn, schema) + } + + fn handle_in_list( + &mut self, + in_list: &InList, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_in_list(self, in_list, schema) + } + + fn handle_in_subquery( + &mut self, + in_subquery: &InSubquery, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_in_subquery(self, in_subquery, schema) + } +} + +pub struct DefaultSubstraitProducer<'a> { + extensions: Extensions, + serializer_registry: &'a dyn SerializerRegistry, +} + +impl<'a> DefaultSubstraitProducer<'a> { + pub fn new(state: &'a SessionState) -> Self { + DefaultSubstraitProducer { + extensions: Extensions::default(), + serializer_registry: state.serializer_registry().as_ref(), + } + } +} + +impl SubstraitProducer for DefaultSubstraitProducer<'_> { + fn register_function(&mut self, fn_name: String) -> u32 { + self.extensions.register_function(fn_name) + } + + fn get_extensions(self) -> Extensions { + self.extensions + } + + fn handle_extension( + &mut self, + plan: &Extension, + ) -> datafusion::common::Result> { + let extension_bytes = self + .serializer_registry + .serialize_logical_plan(plan.node.as_ref())?; + let detail = ProtoAny { + type_url: plan.node.name().to_string(), + value: extension_bytes.into(), + }; + let mut inputs_rel = plan + .node + .inputs() + .into_iter() + .map(|plan| self.handle_plan(plan)) + .collect::>>()?; + let rel_type = match inputs_rel.len() { + 0 => RelType::ExtensionLeaf(ExtensionLeafRel { + common: None, + detail: Some(detail), + }), + 1 => RelType::ExtensionSingle(Box::new(ExtensionSingleRel { + common: None, + detail: Some(detail), + input: Some(inputs_rel.pop().unwrap()), + })), + _ => RelType::ExtensionMulti(ExtensionMultiRel { + common: None, + detail: Some(detail), + inputs: inputs_rel.into_iter().map(|r| *r).collect(), + }), + }; + Ok(Box::new(Rel { + rel_type: Some(rel_type), + })) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs b/datafusion/substrait/src/logical_plan/producer/types.rs new file mode 100644 index 000000000000..6a63bbef5d7d --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/types.rs @@ -0,0 +1,457 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::utils::flatten_names; +use crate::variation_const::{ + DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, + DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF, + DEFAULT_TYPE_VARIATION_REF, DURATION_INTERVAL_DAY_TYPE_VARIATION_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, +}; +use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; +use datafusion::common::{internal_err, not_impl_err, plan_err, DFSchemaRef}; +use substrait::proto::{r#type, NamedStruct}; + +pub(crate) fn to_substrait_type( + dt: &DataType, + nullable: bool, +) -> datafusion::common::Result { + let nullability = if nullable { + r#type::Nullability::Nullable as i32 + } else { + r#type::Nullability::Required as i32 + }; + match dt { + DataType::Null => internal_err!("Null cast is not valid"), + DataType::Boolean => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Bool(r#type::Boolean { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Int8 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I8(r#type::I8 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::UInt8 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I8(r#type::I8 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Int16 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I16(r#type::I16 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::UInt16 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I16(r#type::I16 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Int32 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I32(r#type::I32 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::UInt32 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I32(r#type::I32 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Int64 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I64(r#type::I64 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::UInt64 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::I64(r#type::I64 { + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, + nullability, + })), + }), + // Float16 is not supported in Substrait + DataType::Float32 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Fp32(r#type::Fp32 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Float64 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Fp64(r#type::Fp64 { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Timestamp(unit, tz) => { + let precision = match unit { + TimeUnit::Second => 0, + TimeUnit::Millisecond => 3, + TimeUnit::Microsecond => 6, + TimeUnit::Nanosecond => 9, + }; + let kind = match tz { + None => r#type::Kind::PrecisionTimestamp(r#type::PrecisionTimestamp { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + precision, + }), + Some(_) => { + // If timezone is present, no matter what the actual tz value is, it indicates the + // value of the timestamp is tied to UTC epoch. That's all that Substrait cares about. + // As the timezone is lost, this conversion may be lossy for downstream use of the value. + r#type::Kind::PrecisionTimestampTz(r#type::PrecisionTimestampTz { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + precision, + }) + } + }; + Ok(substrait::proto::Type { kind: Some(kind) }) + } + DataType::Date32 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Date(r#type::Date { + type_variation_reference: DATE_32_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Date64 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Date(r#type::Date { + type_variation_reference: DATE_64_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Interval(interval_unit) => { + match interval_unit { + IntervalUnit::YearMonth => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalYear(r#type::IntervalYear { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + IntervalUnit::DayTime => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalDay(r#type::IntervalDay { + type_variation_reference: DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF, + nullability, + precision: Some(3), // DayTime precision is always milliseconds + })), + }), + IntervalUnit::MonthDayNano => { + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalCompound( + r#type::IntervalCompound { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + precision: 9, // nanos + }, + )), + }) + } + } + } + DataType::Duration(duration_unit) => { + let precision = match duration_unit { + TimeUnit::Second => 0, + TimeUnit::Millisecond => 3, + TimeUnit::Microsecond => 6, + TimeUnit::Nanosecond => 9, + }; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalDay(r#type::IntervalDay { + type_variation_reference: DURATION_INTERVAL_DAY_TYPE_VARIATION_REF, + nullability, + precision: Some(precision), + })), + }) + } + DataType::Binary => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Binary(r#type::Binary { + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::FixedSizeBinary(length) => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::FixedBinary(r#type::FixedBinary { + length: *length, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::LargeBinary => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Binary(r#type::Binary { + type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::BinaryView => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Binary(r#type::Binary { + type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Utf8 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::String(r#type::String { + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::LargeUtf8 => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::String(r#type::String { + type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Utf8View => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::String(r#type::String { + type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::List(inner) => { + let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::List(Box::new(r#type::List { + r#type: Some(Box::new(inner_type)), + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, + }))), + }) + } + DataType::LargeList(inner) => { + let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::List(Box::new(r#type::List { + r#type: Some(Box::new(inner_type)), + type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, + nullability, + }))), + }) + } + DataType::Map(inner, _) => match inner.data_type() { + DataType::Struct(key_and_value) if key_and_value.len() == 2 => { + let key_type = to_substrait_type( + key_and_value[0].data_type(), + key_and_value[0].is_nullable(), + )?; + let value_type = to_substrait_type( + key_and_value[1].data_type(), + key_and_value[1].is_nullable(), + )?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Map(Box::new(r#type::Map { + key: Some(Box::new(key_type)), + value: Some(Box::new(value_type)), + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, + }))), + }) + } + _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + }, + DataType::Struct(fields) => { + let field_types = fields + .iter() + .map(|field| to_substrait_type(field.data_type(), field.is_nullable())) + .collect::>>()?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Struct(r#type::Struct { + types: field_types, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }) + } + DataType::Decimal128(p, s) => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Decimal(r#type::Decimal { + type_variation_reference: DECIMAL_128_TYPE_VARIATION_REF, + nullability, + scale: *s as i32, + precision: *p as i32, + })), + }), + DataType::Decimal256(p, s) => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Decimal(r#type::Decimal { + type_variation_reference: DECIMAL_256_TYPE_VARIATION_REF, + nullability, + scale: *s as i32, + precision: *p as i32, + })), + }), + _ => not_impl_err!("Unsupported cast type: {dt:?}"), + } +} + +pub(crate) fn to_substrait_named_struct( + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let mut names = Vec::with_capacity(schema.fields().len()); + for field in schema.fields() { + flatten_names(field, false, &mut names)?; + } + + let field_types = r#type::Struct { + types: schema + .fields() + .iter() + .map(|f| to_substrait_type(f.data_type(), f.is_nullable())) + .collect::>()?, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability: r#type::Nullability::Required as i32, + }; + + Ok(NamedStruct { + names, + r#struct: Some(field_types), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::consumer::tests::test_consumer; + use crate::logical_plan::consumer::{ + from_substrait_named_struct, from_substrait_type_without_names, + }; + use datafusion::arrow::datatypes::{Field, Fields, Schema}; + use datafusion::common::{DFSchema, Result}; + use std::sync::Arc; + + #[test] + fn round_trip_types() -> Result<()> { + round_trip_type(DataType::Boolean)?; + round_trip_type(DataType::Int8)?; + round_trip_type(DataType::UInt8)?; + round_trip_type(DataType::Int16)?; + round_trip_type(DataType::UInt16)?; + round_trip_type(DataType::Int32)?; + round_trip_type(DataType::UInt32)?; + round_trip_type(DataType::Int64)?; + round_trip_type(DataType::UInt64)?; + round_trip_type(DataType::Float32)?; + round_trip_type(DataType::Float64)?; + + for tz in [None, Some("UTC".into())] { + round_trip_type(DataType::Timestamp(TimeUnit::Second, tz.clone()))?; + round_trip_type(DataType::Timestamp(TimeUnit::Millisecond, tz.clone()))?; + round_trip_type(DataType::Timestamp(TimeUnit::Microsecond, tz.clone()))?; + round_trip_type(DataType::Timestamp(TimeUnit::Nanosecond, tz))?; + } + + round_trip_type(DataType::Date32)?; + round_trip_type(DataType::Date64)?; + round_trip_type(DataType::Binary)?; + round_trip_type(DataType::FixedSizeBinary(10))?; + round_trip_type(DataType::LargeBinary)?; + round_trip_type(DataType::BinaryView)?; + round_trip_type(DataType::Utf8)?; + round_trip_type(DataType::LargeUtf8)?; + round_trip_type(DataType::Utf8View)?; + round_trip_type(DataType::Decimal128(10, 2))?; + round_trip_type(DataType::Decimal256(30, 2))?; + + round_trip_type(DataType::List( + Field::new_list_field(DataType::Int32, true).into(), + ))?; + round_trip_type(DataType::LargeList( + Field::new_list_field(DataType::Int32, true).into(), + ))?; + + round_trip_type(DataType::Map( + Field::new_struct( + "entries", + [ + Field::new("key", DataType::Utf8, false).into(), + Field::new("value", DataType::Int32, true).into(), + ], + false, + ) + .into(), + false, + ))?; + + round_trip_type(DataType::Struct( + vec![ + Field::new("c0", DataType::Int32, true), + Field::new("c1", DataType::Utf8, true), + ] + .into(), + ))?; + + round_trip_type(DataType::Interval(IntervalUnit::YearMonth))?; + round_trip_type(DataType::Interval(IntervalUnit::MonthDayNano))?; + round_trip_type(DataType::Interval(IntervalUnit::DayTime))?; + + round_trip_type(DataType::Duration(TimeUnit::Second))?; + round_trip_type(DataType::Duration(TimeUnit::Millisecond))?; + round_trip_type(DataType::Duration(TimeUnit::Microsecond))?; + round_trip_type(DataType::Duration(TimeUnit::Nanosecond))?; + + Ok(()) + } + + fn round_trip_type(dt: DataType) -> Result<()> { + println!("Checking round trip of {dt:?}"); + + // As DataFusion doesn't consider nullability as a property of the type, but field, + // it doesn't matter if we set nullability to true or false here. + let substrait = to_substrait_type(&dt, true)?; + let consumer = test_consumer(); + let roundtrip_dt = from_substrait_type_without_names(&consumer, &substrait)?; + assert_eq!(dt, roundtrip_dt); + Ok(()) + } + + #[test] + fn named_struct_names() -> Result<()> { + let schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ + Field::new("int", DataType::Int32, true), + Field::new( + "struct", + DataType::Struct(Fields::from(vec![Field::new( + "inner", + DataType::List(Arc::new(Field::new_list_field(DataType::Utf8, true))), + true, + )])), + true, + ), + Field::new("trailer", DataType::Float64, true), + ]))?); + + let named_struct = to_substrait_named_struct(&schema)?; + + // Struct field names should be flattened DFS style + // List field names should be omitted + assert_eq!( + named_struct.names, + vec!["int", "struct", "inner", "trailer"] + ); + + let roundtrip_schema = + from_substrait_named_struct(&test_consumer(), &named_struct)?; + assert_eq!(schema.as_ref(), &roundtrip_schema); + Ok(()) + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/utils.rs b/datafusion/substrait/src/logical_plan/producer/utils.rs new file mode 100644 index 000000000000..5429e4a1ad88 --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/utils.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::common::{plan_err, DFSchemaRef}; +use datafusion::logical_expr::SortExpr; +use substrait::proto::sort_field::{SortDirection, SortKind}; +use substrait::proto::SortField; + +// Substrait wants a list of all field names, including nested fields from structs, +// also from within e.g. lists and maps. However, it does not want the list and map field names +// themselves - only proper structs fields are considered to have useful names. +pub(crate) fn flatten_names( + field: &Field, + skip_self: bool, + names: &mut Vec, +) -> datafusion::common::Result<()> { + if !skip_self { + names.push(field.name().to_string()); + } + match field.data_type() { + DataType::Struct(fields) => { + for field in fields { + flatten_names(field, false, names)?; + } + Ok(()) + } + DataType::List(l) => flatten_names(l, true, names), + DataType::LargeList(l) => flatten_names(l, true, names), + DataType::Map(m, _) => match m.data_type() { + DataType::Struct(key_and_value) if key_and_value.len() == 2 => { + flatten_names(&key_and_value[0], true, names)?; + flatten_names(&key_and_value[1], true, names) + } + _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + }, + _ => Ok(()), + }?; + Ok(()) +} + +pub(crate) fn substrait_sort_field( + producer: &mut impl SubstraitProducer, + sort: &SortExpr, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let SortExpr { + expr, + asc, + nulls_first, + } = sort; + let e = producer.handle_expr(expr, schema)?; + let d = match (asc, nulls_first) { + (true, true) => SortDirection::AscNullsFirst, + (true, false) => SortDirection::AscNullsLast, + (false, true) => SortDirection::DescNullsFirst, + (false, false) => SortDirection::DescNullsLast, + }; + Ok(SortField { + expr: Some(e), + sort_kind: Some(SortKind::Direction(d as i32)), + }) +} diff --git a/datafusion/substrait/src/variation_const.rs b/datafusion/substrait/src/variation_const.rs index e5bebf8e1181..efde8efe509e 100644 --- a/datafusion/substrait/src/variation_const.rs +++ b/datafusion/substrait/src/variation_const.rs @@ -55,6 +55,15 @@ pub const LARGE_CONTAINER_TYPE_VARIATION_REF: u32 = 1; pub const VIEW_CONTAINER_TYPE_VARIATION_REF: u32 = 2; pub const DECIMAL_128_TYPE_VARIATION_REF: u32 = 0; pub const DECIMAL_256_TYPE_VARIATION_REF: u32 = 1; +/// Used for the arrow type [`DataType::Interval`] with [`IntervalUnit::DayTime`]. +/// +/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval +/// [`IntervalUnit::DayTime`]: datafusion::arrow::datatypes::IntervalUnit::DayTime +pub const DEFAULT_INTERVAL_DAY_TYPE_VARIATION_REF: u32 = 0; +/// Used for the arrow type [`DataType::Duration`]. +/// +/// [`DataType::Duration`]: datafusion::arrow::datatypes::DataType::Duration +pub const DURATION_INTERVAL_DAY_TYPE_VARIATION_REF: u32 = 1; // For [user-defined types](https://substrait.io/types/type_classes/#user-defined-types). /// For [`DataType::Interval`] with [`IntervalUnit::YearMonth`]. diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index bdeeeb585c0c..4a121e41d27e 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -44,7 +44,7 @@ mod tests { let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?; let plan = from_substrait_plan(&ctx.state(), &proto).await?; ctx.state().create_physical_plan(&plan).await?; - Ok(format!("{}", plan)) + Ok(format!("{plan}")) } #[tokio::test] @@ -501,7 +501,7 @@ mod tests { let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?; let plan = from_substrait_plan(&ctx.state(), &proto).await?; ctx.state().create_physical_plan(&plan).await?; - Ok(format!("{}", plan)) + Ok(format!("{plan}")) } #[tokio::test] @@ -560,4 +560,28 @@ mod tests { ); Ok(()) } + + #[tokio::test] + async fn test_multiple_unions() -> Result<()> { + let plan_str = test_plan_to_string("multiple_unions.json").await?; + assert_snapshot!( + plan_str, + @r#" + Projection: Utf8("people") AS product_category, Utf8("people")__temp__0 AS product_type, product_key + Union + Projection: Utf8("people"), Utf8("people") AS Utf8("people")__temp__0, sales.product_key + Left Join: sales.product_key = food.@food_id + TableScan: sales + TableScan: food + Union + Projection: people.$f3, people.$f5, people.product_key0 + Left Join: people.product_key0 = food.@food_id + TableScan: people + TableScan: food + TableScan: more_products + "# + ); + + Ok(()) + } } diff --git a/datafusion/substrait/tests/cases/emit_kind_tests.rs b/datafusion/substrait/tests/cases/emit_kind_tests.rs index 88db2bc34d7f..e916b4cb0e1a 100644 --- a/datafusion/substrait/tests/cases/emit_kind_tests.rs +++ b/datafusion/substrait/tests/cases/emit_kind_tests.rs @@ -126,8 +126,8 @@ mod tests { let plan1str = format!("{plan}"); let plan2str = format!("{plan2}"); - println!("{}", plan1str); - println!("{}", plan2str); + println!("{plan1str}"); + println!("{plan2str}"); assert_eq!(plan1str, plan2str); Ok(()) diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 9a85f3e6c4dc..7a5cfeb39836 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -763,7 +763,7 @@ async fn simple_intersect() -> Result<()> { let expected_plan_str = format!( "Projection: count(Int64(1)) AS {syntax}\ \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ - \n Projection: \ + \n Projection:\ \n LeftSemi Join: data.a = data2.a\ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ \n TableScan: data projection=[a]\ @@ -780,7 +780,7 @@ async fn simple_intersect() -> Result<()> { async fn check_constant(sql_syntax: &str, plan_expr: &str) -> Result<()> { let expected_plan_str = format!( "Aggregate: groupBy=[[]], aggr=[[{plan_expr}]]\ - \n Projection: \ + \n Projection:\ \n LeftSemi Join: data.a = data2.a\ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ \n TableScan: data projection=[a]\ @@ -854,6 +854,22 @@ async fn aggregate_wo_projection_sorted_consume() -> Result<()> { Ok(()) } +#[tokio::test] +async fn aggregate_identical_grouping_expressions() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/aggregate_identical_grouping_expressions.substrait.json"); + + let plan = generate_plan_from_substrait(proto_plan).await?; + assert_snapshot!( + plan, + @r#" + Aggregate: groupBy=[[Int32(1) AS grouping_col_1, Int32(1) AS grouping_col_2]], aggr=[[]] + TableScan: data projection=[] + "# + ); + Ok(()) +} + #[tokio::test] async fn simple_intersect_consume() -> Result<()> { let proto_plan = read_json("tests/testdata/test_plans/intersect.substrait.json"); @@ -942,7 +958,7 @@ async fn simple_intersect_table_reuse() -> Result<()> { let expected_plan_str = format!( "Projection: count(Int64(1)) AS {syntax}\ \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ - \n Projection: \ + \n Projection:\ \n LeftSemi Join: left.a = right.a\ \n SubqueryAlias: left\ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ @@ -961,7 +977,7 @@ async fn simple_intersect_table_reuse() -> Result<()> { async fn check_constant(sql_syntax: &str, plan_expr: &str) -> Result<()> { let expected_plan_str = format!( "Aggregate: groupBy=[[]], aggr=[[{plan_expr}]]\ - \n Projection: \ + \n Projection:\ \n LeftSemi Join: left.a = right.a\ \n SubqueryAlias: left\ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ diff --git a/datafusion/substrait/tests/testdata/test_plans/aggregate_identical_grouping_expressions.substrait.json b/datafusion/substrait/tests/testdata/test_plans/aggregate_identical_grouping_expressions.substrait.json new file mode 100644 index 000000000000..15c0b0505fa6 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/aggregate_identical_grouping_expressions.substrait.json @@ -0,0 +1,53 @@ +{ + "extensionUris": [], + "extensions": [], + "relations": [ + { + "root": { + "input": { + "aggregate": { + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [], + "struct": { + "types": [], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": ["data"] + } + } + }, + "groupings": [ + { + "groupingExpressions": [ + { + "literal": { + "i32": 1 + } + }, + { + "literal": { + "i32": 1 + } + } + ] + } + ], + "measures": [] + } + }, + "names": ["grouping_col_1", "grouping_col_2"] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "manual" + } +} diff --git a/datafusion/substrait/tests/testdata/test_plans/multiple_unions.json b/datafusion/substrait/tests/testdata/test_plans/multiple_unions.json new file mode 100644 index 000000000000..8b82d6eec755 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/multiple_unions.json @@ -0,0 +1,328 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "equal:any_any" + } + }], + "relations": [{ + "root": { + "input": { + "set": { + "common": { + "direct": { + } + }, + "inputs": [{ + "project": { + "common": { + "emit": { + "outputMapping": [2, 3, 4] + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["product_key"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "sales" + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["@food_id"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "food" + ] + } + } + }, + "expression": { + "scalarFunction": { + "functionReference": 0, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "expressions": [{ + "literal": { + "string": "people" + } + }, { + "literal": { + "string": "people" + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }] + } + }, { + "set": { + "common": { + "direct": { + } + }, + "inputs": [{ + "project": { + "common": { + "emit": { + "outputMapping": [4, 5, 6] + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["$f3", "$f5", "product_key0"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "people" + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["@food_id"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "food" + ] + } + + } + }, + "expression": { + "scalarFunction": { + "functionReference": 0, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }] + } + }, { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["$f1000", "$f2000", "more_products_key0000"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "more_products" + ] + } + + } + }], + "op": "SET_OP_UNION_ALL" + } + }], + "op": "SET_OP_UNION_ALL" + } + }, + "names": ["product_category", "product_type", "product_key"] + } + }] +} \ No newline at end of file diff --git a/datafusion/wasmtest/Cargo.toml b/datafusion/wasmtest/Cargo.toml index 10eab025734c..b43c34f19760 100644 --- a/datafusion/wasmtest/Cargo.toml +++ b/datafusion/wasmtest/Cargo.toml @@ -52,14 +52,13 @@ datafusion-expr = { workspace = true } datafusion-optimizer = { workspace = true, default-features = true } datafusion-physical-plan = { workspace = true } datafusion-sql = { workspace = true } -# getrandom must be compiled with js feature -getrandom = { version = "0.2.8", features = ["js"] } - +getrandom = { version = "0.3", features = ["wasm_js"] } wasm-bindgen = "0.2.99" [dev-dependencies] insta = { workspace = true } object_store = { workspace = true } +# needs to be compiled tokio = { workspace = true } url = { workspace = true } wasm-bindgen-test = "0.3.49" diff --git a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json index 65d8bdbb5e93..ed1baff7412f 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json @@ -15,7 +15,7 @@ "copy-webpack-plugin": "12.0.2", "webpack": "5.94.0", "webpack-cli": "5.1.4", - "webpack-dev-server": "4.15.1" + "webpack-dev-server": "5.2.1" } }, "../pkg": { @@ -90,10 +90,11 @@ } }, "node_modules/@leichtgewicht/ip-codec": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/@leichtgewicht/ip-codec/-/ip-codec-2.0.4.tgz", - "integrity": "sha512-Hcv+nVC0kZnQ3tD9GVu5xSMR4VVYOteQIr/hwFPVEvPdlXqgGEuRjiheChHgdM+JyqdgNcmzZOX/tnl0JOiI7A==", - "dev": true + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@leichtgewicht/ip-codec/-/ip-codec-2.0.5.tgz", + "integrity": "sha512-Vo+PSpZG2/fmgmiNzYK9qWRh8h/CHrwD0mo1h1DzL4yzHNSfWYujGTYsWGreD000gcgmZ7K4Ys6Tx9TxtsKdDw==", + "dev": true, + "license": "MIT" }, "node_modules/@nodelib/fs.scandir": { "version": "2.1.5", @@ -157,10 +158,11 @@ } }, "node_modules/@types/bonjour": { - "version": "3.5.11", - "resolved": "https://registry.npmjs.org/@types/bonjour/-/bonjour-3.5.11.tgz", - "integrity": "sha512-isGhjmBtLIxdHBDl2xGwUzEM8AOyOvWsADWq7rqirdi/ZQoHnLWErHvsThcEzTX8juDRiZtzp2Qkv5bgNh6mAg==", + "version": "3.5.13", + "resolved": "https://registry.npmjs.org/@types/bonjour/-/bonjour-3.5.13.tgz", + "integrity": "sha512-z9fJ5Im06zvUL548KvYNecEVlA7cVDkGUi6kZusb04mpyEFKCIZJvloCcmpmLaIahDpOQGHaHmG6imtPMmPXGQ==", "dev": true, + "license": "MIT", "dependencies": { "@types/node": "*" } @@ -175,10 +177,11 @@ } }, "node_modules/@types/connect-history-api-fallback": { - "version": "1.5.1", - "resolved": "https://registry.npmjs.org/@types/connect-history-api-fallback/-/connect-history-api-fallback-1.5.1.tgz", - "integrity": "sha512-iaQslNbARe8fctL5Lk+DsmgWOM83lM+7FzP0eQUJs1jd3kBE8NWqBTIT2S8SqQOJjxvt2eyIjpOuYeRXq2AdMw==", + "version": "1.5.4", + "resolved": "https://registry.npmjs.org/@types/connect-history-api-fallback/-/connect-history-api-fallback-1.5.4.tgz", + "integrity": "sha512-n6Cr2xS1h4uAulPRdlw6Jl6s1oG8KrVilPN2yUITEs+K48EzMJJ3W1xy8K5eWuFvjp3R74AOIGSmp2UfBJ8HFw==", "dev": true, + "license": "MIT", "dependencies": { "@types/express-serve-static-core": "*", "@types/node": "*" @@ -191,10 +194,11 @@ "dev": true }, "node_modules/@types/express": { - "version": "4.17.17", - "resolved": "https://registry.npmjs.org/@types/express/-/express-4.17.17.tgz", - "integrity": "sha512-Q4FmmuLGBG58btUnfS1c1r/NQdlp3DMfGDGig8WhfpA2YRUtEkxAjkZb0yvplJGYdF1fsQ81iMDcH24sSCNC/Q==", + "version": "4.17.22", + "resolved": "https://registry.npmjs.org/@types/express/-/express-4.17.22.tgz", + "integrity": "sha512-eZUmSnhRX9YRSkplpz0N+k6NljUUn5l3EWZIKZvYzhvMphEuNiyyy1viH/ejgt66JWgALwC/gtSUAeQKtSwW/w==", "dev": true, + "license": "MIT", "dependencies": { "@types/body-parser": "*", "@types/express-serve-static-core": "^4.17.33", @@ -247,6 +251,16 @@ "integrity": "sha512-HksnYH4Ljr4VQgEy2lTStbCKv/P590tmPe5HqOnv9Gprffgv5WXAY+Y5Gqniu0GGqeTCUdBnzC3QSrzPkBkAMA==", "dev": true }, + "node_modules/@types/node-forge": { + "version": "1.3.11", + "resolved": "https://registry.npmjs.org/@types/node-forge/-/node-forge-1.3.11.tgz", + "integrity": "sha512-FQx220y22OKNTqaByeBGqHWYz4cl94tpcxeFdvBo3wjG6XPBuZ0BNgNZRV5J5TFmmcsJ4IzsLkmGRiQbnYsBEQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, "node_modules/@types/qs": { "version": "6.9.8", "resolved": "https://registry.npmjs.org/@types/qs/-/qs-6.9.8.tgz", @@ -260,10 +274,11 @@ "dev": true }, "node_modules/@types/retry": { - "version": "0.12.0", - "resolved": "https://registry.npmjs.org/@types/retry/-/retry-0.12.0.tgz", - "integrity": "sha512-wWKOClTTiizcZhXnPY4wikVAwmdYHp8q6DmC+EJUzAMsycb7HB32Kh9RN4+0gExjmPmZSAQjgURXIGATPegAvA==", - "dev": true + "version": "0.12.2", + "resolved": "https://registry.npmjs.org/@types/retry/-/retry-0.12.2.tgz", + "integrity": "sha512-XISRgDJ2Tc5q4TRqvgJtzsRkFYNJzZrhTdtMoGVBttwzzQJkPnS3WWTFc7kuDRoPtPakl+T+OfdEUjYJj7Jbow==", + "dev": true, + "license": "MIT" }, "node_modules/@types/send": { "version": "0.17.1", @@ -276,39 +291,43 @@ } }, "node_modules/@types/serve-index": { - "version": "1.9.1", - "resolved": "https://registry.npmjs.org/@types/serve-index/-/serve-index-1.9.1.tgz", - "integrity": "sha512-d/Hs3nWDxNL2xAczmOVZNj92YZCS6RGxfBPjKzuu/XirCgXdpKEb88dYNbrYGint6IVWLNP+yonwVAuRC0T2Dg==", + "version": "1.9.4", + "resolved": "https://registry.npmjs.org/@types/serve-index/-/serve-index-1.9.4.tgz", + "integrity": "sha512-qLpGZ/c2fhSs5gnYsQxtDEq3Oy8SXPClIXkW5ghvAvsNuVSA8k+gCONcUCS/UjLEYvYps+e8uBtfgXgvhwfNug==", "dev": true, + "license": "MIT", "dependencies": { "@types/express": "*" } }, "node_modules/@types/serve-static": { - "version": "1.15.2", - "resolved": "https://registry.npmjs.org/@types/serve-static/-/serve-static-1.15.2.tgz", - "integrity": "sha512-J2LqtvFYCzaj8pVYKw8klQXrLLk7TBZmQ4ShlcdkELFKGwGMfevMLneMMRkMgZxotOD9wg497LpC7O8PcvAmfw==", + "version": "1.15.7", + "resolved": "https://registry.npmjs.org/@types/serve-static/-/serve-static-1.15.7.tgz", + "integrity": "sha512-W8Ym+h8nhuRwaKPaDw34QUkwsGi6Rc4yYqvKFo5rm2FUEhCFbzVWrxXUxuKK8TASjWsysJY0nsmNCGhCOIsrOw==", "dev": true, + "license": "MIT", "dependencies": { "@types/http-errors": "*", - "@types/mime": "*", - "@types/node": "*" + "@types/node": "*", + "@types/send": "*" } }, "node_modules/@types/sockjs": { - "version": "0.3.33", - "resolved": "https://registry.npmjs.org/@types/sockjs/-/sockjs-0.3.33.tgz", - "integrity": "sha512-f0KEEe05NvUnat+boPTZ0dgaLZ4SfSouXUgv5noUiefG2ajgKjmETo9ZJyuqsl7dfl2aHlLJUiki6B4ZYldiiw==", + "version": "0.3.36", + "resolved": "https://registry.npmjs.org/@types/sockjs/-/sockjs-0.3.36.tgz", + "integrity": "sha512-MK9V6NzAS1+Ud7JV9lJLFqW85VbC9dq3LmwZCuBe4wBDgKC0Kj/jd8Xl+nSviU+Qc3+m7umHHyHg//2KSa0a0Q==", "dev": true, + "license": "MIT", "dependencies": { "@types/node": "*" } }, "node_modules/@types/ws": { - "version": "8.5.5", - "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.5.5.tgz", - "integrity": "sha512-lwhs8hktwxSjf9UaZ9tG5M03PGogvFaH8gUgLNbN9HKIg0dvv6q+gkSuJ8HN4/VbyxkuLzCjlN7GquQ0gUJfIg==", + "version": "8.18.1", + "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz", + "integrity": "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==", "dev": true, + "license": "MIT", "dependencies": { "@types/node": "*" } @@ -630,6 +649,7 @@ "resolved": "https://registry.npmjs.org/anymatch/-/anymatch-3.1.3.tgz", "integrity": "sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw==", "dev": true, + "license": "ISC", "dependencies": { "normalize-path": "^3.0.0", "picomatch": "^2.0.4" @@ -639,16 +659,11 @@ } }, "node_modules/array-flatten": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-2.1.2.tgz", - "integrity": "sha512-hNfzcOV8W4NdualtqBFPyVO+54DSJuZGY9qT4pRroB6S9e3iiido2ISIC5h9R2sPJ8H3FHCIiEnsv1lPXO3KtQ==", - "dev": true - }, - "node_modules/balanced-match": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", - "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", - "dev": true + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz", + "integrity": "sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg==", + "dev": true, + "license": "MIT" }, "node_modules/batch": { "version": "0.6.1", @@ -657,12 +672,16 @@ "dev": true }, "node_modules/binary-extensions": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.2.0.tgz", - "integrity": "sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA==", + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.3.0.tgz", + "integrity": "sha512-Ceh+7ox5qe7LJuLHoY0feh3pHuUDHAcRUeyL2VYghZwfpkNIy/+8Ocg0a3UuSoYzavmylwuLWQOf3hl0jjMMIw==", "dev": true, + "license": "MIT", "engines": { "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/body-parser": { @@ -670,6 +689,7 @@ "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.3.tgz", "integrity": "sha512-7rAxByjUMqQ3/bHJy7D6OGXvx/MMc4IqBn/X0fcM1QUcAItpZrBEYhWGem+tzXH90c+G01ypMcYJBO9Y30203g==", "dev": true, + "license": "MIT", "dependencies": { "bytes": "3.1.2", "content-type": "~1.0.5", @@ -694,6 +714,7 @@ "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.2.tgz", "integrity": "sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -703,6 +724,7 @@ "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", "dev": true, + "license": "MIT", "dependencies": { "ms": "2.0.0" } @@ -712,32 +734,22 @@ "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } }, "node_modules/bonjour-service": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/bonjour-service/-/bonjour-service-1.1.1.tgz", - "integrity": "sha512-Z/5lQRMOG9k7W+FkeGTNjh7htqn/2LMnfOvBZ8pynNZCM9MwkQkI3zeI4oz09uWdcgmgHugVvBqxGg4VQJ5PCg==", + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/bonjour-service/-/bonjour-service-1.3.0.tgz", + "integrity": "sha512-3YuAUiSkWykd+2Azjgyxei8OWf8thdn8AITIog2M4UICzoqfjlqr64WIjEXZllf/W6vK1goqleSR6brGomxQqA==", "dev": true, + "license": "MIT", "dependencies": { - "array-flatten": "^2.1.2", - "dns-equal": "^1.0.0", "fast-deep-equal": "^3.1.3", "multicast-dns": "^7.2.5" } }, - "node_modules/brace-expansion": { - "version": "1.1.11", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", - "integrity": "sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==", - "dev": true, - "dependencies": { - "balanced-match": "^1.0.0", - "concat-map": "0.0.1" - } - }, "node_modules/braces": { "version": "3.0.3", "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", @@ -788,6 +800,22 @@ "integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==", "dev": true }, + "node_modules/bundle-name": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/bundle-name/-/bundle-name-4.1.0.tgz", + "integrity": "sha512-tjwM5exMg6BGRI+kNmTntNsvdZS1X8BFYS6tnJ2hdH0kVxM6/eVZ2xy+FqStSWvYmtfFMDLIxurorHwDKfDz5Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "run-applescript": "^7.0.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/bytes": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.0.0.tgz", @@ -797,17 +825,29 @@ "node": ">= 0.8" } }, - "node_modules/call-bind": { - "version": "1.0.7", - "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.7.tgz", - "integrity": "sha512-GHTSNSYICQ7scH7sZ+M2rFopRoLh8t2bLSW6BbgrtLsahOIB5iyAVJf9GjWK3cYTDaMj4XdBpM1cA6pIS0Kv2w==", + "node_modules/call-bind-apply-helpers": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz", + "integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==", "dev": true, + "license": "MIT", "dependencies": { - "es-define-property": "^1.0.0", "es-errors": "^1.3.0", - "function-bind": "^1.1.2", - "get-intrinsic": "^1.2.4", - "set-function-length": "^1.2.1" + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/call-bound": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/call-bound/-/call-bound-1.0.4.tgz", + "integrity": "sha512-+ys997U96po4Kx/ABpBCqhA9EuxJaQWDQg7295H4hBphv3IZg0boBKuwYpt4YXp6MZ5AmZQnU/tyMTlRpaSejg==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "get-intrinsic": "^1.3.0" }, "engines": { "node": ">= 0.4" @@ -837,16 +877,11 @@ ] }, "node_modules/chokidar": { - "version": "3.5.3", - "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz", - "integrity": "sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw==", + "version": "3.6.0", + "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.6.0.tgz", + "integrity": "sha512-7VT13fmjotKpGipCW9JEQAusEPE+Ei8nl6/g4FBAmIm0GOOLMua9NDDo/DWp0ZAxCr3cPq5ZpBqmPAQgDda2Pw==", "dev": true, - "funding": [ - { - "type": "individual", - "url": "https://paulmillr.com/funding/" - } - ], + "license": "MIT", "dependencies": { "anymatch": "~3.1.2", "braces": "~3.0.2", @@ -859,6 +894,9 @@ "engines": { "node": ">= 8.10.0" }, + "funding": { + "url": "https://paulmillr.com/funding/" + }, "optionalDependencies": { "fsevents": "~2.3.2" } @@ -940,12 +978,6 @@ "ms": "2.0.0" } }, - "node_modules/concat-map": { - "version": "0.0.1", - "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", - "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==", - "dev": true - }, "node_modules/connect-history-api-fallback": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/connect-history-api-fallback/-/connect-history-api-fallback-2.0.0.tgz", @@ -960,6 +992,7 @@ "resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-0.5.4.tgz", "integrity": "sha512-FveZTNuGw04cxlAiWbzi6zTAL/lhehaWbTtgluJh4/E95DqMwTmha3KZN1aAWA8cFIhHzMZUvLevkw5Rqk+tSQ==", "dev": true, + "license": "MIT", "dependencies": { "safe-buffer": "5.2.1" }, @@ -985,13 +1018,15 @@ "type": "consulting", "url": "https://feross.org/support" } - ] + ], + "license": "MIT" }, "node_modules/content-type": { "version": "1.0.5", "resolved": "https://registry.npmjs.org/content-type/-/content-type-1.0.5.tgz", "integrity": "sha512-nTjqfcBFEipKdXCv4YDQWCfmcLZKm81ldF0pAopTvyrFGVbcR6P/VAAd5G7N+0tTr8QqiU0tFadD6FK4NtJwOA==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.6" } @@ -1001,6 +1036,7 @@ "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.1.tgz", "integrity": "sha512-6DnInpx7SJ2AK3+CTUE/ZM0vWTUboZCegxhC2xiIydHR9jNuTAASBrfEpHhiGOZw/nX51bHt6YQl8jsGo4y/0w==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.6" } @@ -1009,7 +1045,8 @@ "version": "1.0.6", "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.0.6.tgz", "integrity": "sha512-QADzlaHc8icV8I7vbaJXJwod9HWYp8uCqf1xa4OfNu1T7JVxQIrUgOWtHdNDtPiywmFbiS12VjotIXLrKM3orQ==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/copy-webpack-plugin": { "version": "12.0.2", @@ -1146,42 +1183,47 @@ "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", "dev": true }, - "node_modules/default-gateway": { - "version": "6.0.3", - "resolved": "https://registry.npmjs.org/default-gateway/-/default-gateway-6.0.3.tgz", - "integrity": "sha512-fwSOJsbbNzZ/CUFpqFBqYfYNLj1NbMPm8MMCIzHjC83iSJRBEGmDUxU+WP661BaBQImeC2yHwXtz+P/O9o+XEg==", + "node_modules/default-browser": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/default-browser/-/default-browser-5.2.1.tgz", + "integrity": "sha512-WY/3TUME0x3KPYdRRxEJJvXRHV4PyPoUsxtZa78lwItwRQRHhd2U9xOscaT/YTf8uCXIAjeJOFBVEh/7FtD8Xg==", "dev": true, + "license": "MIT", "dependencies": { - "execa": "^5.0.0" + "bundle-name": "^4.1.0", + "default-browser-id": "^5.0.0" }, "engines": { - "node": ">= 10" + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/define-data-property": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.4.tgz", - "integrity": "sha512-rBMvIzlpA8v6E+SJZoo++HAYqsLrkg7MSfIinMPFhmkorw7X+dOXVJQs+QT69zGkzMyfDnIMN2Wid1+NbL3T+A==", + "node_modules/default-browser-id": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/default-browser-id/-/default-browser-id-5.0.0.tgz", + "integrity": "sha512-A6p/pu/6fyBcA1TRz/GqWYPViplrftcW2gZC9q79ngNCKAeR/X3gcEdXQHl4KNXV+3wgIJ1CPkJQ3IHM6lcsyA==", "dev": true, - "dependencies": { - "es-define-property": "^1.0.0", - "es-errors": "^1.3.0", - "gopd": "^1.0.1" - }, + "license": "MIT", "engines": { - "node": ">= 0.4" + "node": ">=18" }, "funding": { - "url": "https://github.com/sponsors/ljharb" + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/define-lazy-prop": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-2.0.0.tgz", - "integrity": "sha512-Ds09qNh8yw3khSjiJjiUInaGX9xlqZDY7JVryGxdxV7NPeuqQfplOpQ66yJFZut3jLa5zOwkXw1g9EI2uKh4Og==", + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-3.0.0.tgz", + "integrity": "sha512-N+MeXYoqr3pOgn8xfyRPREN7gHakLYjhsHhWGT3fWAiL4IkAt0iDw14QiiEm2bE30c5XX5q0FtAA3CK5f9/BUg==", "dev": true, + "license": "MIT", "engines": { - "node": ">=8" + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/depd": { @@ -1198,6 +1240,7 @@ "resolved": "https://registry.npmjs.org/destroy/-/destroy-1.2.0.tgz", "integrity": "sha512-2sJGJTaXIIaR1w4iJSNoN0hnMY7Gpc/n8D4qSCJw8QqFWXf7cuAgnEHxBpweaVcPevC2l3KpjYCx3NypQQgaJg==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8", "npm": "1.2.8000 || >= 1.4.16" @@ -1209,17 +1252,12 @@ "integrity": "sha512-ZIzRpLJrOj7jjP2miAtgqIfmzbxa4ZOr5jJc601zklsfEx9oTzmmj2nVpIPRpNlRTIh8lc1kyViIY7BWSGNmKw==", "dev": true }, - "node_modules/dns-equal": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/dns-equal/-/dns-equal-1.0.0.tgz", - "integrity": "sha512-z+paD6YUQsk+AbGCEM4PrOXSss5gd66QfcVBFTKR/HpFL9jCqikS94HYwKww6fQyO7IxrIIyUu+g0Ka9tUS2Cg==", - "dev": true - }, "node_modules/dns-packet": { "version": "5.6.1", "resolved": "https://registry.npmjs.org/dns-packet/-/dns-packet-5.6.1.tgz", "integrity": "sha512-l4gcSouhcgIKRvyy99RNVOgxXiicE+2jZoNmaNmZ6JXiGajBOJAesk1OBlJuM5k2c+eudGdLxDqXuPCKIj6kpw==", "dev": true, + "license": "MIT", "dependencies": { "@leichtgewicht/ip-codec": "^2.0.1" }, @@ -1227,11 +1265,27 @@ "node": ">=6" } }, + "node_modules/dunder-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz", + "integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.1", + "es-errors": "^1.3.0", + "gopd": "^1.2.0" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/ee-first": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz", "integrity": "sha512-WMwm9LhRUo+WUaRN+vRuETqG89IgZphVSNkdFgeb6sS/E4OrDIN7t48CAewSHXc6C8lefD8KKfr5vY61brQlow==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/electron-to-chromium": { "version": "1.4.528", @@ -1240,10 +1294,11 @@ "dev": true }, "node_modules/encodeurl": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz", - "integrity": "sha512-TPJXq8JqFaVYm2CWmPvnP2Iyo4ZSM7/QKcSmuMLDObfpH5fi7RUGmd/rTDf+rut/saiDiQEeVTNgAmJEdAOx0w==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", + "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -1274,13 +1329,11 @@ } }, "node_modules/es-define-property": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.0.tgz", - "integrity": "sha512-jxayLKShrEqqzJ0eumQbVhTYQM27CfT1T35+gCgDFoL82JLsXqTJ76zv6A0YLOgEnLUMvLzsDsGIrl8NFpT2gQ==", + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz", + "integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==", "dev": true, - "dependencies": { - "get-intrinsic": "^1.2.4" - }, + "license": "MIT", "engines": { "node": ">= 0.4" } @@ -1290,6 +1343,7 @@ "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.4" } @@ -1300,6 +1354,19 @@ "integrity": "sha512-JUFAyicQV9mXc3YRxPnDlrfBKpqt6hUYzz9/boprUJHs4e4KVr3XwOF70doO6gwXUor6EWZJAyWAfKki84t20Q==", "dev": true }, + "node_modules/es-object-atoms": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz", + "integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/escalade": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.1.tgz", @@ -1363,6 +1430,7 @@ "resolved": "https://registry.npmjs.org/etag/-/etag-1.8.1.tgz", "integrity": "sha512-aIL5Fx7mawVa300al2BnEE4iNvo1qETxLrPI/o05L7z6go7fCw1J6EQmbK4FmJ2AS7kgVF/KEZWufBfdClMcPg==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.6" } @@ -1382,34 +1450,12 @@ "node": ">=0.8.x" } }, - "node_modules/execa": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/execa/-/execa-5.1.1.tgz", - "integrity": "sha512-8uSpZZocAZRBAPIEINJj3Lo9HyGitllczc27Eh5YYojjMFMn8yHMDMaUHE2Jqfq05D/wucwI4JGURyXt1vchyg==", - "dev": true, - "dependencies": { - "cross-spawn": "^7.0.3", - "get-stream": "^6.0.0", - "human-signals": "^2.1.0", - "is-stream": "^2.0.0", - "merge-stream": "^2.0.0", - "npm-run-path": "^4.0.1", - "onetime": "^5.1.2", - "signal-exit": "^3.0.3", - "strip-final-newline": "^2.0.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sindresorhus/execa?sponsor=1" - } - }, "node_modules/express": { - "version": "4.21.1", - "resolved": "https://registry.npmjs.org/express/-/express-4.21.1.tgz", - "integrity": "sha512-YSFlK1Ee0/GC8QaO91tHcDxJiE/X4FbpAyQWkxAvG6AXCuR65YzK8ua6D9hvi/TzUfZMpc+BwuM1IPw8fmQBiQ==", + "version": "4.21.2", + "resolved": "https://registry.npmjs.org/express/-/express-4.21.2.tgz", + "integrity": "sha512-28HqgMZAmih1Czt9ny7qr6ek2qddF4FclbMzwhCREB6OFfH+rXAnuNCwo1/wFvrtbgsQDb4kSbX9de9lFbrXnA==", "dev": true, + "license": "MIT", "dependencies": { "accepts": "~1.3.8", "array-flatten": "1.1.1", @@ -1430,7 +1476,7 @@ "methods": "~1.1.2", "on-finished": "2.4.1", "parseurl": "~1.3.3", - "path-to-regexp": "0.1.10", + "path-to-regexp": "0.1.12", "proxy-addr": "~2.0.7", "qs": "6.13.0", "range-parser": "~1.2.1", @@ -1445,19 +1491,18 @@ }, "engines": { "node": ">= 0.10.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" } }, - "node_modules/express/node_modules/array-flatten": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz", - "integrity": "sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg==", - "dev": true - }, "node_modules/express/node_modules/debug": { "version": "2.6.9", "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", "dev": true, + "license": "MIT", "dependencies": { "ms": "2.0.0" } @@ -1467,15 +1512,7 @@ "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", "dev": true, - "engines": { - "node": ">= 0.8" - } - }, - "node_modules/express/node_modules/encodeurl": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", - "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", - "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -1498,13 +1535,15 @@ "type": "consulting", "url": "https://feross.org/support" } - ] + ], + "license": "MIT" }, "node_modules/express/node_modules/statuses": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -1603,6 +1642,7 @@ "resolved": "https://registry.npmjs.org/finalhandler/-/finalhandler-1.3.1.tgz", "integrity": "sha512-6BN9trH7bp3qvnrRyzsBz+g3lZxTNZTbVO2EV1CS0WIcDbawYVdYvGflME/9QP0h0pYlCDBCTjYa9nZzMDpyxQ==", "dev": true, + "license": "MIT", "dependencies": { "debug": "2.6.9", "encodeurl": "~2.0.0", @@ -1621,24 +1661,17 @@ "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", "dev": true, + "license": "MIT", "dependencies": { "ms": "2.0.0" } }, - "node_modules/finalhandler/node_modules/encodeurl": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", - "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", - "dev": true, - "engines": { - "node": ">= 0.8" - } - }, "node_modules/finalhandler/node_modules/statuses": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -1681,6 +1714,7 @@ "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz", "integrity": "sha512-buRG0fpBtRHSTCOASe6hD258tEubFoRLb4ZNA6NxMVHNw2gOcwHo9wyablzMzOA5z9xA9L1KNjk/Nt6MT9aYow==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.6" } @@ -1690,28 +1724,18 @@ "resolved": "https://registry.npmjs.org/fresh/-/fresh-0.5.2.tgz", "integrity": "sha512-zJ2mQYM18rEFOudeV4GShTGIQ7RbzA7ozbU9I/XBpm7kqgMywgmylMwXHxZJmkVoYkna9d2pVXVXPdYTP9ej8Q==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.6" } }, - "node_modules/fs-monkey": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/fs-monkey/-/fs-monkey-1.0.4.tgz", - "integrity": "sha512-INM/fWAxMICjttnD0DX1rBvinKskj5G1w+oy/pnm9u/tSlnBrzFonJMcalKJ30P8RRsPzKcCG7Q8l0jx5Fh9YQ==", - "dev": true - }, - "node_modules/fs.realpath": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", - "integrity": "sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==", - "dev": true - }, "node_modules/fsevents": { "version": "2.3.3", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", "dev": true, "hasInstallScript": true, + "license": "MIT", "optional": true, "os": [ "darwin" @@ -1730,16 +1754,22 @@ } }, "node_modules/get-intrinsic": { - "version": "1.2.4", - "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.4.tgz", - "integrity": "sha512-5uYhsJH8VJBTv7oslg4BznJYhDoRI6waYCxMmCdnTrcCrHA/fCFKoTFz2JKKE0HdDFUF7/oQuhzumXJK7paBRQ==", + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz", + "integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==", "dev": true, + "license": "MIT", "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "es-define-property": "^1.0.1", "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", "function-bind": "^1.1.2", - "has-proto": "^1.0.1", - "has-symbols": "^1.0.3", - "hasown": "^2.0.0" + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "has-symbols": "^1.1.0", + "hasown": "^2.0.2", + "math-intrinsics": "^1.1.0" }, "engines": { "node": ">= 0.4" @@ -1748,36 +1778,18 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/get-stream": { - "version": "6.0.1", - "resolved": "https://registry.npmjs.org/get-stream/-/get-stream-6.0.1.tgz", - "integrity": "sha512-ts6Wi+2j3jQjqi70w5AlN8DFnkSwC+MqmxEzdEALB2qXZYV3X/b1CTfgPLGJNMeAWxdPfU8FO1ms3NUfaHCPYg==", - "dev": true, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/glob": { - "version": "7.2.3", - "resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz", - "integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==", + "node_modules/get-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz", + "integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==", "dev": true, + "license": "MIT", "dependencies": { - "fs.realpath": "^1.0.0", - "inflight": "^1.0.4", - "inherits": "2", - "minimatch": "^3.1.1", - "once": "^1.3.0", - "path-is-absolute": "^1.0.0" + "dunder-proto": "^1.0.1", + "es-object-atoms": "^1.0.0" }, "engines": { - "node": "*" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" + "node": ">= 0.4" } }, "node_modules/glob-parent": { @@ -1820,12 +1832,13 @@ } }, "node_modules/gopd": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.0.1.tgz", - "integrity": "sha512-d65bNlIadxvpb/A2abVdlqKqV563juRnZ1Wtk6s1sIR8uNsXR70xqIzVqxVf1eTqDunwT2MkczEeaezCKTZhwA==", + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz", + "integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==", "dev": true, - "dependencies": { - "get-intrinsic": "^1.1.3" + "license": "MIT", + "engines": { + "node": ">= 0.4" }, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -1864,35 +1877,12 @@ "node": ">=8" } }, - "node_modules/has-property-descriptors": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.2.tgz", - "integrity": "sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg==", - "dev": true, - "dependencies": { - "es-define-property": "^1.0.0" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/has-proto": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.3.tgz", - "integrity": "sha512-SJ1amZAJUiZS+PhsVLf5tGydlaVB8EdFpaSO4gmiUKUOxk8qzn5AIy4ZeJUmh22znIdk/uMAUT2pl3FxzVUH+Q==", - "dev": true, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/has-symbols": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.0.3.tgz", - "integrity": "sha512-l3LCuF6MgDNwTDKkdYGEihYjt5pRPbEg46rtlmnSPlUbgmB8LOIrKJbYYFBSbnPaJexMKtiPO8hmeRjRz2Td+A==", + "node_modules/has-symbols": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz", + "integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.4" }, @@ -1905,6 +1895,7 @@ "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", "dev": true, + "license": "MIT", "dependencies": { "function-bind": "^1.1.2" }, @@ -1924,22 +1915,6 @@ "wbuf": "^1.1.0" } }, - "node_modules/html-entities": { - "version": "2.4.0", - "resolved": "https://registry.npmjs.org/html-entities/-/html-entities-2.4.0.tgz", - "integrity": "sha512-igBTJcNNNhvZFRtm8uA6xMY6xYleeDwn3PeBCkDz7tHttv4F2hsDI2aPgNERWzvRcNYHNT3ymRaQzllmXj4YsQ==", - "dev": true, - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/mdevils" - }, - { - "type": "patreon", - "url": "https://patreon.com/mdevils" - } - ] - }, "node_modules/http-deceiver": { "version": "1.2.7", "resolved": "https://registry.npmjs.org/http-deceiver/-/http-deceiver-1.2.7.tgz", @@ -1951,6 +1926,7 @@ "resolved": "https://registry.npmjs.org/http-errors/-/http-errors-2.0.0.tgz", "integrity": "sha512-FtwrG/euBzaEjYeRqOgly7G0qviiXoJWnvEH2Z1plBdXgbyjv34pHTSb9zoeHMyDy33+DWy5Wt9Wo+TURtOYSQ==", "dev": true, + "license": "MIT", "dependencies": { "depd": "2.0.0", "inherits": "2.0.4", @@ -1967,6 +1943,7 @@ "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -1975,13 +1952,15 @@ "version": "2.0.4", "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", - "dev": true + "dev": true, + "license": "ISC" }, "node_modules/http-errors/node_modules/statuses": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -2007,10 +1986,11 @@ } }, "node_modules/http-proxy-middleware": { - "version": "2.0.6", - "resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.6.tgz", - "integrity": "sha512-ya/UeJ6HVBYxrgYotAZo1KvPWlgB48kUJLDePFeneHsVujFaW5WNj2NgWCAE//B1Dl02BIfYlpNgBy8Kf8Rjmw==", + "version": "2.0.9", + "resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.9.tgz", + "integrity": "sha512-c1IyJYLYppU574+YI7R4QyX2ystMtVXZwIdzazUIPIJsHuWNd+mho2j+bKoHftndicGj9yh+xjd+l0yj7VeT1Q==", "dev": true, + "license": "MIT", "dependencies": { "@types/http-proxy": "^1.17.8", "http-proxy": "^1.18.1", @@ -2030,13 +2010,14 @@ } } }, - "node_modules/human-signals": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/human-signals/-/human-signals-2.1.0.tgz", - "integrity": "sha512-B4FFZ6q/T2jhhksgkbEW3HBvWIfDW85snkQgawt07S7J5QXTk6BkNV+0yAeZrM5QpMAdYlocGoljn0sJ/WQkFw==", + "node_modules/hyperdyperid": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/hyperdyperid/-/hyperdyperid-1.2.0.tgz", + "integrity": "sha512-Y93lCzHYgGWdrJ66yIktxiaGULYc6oGiABxhcO5AufBeOyoIdZF7bIfLaOrbM0iGIOXQQgxxRrFEnb+Y6w1n4A==", "dev": true, + "license": "MIT", "engines": { - "node": ">=10.17.0" + "node": ">=10.18" } }, "node_modules/iconv-lite": { @@ -2044,6 +2025,7 @@ "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.24.tgz", "integrity": "sha512-v3MXnZAcvnywkTUEZomIActle7RXXeedOR31wwl7VlyoXO4Qi9arvSenNQWne1TcRwhCL1HwLI21bEqdpj8/rA==", "dev": true, + "license": "MIT", "dependencies": { "safer-buffer": ">= 2.1.2 < 3" }, @@ -2080,16 +2062,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/inflight": { - "version": "1.0.6", - "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", - "integrity": "sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==", - "dev": true, - "dependencies": { - "once": "^1.3.0", - "wrappy": "1" - } - }, "node_modules/inherits": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.3.tgz", @@ -2119,6 +2091,7 @@ "resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz", "integrity": "sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==", "dev": true, + "license": "MIT", "dependencies": { "binary-extensions": "^2.0.0" }, @@ -2139,15 +2112,16 @@ } }, "node_modules/is-docker": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/is-docker/-/is-docker-2.2.1.tgz", - "integrity": "sha512-F+i2BKsFrH66iaUFc0woD8sLy8getkwTwtOBjvs56Cx4CgJDeKQeqfz8wAYiSb8JOprWhHH5p77PbmYCvvUuXQ==", + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-docker/-/is-docker-3.0.0.tgz", + "integrity": "sha512-eljcgEDlEns/7AXFosB5K/2nCM4P7FQPkGc/DWLy5rmFEWvZayGrik1d9/QIY5nJ4f9YsVvBkA6kJpHn9rISdQ==", "dev": true, + "license": "MIT", "bin": { "is-docker": "cli.js" }, "engines": { - "node": ">=8" + "node": "^12.20.0 || ^14.13.1 || >=16.0.0" }, "funding": { "url": "https://github.com/sponsors/sindresorhus" @@ -2174,6 +2148,38 @@ "node": ">=0.10.0" } }, + "node_modules/is-inside-container": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/is-inside-container/-/is-inside-container-1.0.0.tgz", + "integrity": "sha512-KIYLCCJghfHZxqjYBE7rEy0OBuTd5xCHS7tHVgvCLkx7StIoaxwNW3hCALgEUjFfeRk+MG/Qxmp/vtETEF3tRA==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-docker": "^3.0.0" + }, + "bin": { + "is-inside-container": "cli.js" + }, + "engines": { + "node": ">=14.16" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/is-network-error": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/is-network-error/-/is-network-error-1.1.0.tgz", + "integrity": "sha512-tUdRRAnhT+OtCZR/LxZelH/C7QtjtFrTu5tXCA8pl55eTUElUHT+GPYV8MBMBvea/j+NxQqVt3LbWMRir7Gx9g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=16" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/is-number": { "version": "7.0.0", "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", @@ -2207,28 +2213,20 @@ "node": ">=0.10.0" } }, - "node_modules/is-stream": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/is-stream/-/is-stream-2.0.1.tgz", - "integrity": "sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg==", - "dev": true, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/is-wsl": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/is-wsl/-/is-wsl-2.2.0.tgz", - "integrity": "sha512-fKzAra0rGJUUBwGBgNkHZuToZcn+TtXHpeCgmkMJMMYx1sQDYaCSyjJBSCa2nH1DGm7s3n1oBnohoVTBaN7Lww==", + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/is-wsl/-/is-wsl-3.1.0.tgz", + "integrity": "sha512-UcVfVfaK4Sc4m7X3dUSoHoozQGBEFeDC+zVo06t98xe8CzHSZZBekNXH+tu0NalHolcJ/QAGqS46Hef7QXBIMw==", "dev": true, + "license": "MIT", "dependencies": { - "is-docker": "^2.0.0" + "is-inside-container": "^1.0.0" }, "engines": { - "node": ">=8" + "node": ">=16" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/isarray": { @@ -2288,13 +2286,14 @@ } }, "node_modules/launch-editor": { - "version": "2.6.0", - "resolved": "https://registry.npmjs.org/launch-editor/-/launch-editor-2.6.0.tgz", - "integrity": "sha512-JpDCcQnyAAzZZaZ7vEiSqL690w7dAEyLao+KC96zBplnYbJS7TYNjvM3M7y3dGz+v7aIsJk3hllWuc0kWAjyRQ==", + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/launch-editor/-/launch-editor-2.10.0.tgz", + "integrity": "sha512-D7dBRJo/qcGX9xlvt/6wUYzQxjh5G1RvZPgPv8vi4KRU99DVQL/oW7tnVOCCTm2HGeo3C5HvGE5Yrh6UBoZ0vA==", "dev": true, + "license": "MIT", "dependencies": { "picocolors": "^1.0.0", - "shell-quote": "^1.7.3" + "shell-quote": "^1.8.1" } }, "node_modules/loader-runner": { @@ -2318,32 +2317,146 @@ "node": ">=8" } }, + "node_modules/math-intrinsics": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz", + "integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, "node_modules/media-typer": { "version": "0.3.0", "resolved": "https://registry.npmjs.org/media-typer/-/media-typer-0.3.0.tgz", "integrity": "sha512-dq+qelQ9akHpcOl/gUVRTxVIOkAJ1wR3QAvb4RsVjS8oVoFjDGTc679wJYmUmknUF5HwMLOgb5O+a3KxfWapPQ==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.6" } }, "node_modules/memfs": { - "version": "3.5.3", - "resolved": "https://registry.npmjs.org/memfs/-/memfs-3.5.3.tgz", - "integrity": "sha512-UERzLsxzllchadvbPs5aolHh65ISpKpM+ccLbOJ8/vvpBKmAWf+la7dXFy7Mr0ySHbdHrFv5kGFCUHHe6GFEmw==", + "version": "4.17.2", + "resolved": "https://registry.npmjs.org/memfs/-/memfs-4.17.2.tgz", + "integrity": "sha512-NgYhCOWgovOXSzvYgUW0LQ7Qy72rWQMGGFJDoWg4G30RHd3z77VbYdtJ4fembJXBy8pMIUA31XNAupobOQlwdg==", "dev": true, + "license": "Apache-2.0", "dependencies": { - "fs-monkey": "^1.0.4" + "@jsonjoy.com/json-pack": "^1.0.3", + "@jsonjoy.com/util": "^1.3.0", + "tree-dump": "^1.0.1", + "tslib": "^2.0.0" }, "engines": { "node": ">= 4.0.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/streamich" } }, + "node_modules/memfs/node_modules/@jsonjoy.com/base64": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@jsonjoy.com/base64/-/base64-1.1.2.tgz", + "integrity": "sha512-q6XAnWQDIMA3+FTiOYajoYqySkO+JSat0ytXGSuRdq9uXE7o92gzuQwQM14xaCRlBLGq3v5miDGC4vkVTn54xA==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=10.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/streamich" + }, + "peerDependencies": { + "tslib": "2" + } + }, + "node_modules/memfs/node_modules/@jsonjoy.com/json-pack": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/@jsonjoy.com/json-pack/-/json-pack-1.2.0.tgz", + "integrity": "sha512-io1zEbbYcElht3tdlqEOFxZ0dMTYrHz9iMf0gqn1pPjZFTCgM5R4R5IMA20Chb2UPYYsxjzs8CgZ7Nb5n2K2rA==", + "dev": true, + "license": "Apache-2.0", + "dependencies": { + "@jsonjoy.com/base64": "^1.1.1", + "@jsonjoy.com/util": "^1.1.2", + "hyperdyperid": "^1.2.0", + "thingies": "^1.20.0" + }, + "engines": { + "node": ">=10.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/streamich" + }, + "peerDependencies": { + "tslib": "2" + } + }, + "node_modules/memfs/node_modules/@jsonjoy.com/util": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/@jsonjoy.com/util/-/util-1.6.0.tgz", + "integrity": "sha512-sw/RMbehRhN68WRtcKCpQOPfnH6lLP4GJfqzi3iYej8tnzpZUDr6UkZYJjcjjC0FWEJOJbyM3PTIwxucUmDG2A==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=10.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/streamich" + }, + "peerDependencies": { + "tslib": "2" + } + }, + "node_modules/memfs/node_modules/thingies": { + "version": "1.21.0", + "resolved": "https://registry.npmjs.org/thingies/-/thingies-1.21.0.tgz", + "integrity": "sha512-hsqsJsFMsV+aD4s3CWKk85ep/3I9XzYV/IXaSouJMYIoDlgyi11cBhsqYe9/geRfB0YIikBQg6raRaM+nIMP9g==", + "dev": true, + "license": "Unlicense", + "engines": { + "node": ">=10.18" + }, + "peerDependencies": { + "tslib": "^2" + } + }, + "node_modules/memfs/node_modules/tree-dump": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/tree-dump/-/tree-dump-1.0.3.tgz", + "integrity": "sha512-il+Cv80yVHFBwokQSfd4bldvr1Md951DpgAGfmhydt04L+YzHgubm2tQ7zueWDcGENKHq0ZvGFR/hjvNXilHEg==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": ">=10.0" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/streamich" + }, + "peerDependencies": { + "tslib": "2" + } + }, + "node_modules/memfs/node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "dev": true, + "license": "0BSD" + }, "node_modules/merge-descriptors": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-1.0.3.tgz", "integrity": "sha512-gaNvAS7TZ897/rVaZ0nMtAyxNyi/pdbjbAwUpFQpN70GqnVfOiXpeUUMKRBmzXaSQ8DdTX4/0ms62r2K+hE6mQ==", "dev": true, + "license": "MIT", "funding": { "url": "https://github.com/sponsors/sindresorhus" } @@ -2369,6 +2482,7 @@ "resolved": "https://registry.npmjs.org/methods/-/methods-1.1.2.tgz", "integrity": "sha512-iclAHeNqNm68zFtnZ0e+1L2yUIdvzNoauKU4WBA3VvH/vPFieF7qfRlwUZU+DA9P9bPXIS90ulxoUoCH23sV2w==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.6" } @@ -2392,6 +2506,7 @@ "resolved": "https://registry.npmjs.org/mime/-/mime-1.6.0.tgz", "integrity": "sha512-x0Vn8spI+wuJ1O6S7gnbaQg8Pxh4NNHb7KSINmEWKiPE4RKOplvijn+NkmYmmRgP68mc70j2EbeTFRsrswaQeg==", "dev": true, + "license": "MIT", "bin": { "mime": "cli.js" }, @@ -2420,33 +2535,12 @@ "node": ">= 0.6" } }, - "node_modules/mimic-fn": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/mimic-fn/-/mimic-fn-2.1.0.tgz", - "integrity": "sha512-OqbOk5oEQeAZ8WXWydlu9HJjz9WVdEIvamMCcXmuqUYjTknH/sqsWvhQ3vgwKFRR1HpjvNBKQ37nbJgYzGqGcg==", - "dev": true, - "engines": { - "node": ">=6" - } - }, "node_modules/minimalistic-assert": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/minimalistic-assert/-/minimalistic-assert-1.0.1.tgz", "integrity": "sha512-UtJcAD4yEaGtjPezWuO9wC4nwUnVH/8/Im3yEHQP4b67cXlD/Qr9hdITCU1xDbSEXg2XKNaP8jsReV7vQd00/A==", "dev": true }, - "node_modules/minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", - "dev": true, - "dependencies": { - "brace-expansion": "^1.1.7" - }, - "engines": { - "node": "*" - } - }, "node_modules/ms": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", @@ -2458,6 +2552,7 @@ "resolved": "https://registry.npmjs.org/multicast-dns/-/multicast-dns-7.2.5.tgz", "integrity": "sha512-2eznPJP8z2BFLX50tf0LuODrpINqP1RVIm/CObbTcBRITQgmC/TjcREF1NeTBzIcR5XO/ukWo+YHOjBbFwIupg==", "dev": true, + "license": "MIT", "dependencies": { "dns-packet": "^5.2.2", "thunky": "^1.0.2" @@ -2486,6 +2581,7 @@ "resolved": "https://registry.npmjs.org/node-forge/-/node-forge-1.3.1.tgz", "integrity": "sha512-dPEtOeMvF9VMcYV/1Wb8CPoVAXtp6MKMlcbAt4ddqmGqUJ6fQZFXkNZNkNlfevtNkGtaSoXf/vNNNSvgrdXwtA==", "dev": true, + "license": "(BSD-3-Clause OR GPL-2.0)", "engines": { "node": ">= 6.13.0" } @@ -2505,23 +2601,12 @@ "node": ">=0.10.0" } }, - "node_modules/npm-run-path": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/npm-run-path/-/npm-run-path-4.0.1.tgz", - "integrity": "sha512-S48WzZW777zhNIrn7gxOlISNAqi9ZC/uQFnRdbeIHhZhCA6UqpkOT8T1G7BvfdgP4Er8gF4sUbaS0i7QvIfCWw==", - "dev": true, - "dependencies": { - "path-key": "^3.0.0" - }, - "engines": { - "node": ">=8" - } - }, "node_modules/object-inspect": { - "version": "1.13.2", - "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.2.tgz", - "integrity": "sha512-IRZSRuzJiynemAXPYtPe5BoI/RESNYR7TYm50MC5Mqbd3Jmw5y790sErYw3V6SryFJD64b74qQQs9wn5Bg/k3g==", + "version": "1.13.4", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz", + "integrity": "sha512-W67iLl4J2EXEGTbfeHCffrjDfitvLANg0UlX3wFUUSTx92KXRFegMHUVgSqE+wvhAbi4WqjGg9czysTV2Epbew==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.4" }, @@ -2540,6 +2625,7 @@ "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.4.1.tgz", "integrity": "sha512-oVlzkg3ENAhCk2zdv7IJwd/QUD4z2RxRwpkcGY8psCVcCYZNq4wYnVWALHM+brtuJjePWiYF/ClmuDr8Ch5+kg==", "dev": true, + "license": "MIT", "dependencies": { "ee-first": "1.1.1" }, @@ -2556,42 +2642,20 @@ "node": ">= 0.8" } }, - "node_modules/once": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", - "integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==", - "dev": true, - "dependencies": { - "wrappy": "1" - } - }, - "node_modules/onetime": { - "version": "5.1.2", - "resolved": "https://registry.npmjs.org/onetime/-/onetime-5.1.2.tgz", - "integrity": "sha512-kbpaSSGJTWdAY5KPVeMOKXSrPtr8C8C7wodJbcsd51jRnmD+GZu8Y0VoU6Dm5Z4vWr0Ig/1NKuWRKf7j5aaYSg==", - "dev": true, - "dependencies": { - "mimic-fn": "^2.1.0" - }, - "engines": { - "node": ">=6" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/open": { - "version": "8.4.2", - "resolved": "https://registry.npmjs.org/open/-/open-8.4.2.tgz", - "integrity": "sha512-7x81NCL719oNbsq/3mh+hVrAWmFuEYUqrq/Iw3kUzH8ReypT9QQ0BLoJS7/G9k6N81XjW4qHWtjWwe/9eLy1EQ==", + "version": "10.1.2", + "resolved": "https://registry.npmjs.org/open/-/open-10.1.2.tgz", + "integrity": "sha512-cxN6aIDPz6rm8hbebcP7vrQNhvRcveZoJU72Y7vskh4oIm+BZwBECnx5nTmrlres1Qapvx27Qo1Auukpf8PKXw==", "dev": true, + "license": "MIT", "dependencies": { - "define-lazy-prop": "^2.0.0", - "is-docker": "^2.1.1", - "is-wsl": "^2.2.0" + "default-browser": "^5.2.1", + "define-lazy-prop": "^3.0.0", + "is-inside-container": "^1.0.0", + "is-wsl": "^3.1.0" }, "engines": { - "node": ">=12" + "node": ">=18" }, "funding": { "url": "https://github.com/sponsors/sindresorhus" @@ -2625,16 +2689,21 @@ } }, "node_modules/p-retry": { - "version": "4.6.2", - "resolved": "https://registry.npmjs.org/p-retry/-/p-retry-4.6.2.tgz", - "integrity": "sha512-312Id396EbJdvRONlngUx0NydfrIQ5lsYu0znKVUzVvArzEIt08V1qhtyESbGVd1FGX7UKtiFp5uwKZdM8wIuQ==", + "version": "6.2.1", + "resolved": "https://registry.npmjs.org/p-retry/-/p-retry-6.2.1.tgz", + "integrity": "sha512-hEt02O4hUct5wtwg4H4KcWgDdm+l1bOaEy/hWzd8xtXB9BqxTWBBhb+2ImAtH4Cv4rPjV76xN3Zumqk3k3AhhQ==", "dev": true, + "license": "MIT", "dependencies": { - "@types/retry": "0.12.0", + "@types/retry": "0.12.2", + "is-network-error": "^1.0.0", "retry": "^0.13.1" }, "engines": { - "node": ">=8" + "node": ">=16.17" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/p-try": { @@ -2664,15 +2733,6 @@ "node": ">=8" } }, - "node_modules/path-is-absolute": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz", - "integrity": "sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==", - "dev": true, - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/path-key": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", @@ -2689,10 +2749,11 @@ "dev": true }, "node_modules/path-to-regexp": { - "version": "0.1.10", - "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.10.tgz", - "integrity": "sha512-7lf7qcQidTku0Gu3YDPc8DJ1q7OOucfa/BSsIwjuh56VU7katFvuM8hULfkwB3Fns/rsVF7PwPKVw1sl5KQS9w==", - "dev": true + "version": "0.1.12", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.12.tgz", + "integrity": "sha512-RA1GjUVMnvYFxuqovrEqZoxxW5NUZqbwKtYz/Tt7nXerk0LbLblQmrsgdeOxV5SFHf0UDggjS/bSeOZwt1pmEQ==", + "dev": true, + "license": "MIT" }, "node_modules/path-type": { "version": "6.0.0", @@ -2748,6 +2809,7 @@ "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz", "integrity": "sha512-llQsMLSUDUPT44jdrU/O37qlnifitDP+ZwrmmZcoSKyLKvtZxpyV0n2/bD/N4tBAAZ/gJEdZU7KMraoK1+XYAg==", "dev": true, + "license": "MIT", "dependencies": { "forwarded": "0.2.0", "ipaddr.js": "1.9.1" @@ -2761,6 +2823,7 @@ "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.9.1.tgz", "integrity": "sha512-0KI/607xoxSToH7GjN1FfSbLoU0+btTicjsQSWQlh/hZykN8KpmMf7uYwPW3R+akZ6R/w18ZlXSHBYXiYUPO3g==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.10" } @@ -2779,6 +2842,7 @@ "resolved": "https://registry.npmjs.org/qs/-/qs-6.13.0.tgz", "integrity": "sha512-+38qI9SOr8tfZ4QmJNplMUxqjbe7LKvvZgWdExBOmd+egZTtjLB67Gu0HRX3u/XOq7UU2Nx6nsjvS16Z9uwfpg==", "dev": true, + "license": "BSD-3-Clause", "dependencies": { "side-channel": "^1.0.6" }, @@ -2824,6 +2888,7 @@ "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.1.tgz", "integrity": "sha512-Hrgsx+orqoygnmhFbKaHE6c296J+HTAQXoxEF6gNupROmmGJRoyzfG3ccAveqCBrwr/2yxQ5BVd/GTl5agOwSg==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.6" } @@ -2833,6 +2898,7 @@ "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.5.2.tgz", "integrity": "sha512-8zGqypfENjCIqGhgXToC8aB2r7YrBX+AQAfIPs/Mlk+BtPTztOvTS01NRW/3Eh60J+a48lt8qsCzirQ6loCVfA==", "dev": true, + "license": "MIT", "dependencies": { "bytes": "3.1.2", "http-errors": "2.0.0", @@ -2848,6 +2914,7 @@ "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.2.tgz", "integrity": "sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -2872,6 +2939,7 @@ "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz", "integrity": "sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA==", "dev": true, + "license": "MIT", "dependencies": { "picomatch": "^2.2.1" }, @@ -2949,6 +3017,7 @@ "resolved": "https://registry.npmjs.org/retry/-/retry-0.13.1.tgz", "integrity": "sha512-XQBQ3I8W1Cge0Seh+6gjj03LbmRFWuoszgK9ooCpwYIrhhoO80pfq4cUkU5DkknwfOfFteRwlZ56PYOGYyFWdg==", "dev": true, + "license": "MIT", "engines": { "node": ">= 4" } @@ -2964,19 +3033,17 @@ "node": ">=0.10.0" } }, - "node_modules/rimraf": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-3.0.2.tgz", - "integrity": "sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA==", + "node_modules/run-applescript": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/run-applescript/-/run-applescript-7.0.0.tgz", + "integrity": "sha512-9by4Ij99JUr/MCFBUkDKLWK3G9HVXmabKz9U5MlIAIuvuzkiOicRYs8XJLxX+xahD+mLiiCYDqF9dKAgtzKP1A==", "dev": true, - "dependencies": { - "glob": "^7.1.3" - }, - "bin": { - "rimraf": "bin.js" + "license": "MIT", + "engines": { + "node": ">=18" }, "funding": { - "url": "https://github.com/sponsors/isaacs" + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/run-parallel": { @@ -3013,7 +3080,8 @@ "version": "2.1.2", "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/schema-utils": { "version": "3.3.0", @@ -3040,11 +3108,13 @@ "dev": true }, "node_modules/selfsigned": { - "version": "2.1.1", - "resolved": "https://registry.npmjs.org/selfsigned/-/selfsigned-2.1.1.tgz", - "integrity": "sha512-GSL3aowiF7wa/WtSFwnUrludWFoNhftq8bUkH9pkzjpN2XSPOAYEgg6e0sS9s0rZwgJzJiQRPU18A6clnoW5wQ==", + "version": "2.4.1", + "resolved": "https://registry.npmjs.org/selfsigned/-/selfsigned-2.4.1.tgz", + "integrity": "sha512-th5B4L2U+eGLq1TVh7zNRGBapioSORUeymIydxgFpwww9d2qyKvtuPU2jJuHvYAwwqi2Y596QBL3eEqcPEYL8Q==", "dev": true, + "license": "MIT", "dependencies": { + "@types/node-forge": "^1.3.0", "node-forge": "^1" }, "engines": { @@ -3056,6 +3126,7 @@ "resolved": "https://registry.npmjs.org/send/-/send-0.19.0.tgz", "integrity": "sha512-dW41u5VfLXu8SJh5bwRmyYUbAoSB3c9uQh6L8h/KtsFREPWpbX1lrljJo186Jc4nmci/sGUZ9a0a0J2zgfq2hw==", "dev": true, + "license": "MIT", "dependencies": { "debug": "2.6.9", "depd": "2.0.0", @@ -3080,6 +3151,7 @@ "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", "dev": true, + "license": "MIT", "dependencies": { "ms": "2.0.0" } @@ -3088,13 +3160,25 @@ "version": "2.0.0", "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", "integrity": "sha512-Tpp60P6IUJDTuOq/5Z8cdskzJujfwqfOTkrwIwj7IRISpnkJnT6SyJ4PCPnGMoFjC9ddhal5KVIYtAt97ix05A==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/send/node_modules/depd": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/send/node_modules/encodeurl": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz", + "integrity": "sha512-TPJXq8JqFaVYm2CWmPvnP2Iyo4ZSM7/QKcSmuMLDObfpH5fi7RUGmd/rTDf+rut/saiDiQEeVTNgAmJEdAOx0w==", + "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -3103,13 +3187,15 @@ "version": "2.1.3", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/send/node_modules/statuses": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -3177,6 +3263,7 @@ "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.16.2.tgz", "integrity": "sha512-VqpjJZKadQB/PEbEwvFdO43Ax5dFBZ2UECszz8bQ7pi7wt//PWe1P6MN7eCnjsatYtBT6EuiClbjSWP2WrIoTw==", "dev": true, + "license": "MIT", "dependencies": { "encodeurl": "~2.0.0", "escape-html": "~1.0.3", @@ -3187,37 +3274,12 @@ "node": ">= 0.8.0" } }, - "node_modules/serve-static/node_modules/encodeurl": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", - "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", - "dev": true, - "engines": { - "node": ">= 0.8" - } - }, - "node_modules/set-function-length": { - "version": "1.2.2", - "resolved": "https://registry.npmjs.org/set-function-length/-/set-function-length-1.2.2.tgz", - "integrity": "sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==", - "dev": true, - "dependencies": { - "define-data-property": "^1.1.4", - "es-errors": "^1.3.0", - "function-bind": "^1.1.2", - "get-intrinsic": "^1.2.4", - "gopd": "^1.0.1", - "has-property-descriptors": "^1.0.2" - }, - "engines": { - "node": ">= 0.4" - } - }, "node_modules/setprototypeof": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", "integrity": "sha512-E5LDX7Wrp85Kil5bhZv46j8jOeboKq5JMmYM3gVGdGH8xFpPWXUMsNrlODCrkoxMEeNi/XZIwuRvY4XNwYMJpw==", - "dev": true + "dev": true, + "license": "ISC" }, "node_modules/shallow-clone": { "version": "3.0.1", @@ -3253,24 +3315,30 @@ } }, "node_modules/shell-quote": { - "version": "1.8.1", - "resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.1.tgz", - "integrity": "sha512-6j1W9l1iAs/4xYBI1SYOVZyFcCis9b4KCLQ8fgAGG07QvzaRLVVRQvAy85yNmmZSjYjg4MWh4gNvlPujU/5LpA==", + "version": "1.8.3", + "resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.3.tgz", + "integrity": "sha512-ObmnIF4hXNg1BqhnHmgbDETF8dLPCggZWBjkQfhZpbszZnYur5DUljTcCHii5LC3J5E0yeO/1LIMyH+UvHQgyw==", "dev": true, + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, "funding": { "url": "https://github.com/sponsors/ljharb" } }, "node_modules/side-channel": { - "version": "1.0.6", - "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.6.tgz", - "integrity": "sha512-fDW/EZ6Q9RiO8eFG8Hj+7u/oW+XrPTIChwCOM2+th2A6OblDtYYIpve9m+KvI9Z4C9qSEXlaGR6bTEYHReuglA==", + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.1.0.tgz", + "integrity": "sha512-ZX99e6tRweoUXqR+VBrslhda51Nh5MTQwou5tnUDgbtyM0dBgmhEDtWGP/xbKn6hqfPRHujUNwz5fy/wbbhnpw==", "dev": true, + "license": "MIT", "dependencies": { - "call-bind": "^1.0.7", "es-errors": "^1.3.0", - "get-intrinsic": "^1.2.4", - "object-inspect": "^1.13.1" + "object-inspect": "^1.13.3", + "side-channel-list": "^1.0.0", + "side-channel-map": "^1.0.1", + "side-channel-weakmap": "^1.0.2" }, "engines": { "node": ">= 0.4" @@ -3279,11 +3347,61 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/signal-exit": { - "version": "3.0.7", - "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-3.0.7.tgz", - "integrity": "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==", - "dev": true + "node_modules/side-channel-list": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/side-channel-list/-/side-channel-list-1.0.0.tgz", + "integrity": "sha512-FCLHtRD/gnpCiCHEiJLOwdmFP+wzCmDEkc9y7NsYxeF4u7Btsn1ZuwgwJGxImImHicJArLP4R0yX4c2KCrMrTA==", + "dev": true, + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "object-inspect": "^1.13.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-map": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/side-channel-map/-/side-channel-map-1.0.1.tgz", + "integrity": "sha512-VCjCNfgMsby3tTdo02nbjtM/ewra6jPHmpThenkTYh8pG9ucZ/1P8So4u4FGBek/BjpOVsDCMoLA/iuBKIFXRA==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-weakmap": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/side-channel-weakmap/-/side-channel-weakmap-1.0.2.tgz", + "integrity": "sha512-WPS/HvHQTYnHisLo9McqBHOJk2FkHO/tlpvldyrnem4aeQp4hai3gythswg6p01oSoTl58rcpiFAjF2br2Ak2A==", + "dev": true, + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3", + "side-channel-map": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } }, "node_modules/slash": { "version": "5.1.0", @@ -3390,15 +3508,6 @@ "safe-buffer": "~5.1.0" } }, - "node_modules/strip-final-newline": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/strip-final-newline/-/strip-final-newline-2.0.0.tgz", - "integrity": "sha512-BrpvfNAE3dcvq7ll3xVumzjKjZQ5tI1sEUIKr3Uoks0XUl45St3FlatVqef9prk4jRDzhW6WZg+3bk93y6pLjA==", - "dev": true, - "engines": { - "node": ">=6" - } - }, "node_modules/supports-color": { "version": "8.1.1", "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-8.1.1.tgz", @@ -3491,7 +3600,8 @@ "version": "1.1.0", "resolved": "https://registry.npmjs.org/thunky/-/thunky-1.1.0.tgz", "integrity": "sha512-eHY7nBftgThBqOyHGVN+l8gF0BucP09fMo0oO/Lb0w1OF80dJv+lDVpXG60WMQvkcxAkNybKsrEIE3ZtKGmPrA==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/to-regex-range": { "version": "5.0.1", @@ -3510,6 +3620,7 @@ "resolved": "https://registry.npmjs.org/toidentifier/-/toidentifier-1.0.1.tgz", "integrity": "sha512-o5sSPKEkg/DIQNmH43V0/uerLrpzVedkUh8tGNvaeXpfpuwjKenlSox/2O/BTlZUtEe+JG7s5YhEz608PlAHRA==", "dev": true, + "license": "MIT", "engines": { "node": ">=0.6" } @@ -3525,6 +3636,7 @@ "resolved": "https://registry.npmjs.org/type-is/-/type-is-1.6.18.tgz", "integrity": "sha512-TkRKr9sUTxEH8MdfuCSP7VizJyzRNMjj2J2do2Jr3Kym598JVdEksuzPQCnlFPW4ky9Q+iA+ma9BGm06XQBy8g==", "dev": true, + "license": "MIT", "dependencies": { "media-typer": "0.3.0", "mime-types": "~2.1.24" @@ -3551,6 +3663,7 @@ "resolved": "https://registry.npmjs.org/unpipe/-/unpipe-1.0.0.tgz", "integrity": "sha512-pjy2bYhSsufwWlKwPc+l3cN7+wuJlK6uz0YdJEOlQDbl6jo/YlPi4mb8agUkVC8BF7V8NuzeyPNqRksA3hztKQ==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.8" } @@ -3605,6 +3718,7 @@ "resolved": "https://registry.npmjs.org/utils-merge/-/utils-merge-1.0.1.tgz", "integrity": "sha512-pMZTvIkT1d+TFGvDOqodOclx0QWkkgi6Tdoa8gC8ffGAAqz9pzPTZWAybbsHHoED/ztMtkv/VoYTYyShUn81hA==", "dev": true, + "license": "MIT", "engines": { "node": ">= 0.4.0" } @@ -3750,38 +3864,46 @@ } }, "node_modules/webpack-dev-middleware": { - "version": "5.3.4", - "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-5.3.4.tgz", - "integrity": "sha512-BVdTqhhs+0IfoeAf7EoH5WE+exCmqGerHfDM0IL096Px60Tq2Mn9MAbnaGUe6HiMa41KMCYF19gyzZmBcq/o4Q==", + "version": "7.4.2", + "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-7.4.2.tgz", + "integrity": "sha512-xOO8n6eggxnwYpy1NlzUKpvrjfJTvae5/D6WOK0S2LSo7vjmo5gCM1DbLUmFqrMTJP+W/0YZNctm7jasWvLuBA==", "dev": true, + "license": "MIT", "dependencies": { "colorette": "^2.0.10", - "memfs": "^3.4.3", + "memfs": "^4.6.0", "mime-types": "^2.1.31", + "on-finished": "^2.4.1", "range-parser": "^1.2.1", "schema-utils": "^4.0.0" }, "engines": { - "node": ">= 12.13.0" + "node": ">= 18.12.0" }, "funding": { "type": "opencollective", "url": "https://opencollective.com/webpack" }, "peerDependencies": { - "webpack": "^4.0.0 || ^5.0.0" + "webpack": "^5.0.0" + }, + "peerDependenciesMeta": { + "webpack": { + "optional": true + } } }, "node_modules/webpack-dev-middleware/node_modules/ajv": { - "version": "8.12.0", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz", - "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", + "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", "dev": true, + "license": "MIT", "dependencies": { - "fast-deep-equal": "^3.1.1", + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" + "require-from-string": "^2.0.2" }, "funding": { "type": "github", @@ -3793,6 +3915,7 @@ "resolved": "https://registry.npmjs.org/ajv-keywords/-/ajv-keywords-5.1.0.tgz", "integrity": "sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==", "dev": true, + "license": "MIT", "dependencies": { "fast-deep-equal": "^3.1.3" }, @@ -3804,13 +3927,15 @@ "version": "1.0.0", "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", - "dev": true + "dev": true, + "license": "MIT" }, "node_modules/webpack-dev-middleware/node_modules/schema-utils": { - "version": "4.2.0", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.2.0.tgz", - "integrity": "sha512-L0jRsrPpjdckP3oPug3/VxNKt2trR8TcabrM6FOAAlvC/9Phcmm+cuAgTlxBqdBR1WJx7Naj9WHw+aOmheSVbw==", + "version": "4.3.2", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.2.tgz", + "integrity": "sha512-Gn/JaSk/Mt9gYubxTtSn/QCV4em9mpAPiR1rqy/Ocu19u/G9J5WWdNoUT4SiV6mFC3y6cxyFcFwdzPM3FgxGAQ==", "dev": true, + "license": "MIT", "dependencies": { "@types/json-schema": "^7.0.9", "ajv": "^8.9.0", @@ -3818,7 +3943,7 @@ "ajv-keywords": "^5.1.0" }, "engines": { - "node": ">= 12.13.0" + "node": ">= 10.13.0" }, "funding": { "type": "opencollective", @@ -3826,54 +3951,53 @@ } }, "node_modules/webpack-dev-server": { - "version": "4.15.1", - "resolved": "https://registry.npmjs.org/webpack-dev-server/-/webpack-dev-server-4.15.1.tgz", - "integrity": "sha512-5hbAst3h3C3L8w6W4P96L5vaV0PxSmJhxZvWKYIdgxOQm8pNZ5dEOmmSLBVpP85ReeyRt6AS1QJNyo/oFFPeVA==", - "dev": true, - "dependencies": { - "@types/bonjour": "^3.5.9", - "@types/connect-history-api-fallback": "^1.3.5", - "@types/express": "^4.17.13", - "@types/serve-index": "^1.9.1", - "@types/serve-static": "^1.13.10", - "@types/sockjs": "^0.3.33", - "@types/ws": "^8.5.5", + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/webpack-dev-server/-/webpack-dev-server-5.2.1.tgz", + "integrity": "sha512-ml/0HIj9NLpVKOMq+SuBPLHcmbG+TGIjXRHsYfZwocUBIqEvws8NnS/V9AFQ5FKP+tgn5adwVwRrTEpGL33QFQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/bonjour": "^3.5.13", + "@types/connect-history-api-fallback": "^1.5.4", + "@types/express": "^4.17.21", + "@types/express-serve-static-core": "^4.17.21", + "@types/serve-index": "^1.9.4", + "@types/serve-static": "^1.15.5", + "@types/sockjs": "^0.3.36", + "@types/ws": "^8.5.10", "ansi-html-community": "^0.0.8", - "bonjour-service": "^1.0.11", - "chokidar": "^3.5.3", + "bonjour-service": "^1.2.1", + "chokidar": "^3.6.0", "colorette": "^2.0.10", "compression": "^1.7.4", "connect-history-api-fallback": "^2.0.0", - "default-gateway": "^6.0.3", - "express": "^4.17.3", + "express": "^4.21.2", "graceful-fs": "^4.2.6", - "html-entities": "^2.3.2", - "http-proxy-middleware": "^2.0.3", - "ipaddr.js": "^2.0.1", - "launch-editor": "^2.6.0", - "open": "^8.0.9", - "p-retry": "^4.5.0", - "rimraf": "^3.0.2", - "schema-utils": "^4.0.0", - "selfsigned": "^2.1.1", + "http-proxy-middleware": "^2.0.7", + "ipaddr.js": "^2.1.0", + "launch-editor": "^2.6.1", + "open": "^10.0.3", + "p-retry": "^6.2.0", + "schema-utils": "^4.2.0", + "selfsigned": "^2.4.1", "serve-index": "^1.9.1", "sockjs": "^0.3.24", "spdy": "^4.0.2", - "webpack-dev-middleware": "^5.3.1", - "ws": "^8.13.0" + "webpack-dev-middleware": "^7.4.2", + "ws": "^8.18.0" }, "bin": { "webpack-dev-server": "bin/webpack-dev-server.js" }, "engines": { - "node": ">= 12.13.0" + "node": ">= 18.12.0" }, "funding": { "type": "opencollective", "url": "https://opencollective.com/webpack" }, "peerDependencies": { - "webpack": "^4.37.0 || ^5.0.0" + "webpack": "^5.0.0" }, "peerDependenciesMeta": { "webpack": { @@ -4003,17 +4127,12 @@ "integrity": "sha512-CC1bOL87PIWSBhDcTrdeLo6eGT7mCFtrg0uIJtqJUFyK+eJnzl8A1niH56uu7KMa5XFrtiV+AQuHO3n7DsHnLQ==", "dev": true }, - "node_modules/wrappy": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", - "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", - "dev": true - }, "node_modules/ws": { - "version": "8.17.1", - "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz", - "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==", + "version": "8.18.2", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.18.2.tgz", + "integrity": "sha512-DMricUmwGZUVr++AEAe2uiVM7UoO9MAVZMDu05UQOaUII0lp+zOzLLU4Xqh/JvTqklB1T4uELaaPBKyjE1r4fQ==", "dev": true, + "license": "MIT", "engines": { "node": ">=10.0.0" }, @@ -4088,9 +4207,9 @@ } }, "@leichtgewicht/ip-codec": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/@leichtgewicht/ip-codec/-/ip-codec-2.0.4.tgz", - "integrity": "sha512-Hcv+nVC0kZnQ3tD9GVu5xSMR4VVYOteQIr/hwFPVEvPdlXqgGEuRjiheChHgdM+JyqdgNcmzZOX/tnl0JOiI7A==", + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@leichtgewicht/ip-codec/-/ip-codec-2.0.5.tgz", + "integrity": "sha512-Vo+PSpZG2/fmgmiNzYK9qWRh8h/CHrwD0mo1h1DzL4yzHNSfWYujGTYsWGreD000gcgmZ7K4Ys6Tx9TxtsKdDw==", "dev": true }, "@nodelib/fs.scandir": { @@ -4136,9 +4255,9 @@ } }, "@types/bonjour": { - "version": "3.5.11", - "resolved": "https://registry.npmjs.org/@types/bonjour/-/bonjour-3.5.11.tgz", - "integrity": "sha512-isGhjmBtLIxdHBDl2xGwUzEM8AOyOvWsADWq7rqirdi/ZQoHnLWErHvsThcEzTX8juDRiZtzp2Qkv5bgNh6mAg==", + "version": "3.5.13", + "resolved": "https://registry.npmjs.org/@types/bonjour/-/bonjour-3.5.13.tgz", + "integrity": "sha512-z9fJ5Im06zvUL548KvYNecEVlA7cVDkGUi6kZusb04mpyEFKCIZJvloCcmpmLaIahDpOQGHaHmG6imtPMmPXGQ==", "dev": true, "requires": { "@types/node": "*" @@ -4154,9 +4273,9 @@ } }, "@types/connect-history-api-fallback": { - "version": "1.5.1", - "resolved": "https://registry.npmjs.org/@types/connect-history-api-fallback/-/connect-history-api-fallback-1.5.1.tgz", - "integrity": "sha512-iaQslNbARe8fctL5Lk+DsmgWOM83lM+7FzP0eQUJs1jd3kBE8NWqBTIT2S8SqQOJjxvt2eyIjpOuYeRXq2AdMw==", + "version": "1.5.4", + "resolved": "https://registry.npmjs.org/@types/connect-history-api-fallback/-/connect-history-api-fallback-1.5.4.tgz", + "integrity": "sha512-n6Cr2xS1h4uAulPRdlw6Jl6s1oG8KrVilPN2yUITEs+K48EzMJJ3W1xy8K5eWuFvjp3R74AOIGSmp2UfBJ8HFw==", "dev": true, "requires": { "@types/express-serve-static-core": "*", @@ -4170,9 +4289,9 @@ "dev": true }, "@types/express": { - "version": "4.17.17", - "resolved": "https://registry.npmjs.org/@types/express/-/express-4.17.17.tgz", - "integrity": "sha512-Q4FmmuLGBG58btUnfS1c1r/NQdlp3DMfGDGig8WhfpA2YRUtEkxAjkZb0yvplJGYdF1fsQ81iMDcH24sSCNC/Q==", + "version": "4.17.22", + "resolved": "https://registry.npmjs.org/@types/express/-/express-4.17.22.tgz", + "integrity": "sha512-eZUmSnhRX9YRSkplpz0N+k6NljUUn5l3EWZIKZvYzhvMphEuNiyyy1viH/ejgt66JWgALwC/gtSUAeQKtSwW/w==", "dev": true, "requires": { "@types/body-parser": "*", @@ -4226,6 +4345,15 @@ "integrity": "sha512-HksnYH4Ljr4VQgEy2lTStbCKv/P590tmPe5HqOnv9Gprffgv5WXAY+Y5Gqniu0GGqeTCUdBnzC3QSrzPkBkAMA==", "dev": true }, + "@types/node-forge": { + "version": "1.3.11", + "resolved": "https://registry.npmjs.org/@types/node-forge/-/node-forge-1.3.11.tgz", + "integrity": "sha512-FQx220y22OKNTqaByeBGqHWYz4cl94tpcxeFdvBo3wjG6XPBuZ0BNgNZRV5J5TFmmcsJ4IzsLkmGRiQbnYsBEQ==", + "dev": true, + "requires": { + "@types/node": "*" + } + }, "@types/qs": { "version": "6.9.8", "resolved": "https://registry.npmjs.org/@types/qs/-/qs-6.9.8.tgz", @@ -4239,9 +4367,9 @@ "dev": true }, "@types/retry": { - "version": "0.12.0", - "resolved": "https://registry.npmjs.org/@types/retry/-/retry-0.12.0.tgz", - "integrity": "sha512-wWKOClTTiizcZhXnPY4wikVAwmdYHp8q6DmC+EJUzAMsycb7HB32Kh9RN4+0gExjmPmZSAQjgURXIGATPegAvA==", + "version": "0.12.2", + "resolved": "https://registry.npmjs.org/@types/retry/-/retry-0.12.2.tgz", + "integrity": "sha512-XISRgDJ2Tc5q4TRqvgJtzsRkFYNJzZrhTdtMoGVBttwzzQJkPnS3WWTFc7kuDRoPtPakl+T+OfdEUjYJj7Jbow==", "dev": true }, "@types/send": { @@ -4255,38 +4383,38 @@ } }, "@types/serve-index": { - "version": "1.9.1", - "resolved": "https://registry.npmjs.org/@types/serve-index/-/serve-index-1.9.1.tgz", - "integrity": "sha512-d/Hs3nWDxNL2xAczmOVZNj92YZCS6RGxfBPjKzuu/XirCgXdpKEb88dYNbrYGint6IVWLNP+yonwVAuRC0T2Dg==", + "version": "1.9.4", + "resolved": "https://registry.npmjs.org/@types/serve-index/-/serve-index-1.9.4.tgz", + "integrity": "sha512-qLpGZ/c2fhSs5gnYsQxtDEq3Oy8SXPClIXkW5ghvAvsNuVSA8k+gCONcUCS/UjLEYvYps+e8uBtfgXgvhwfNug==", "dev": true, "requires": { "@types/express": "*" } }, "@types/serve-static": { - "version": "1.15.2", - "resolved": "https://registry.npmjs.org/@types/serve-static/-/serve-static-1.15.2.tgz", - "integrity": "sha512-J2LqtvFYCzaj8pVYKw8klQXrLLk7TBZmQ4ShlcdkELFKGwGMfevMLneMMRkMgZxotOD9wg497LpC7O8PcvAmfw==", + "version": "1.15.7", + "resolved": "https://registry.npmjs.org/@types/serve-static/-/serve-static-1.15.7.tgz", + "integrity": "sha512-W8Ym+h8nhuRwaKPaDw34QUkwsGi6Rc4yYqvKFo5rm2FUEhCFbzVWrxXUxuKK8TASjWsysJY0nsmNCGhCOIsrOw==", "dev": true, "requires": { "@types/http-errors": "*", - "@types/mime": "*", - "@types/node": "*" + "@types/node": "*", + "@types/send": "*" } }, "@types/sockjs": { - "version": "0.3.33", - "resolved": "https://registry.npmjs.org/@types/sockjs/-/sockjs-0.3.33.tgz", - "integrity": "sha512-f0KEEe05NvUnat+boPTZ0dgaLZ4SfSouXUgv5noUiefG2ajgKjmETo9ZJyuqsl7dfl2aHlLJUiki6B4ZYldiiw==", + "version": "0.3.36", + "resolved": "https://registry.npmjs.org/@types/sockjs/-/sockjs-0.3.36.tgz", + "integrity": "sha512-MK9V6NzAS1+Ud7JV9lJLFqW85VbC9dq3LmwZCuBe4wBDgKC0Kj/jd8Xl+nSviU+Qc3+m7umHHyHg//2KSa0a0Q==", "dev": true, "requires": { "@types/node": "*" } }, "@types/ws": { - "version": "8.5.5", - "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.5.5.tgz", - "integrity": "sha512-lwhs8hktwxSjf9UaZ9tG5M03PGogvFaH8gUgLNbN9HKIg0dvv6q+gkSuJ8HN4/VbyxkuLzCjlN7GquQ0gUJfIg==", + "version": "8.18.1", + "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz", + "integrity": "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==", "dev": true, "requires": { "@types/node": "*" @@ -4559,15 +4687,9 @@ } }, "array-flatten": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-2.1.2.tgz", - "integrity": "sha512-hNfzcOV8W4NdualtqBFPyVO+54DSJuZGY9qT4pRroB6S9e3iiido2ISIC5h9R2sPJ8H3FHCIiEnsv1lPXO3KtQ==", - "dev": true - }, - "balanced-match": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", - "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz", + "integrity": "sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg==", "dev": true }, "batch": { @@ -4577,9 +4699,9 @@ "dev": true }, "binary-extensions": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.2.0.tgz", - "integrity": "sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA==", + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.3.0.tgz", + "integrity": "sha512-Ceh+7ox5qe7LJuLHoY0feh3pHuUDHAcRUeyL2VYghZwfpkNIy/+8Ocg0a3UuSoYzavmylwuLWQOf3hl0jjMMIw==", "dev": true }, "body-parser": { @@ -4626,27 +4748,15 @@ } }, "bonjour-service": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/bonjour-service/-/bonjour-service-1.1.1.tgz", - "integrity": "sha512-Z/5lQRMOG9k7W+FkeGTNjh7htqn/2LMnfOvBZ8pynNZCM9MwkQkI3zeI4oz09uWdcgmgHugVvBqxGg4VQJ5PCg==", + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/bonjour-service/-/bonjour-service-1.3.0.tgz", + "integrity": "sha512-3YuAUiSkWykd+2Azjgyxei8OWf8thdn8AITIog2M4UICzoqfjlqr64WIjEXZllf/W6vK1goqleSR6brGomxQqA==", "dev": true, "requires": { - "array-flatten": "^2.1.2", - "dns-equal": "^1.0.0", "fast-deep-equal": "^3.1.3", "multicast-dns": "^7.2.5" } }, - "brace-expansion": { - "version": "1.1.11", - "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", - "integrity": "sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==", - "dev": true, - "requires": { - "balanced-match": "^1.0.0", - "concat-map": "0.0.1" - } - }, "braces": { "version": "3.0.3", "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", @@ -4674,23 +4784,39 @@ "integrity": "sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==", "dev": true }, + "bundle-name": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/bundle-name/-/bundle-name-4.1.0.tgz", + "integrity": "sha512-tjwM5exMg6BGRI+kNmTntNsvdZS1X8BFYS6tnJ2hdH0kVxM6/eVZ2xy+FqStSWvYmtfFMDLIxurorHwDKfDz5Q==", + "dev": true, + "requires": { + "run-applescript": "^7.0.0" + } + }, "bytes": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.0.0.tgz", "integrity": "sha1-0ygVQE1olpn4Wk6k+odV3ROpYEg=", "dev": true }, - "call-bind": { - "version": "1.0.7", - "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.7.tgz", - "integrity": "sha512-GHTSNSYICQ7scH7sZ+M2rFopRoLh8t2bLSW6BbgrtLsahOIB5iyAVJf9GjWK3cYTDaMj4XdBpM1cA6pIS0Kv2w==", + "call-bind-apply-helpers": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz", + "integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==", "dev": true, "requires": { - "es-define-property": "^1.0.0", "es-errors": "^1.3.0", - "function-bind": "^1.1.2", - "get-intrinsic": "^1.2.4", - "set-function-length": "^1.2.1" + "function-bind": "^1.1.2" + } + }, + "call-bound": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/call-bound/-/call-bound-1.0.4.tgz", + "integrity": "sha512-+ys997U96po4Kx/ABpBCqhA9EuxJaQWDQg7295H4hBphv3IZg0boBKuwYpt4YXp6MZ5AmZQnU/tyMTlRpaSejg==", + "dev": true, + "requires": { + "call-bind-apply-helpers": "^1.0.2", + "get-intrinsic": "^1.3.0" } }, "caniuse-lite": { @@ -4700,9 +4826,9 @@ "dev": true }, "chokidar": { - "version": "3.5.3", - "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz", - "integrity": "sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw==", + "version": "3.6.0", + "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.6.0.tgz", + "integrity": "sha512-7VT13fmjotKpGipCW9JEQAusEPE+Ei8nl6/g4FBAmIm0GOOLMua9NDDo/DWp0ZAxCr3cPq5ZpBqmPAQgDda2Pw==", "dev": true, "requires": { "anymatch": "~3.1.2", @@ -4782,12 +4908,6 @@ } } }, - "concat-map": { - "version": "0.0.1", - "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", - "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==", - "dev": true - }, "connect-history-api-fallback": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/connect-history-api-fallback/-/connect-history-api-fallback-2.0.0.tgz", @@ -4930,30 +5050,26 @@ } } }, - "default-gateway": { - "version": "6.0.3", - "resolved": "https://registry.npmjs.org/default-gateway/-/default-gateway-6.0.3.tgz", - "integrity": "sha512-fwSOJsbbNzZ/CUFpqFBqYfYNLj1NbMPm8MMCIzHjC83iSJRBEGmDUxU+WP661BaBQImeC2yHwXtz+P/O9o+XEg==", + "default-browser": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/default-browser/-/default-browser-5.2.1.tgz", + "integrity": "sha512-WY/3TUME0x3KPYdRRxEJJvXRHV4PyPoUsxtZa78lwItwRQRHhd2U9xOscaT/YTf8uCXIAjeJOFBVEh/7FtD8Xg==", "dev": true, "requires": { - "execa": "^5.0.0" + "bundle-name": "^4.1.0", + "default-browser-id": "^5.0.0" } }, - "define-data-property": { - "version": "1.1.4", - "resolved": "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.4.tgz", - "integrity": "sha512-rBMvIzlpA8v6E+SJZoo++HAYqsLrkg7MSfIinMPFhmkorw7X+dOXVJQs+QT69zGkzMyfDnIMN2Wid1+NbL3T+A==", - "dev": true, - "requires": { - "es-define-property": "^1.0.0", - "es-errors": "^1.3.0", - "gopd": "^1.0.1" - } + "default-browser-id": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/default-browser-id/-/default-browser-id-5.0.0.tgz", + "integrity": "sha512-A6p/pu/6fyBcA1TRz/GqWYPViplrftcW2gZC9q79ngNCKAeR/X3gcEdXQHl4KNXV+3wgIJ1CPkJQ3IHM6lcsyA==", + "dev": true }, "define-lazy-prop": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-2.0.0.tgz", - "integrity": "sha512-Ds09qNh8yw3khSjiJjiUInaGX9xlqZDY7JVryGxdxV7NPeuqQfplOpQ66yJFZut3jLa5zOwkXw1g9EI2uKh4Og==", + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-3.0.0.tgz", + "integrity": "sha512-N+MeXYoqr3pOgn8xfyRPREN7gHakLYjhsHhWGT3fWAiL4IkAt0iDw14QiiEm2bE30c5XX5q0FtAA3CK5f9/BUg==", "dev": true }, "depd": { @@ -4974,12 +5090,6 @@ "integrity": "sha512-ZIzRpLJrOj7jjP2miAtgqIfmzbxa4ZOr5jJc601zklsfEx9oTzmmj2nVpIPRpNlRTIh8lc1kyViIY7BWSGNmKw==", "dev": true }, - "dns-equal": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/dns-equal/-/dns-equal-1.0.0.tgz", - "integrity": "sha512-z+paD6YUQsk+AbGCEM4PrOXSss5gd66QfcVBFTKR/HpFL9jCqikS94HYwKww6fQyO7IxrIIyUu+g0Ka9tUS2Cg==", - "dev": true - }, "dns-packet": { "version": "5.6.1", "resolved": "https://registry.npmjs.org/dns-packet/-/dns-packet-5.6.1.tgz", @@ -4989,6 +5099,17 @@ "@leichtgewicht/ip-codec": "^2.0.1" } }, + "dunder-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz", + "integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==", + "dev": true, + "requires": { + "call-bind-apply-helpers": "^1.0.1", + "es-errors": "^1.3.0", + "gopd": "^1.2.0" + } + }, "ee-first": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz", @@ -5002,9 +5123,9 @@ "dev": true }, "encodeurl": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz", - "integrity": "sha512-TPJXq8JqFaVYm2CWmPvnP2Iyo4ZSM7/QKcSmuMLDObfpH5fi7RUGmd/rTDf+rut/saiDiQEeVTNgAmJEdAOx0w==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", + "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", "dev": true }, "enhanced-resolve": { @@ -5024,13 +5145,10 @@ "dev": true }, "es-define-property": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.0.tgz", - "integrity": "sha512-jxayLKShrEqqzJ0eumQbVhTYQM27CfT1T35+gCgDFoL82JLsXqTJ76zv6A0YLOgEnLUMvLzsDsGIrl8NFpT2gQ==", - "dev": true, - "requires": { - "get-intrinsic": "^1.2.4" - } + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz", + "integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==", + "dev": true }, "es-errors": { "version": "1.3.0", @@ -5044,6 +5162,15 @@ "integrity": "sha512-JUFAyicQV9mXc3YRxPnDlrfBKpqt6hUYzz9/boprUJHs4e4KVr3XwOF70doO6gwXUor6EWZJAyWAfKki84t20Q==", "dev": true }, + "es-object-atoms": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz", + "integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==", + "dev": true, + "requires": { + "es-errors": "^1.3.0" + } + }, "escalade": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.1.tgz", @@ -5107,27 +5234,10 @@ "integrity": "sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==", "dev": true }, - "execa": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/execa/-/execa-5.1.1.tgz", - "integrity": "sha512-8uSpZZocAZRBAPIEINJj3Lo9HyGitllczc27Eh5YYojjMFMn8yHMDMaUHE2Jqfq05D/wucwI4JGURyXt1vchyg==", - "dev": true, - "requires": { - "cross-spawn": "^7.0.3", - "get-stream": "^6.0.0", - "human-signals": "^2.1.0", - "is-stream": "^2.0.0", - "merge-stream": "^2.0.0", - "npm-run-path": "^4.0.1", - "onetime": "^5.1.2", - "signal-exit": "^3.0.3", - "strip-final-newline": "^2.0.0" - } - }, "express": { - "version": "4.21.1", - "resolved": "https://registry.npmjs.org/express/-/express-4.21.1.tgz", - "integrity": "sha512-YSFlK1Ee0/GC8QaO91tHcDxJiE/X4FbpAyQWkxAvG6AXCuR65YzK8ua6D9hvi/TzUfZMpc+BwuM1IPw8fmQBiQ==", + "version": "4.21.2", + "resolved": "https://registry.npmjs.org/express/-/express-4.21.2.tgz", + "integrity": "sha512-28HqgMZAmih1Czt9ny7qr6ek2qddF4FclbMzwhCREB6OFfH+rXAnuNCwo1/wFvrtbgsQDb4kSbX9de9lFbrXnA==", "dev": true, "requires": { "accepts": "~1.3.8", @@ -5149,7 +5259,7 @@ "methods": "~1.1.2", "on-finished": "2.4.1", "parseurl": "~1.3.3", - "path-to-regexp": "0.1.10", + "path-to-regexp": "0.1.12", "proxy-addr": "~2.0.7", "qs": "6.13.0", "range-parser": "~1.2.1", @@ -5163,12 +5273,6 @@ "vary": "~1.1.2" }, "dependencies": { - "array-flatten": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz", - "integrity": "sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg==", - "dev": true - }, "debug": { "version": "2.6.9", "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", @@ -5184,12 +5288,6 @@ "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", "dev": true }, - "encodeurl": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", - "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", - "dev": true - }, "safe-buffer": { "version": "5.2.1", "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", @@ -5292,12 +5390,6 @@ "ms": "2.0.0" } }, - "encodeurl": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", - "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", - "dev": true - }, "statuses": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", @@ -5334,18 +5426,6 @@ "integrity": "sha512-zJ2mQYM18rEFOudeV4GShTGIQ7RbzA7ozbU9I/XBpm7kqgMywgmylMwXHxZJmkVoYkna9d2pVXVXPdYTP9ej8Q==", "dev": true }, - "fs-monkey": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/fs-monkey/-/fs-monkey-1.0.4.tgz", - "integrity": "sha512-INM/fWAxMICjttnD0DX1rBvinKskj5G1w+oy/pnm9u/tSlnBrzFonJMcalKJ30P8RRsPzKcCG7Q8l0jx5Fh9YQ==", - "dev": true - }, - "fs.realpath": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", - "integrity": "sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==", - "dev": true - }, "fsevents": { "version": "2.3.3", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", @@ -5360,36 +5440,31 @@ "dev": true }, "get-intrinsic": { - "version": "1.2.4", - "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.4.tgz", - "integrity": "sha512-5uYhsJH8VJBTv7oslg4BznJYhDoRI6waYCxMmCdnTrcCrHA/fCFKoTFz2JKKE0HdDFUF7/oQuhzumXJK7paBRQ==", + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz", + "integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==", "dev": true, "requires": { + "call-bind-apply-helpers": "^1.0.2", + "es-define-property": "^1.0.1", "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", "function-bind": "^1.1.2", - "has-proto": "^1.0.1", - "has-symbols": "^1.0.3", - "hasown": "^2.0.0" + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "has-symbols": "^1.1.0", + "hasown": "^2.0.2", + "math-intrinsics": "^1.1.0" } }, - "get-stream": { - "version": "6.0.1", - "resolved": "https://registry.npmjs.org/get-stream/-/get-stream-6.0.1.tgz", - "integrity": "sha512-ts6Wi+2j3jQjqi70w5AlN8DFnkSwC+MqmxEzdEALB2qXZYV3X/b1CTfgPLGJNMeAWxdPfU8FO1ms3NUfaHCPYg==", - "dev": true - }, - "glob": { - "version": "7.2.3", - "resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz", - "integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==", + "get-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz", + "integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==", "dev": true, "requires": { - "fs.realpath": "^1.0.0", - "inflight": "^1.0.4", - "inherits": "2", - "minimatch": "^3.1.1", - "once": "^1.3.0", - "path-is-absolute": "^1.0.0" + "dunder-proto": "^1.0.1", + "es-object-atoms": "^1.0.0" } }, "glob-parent": { @@ -5422,13 +5497,10 @@ } }, "gopd": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.0.1.tgz", - "integrity": "sha512-d65bNlIadxvpb/A2abVdlqKqV563juRnZ1Wtk6s1sIR8uNsXR70xqIzVqxVf1eTqDunwT2MkczEeaezCKTZhwA==", - "dev": true, - "requires": { - "get-intrinsic": "^1.1.3" - } + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz", + "integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==", + "dev": true }, "graceful-fs": { "version": "4.2.11", @@ -5457,25 +5529,10 @@ "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", "dev": true }, - "has-property-descriptors": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.2.tgz", - "integrity": "sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg==", - "dev": true, - "requires": { - "es-define-property": "^1.0.0" - } - }, - "has-proto": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.3.tgz", - "integrity": "sha512-SJ1amZAJUiZS+PhsVLf5tGydlaVB8EdFpaSO4gmiUKUOxk8qzn5AIy4ZeJUmh22znIdk/uMAUT2pl3FxzVUH+Q==", - "dev": true - }, "has-symbols": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.0.3.tgz", - "integrity": "sha512-l3LCuF6MgDNwTDKkdYGEihYjt5pRPbEg46rtlmnSPlUbgmB8LOIrKJbYYFBSbnPaJexMKtiPO8hmeRjRz2Td+A==", + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz", + "integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==", "dev": true }, "hasown": { @@ -5499,12 +5556,6 @@ "wbuf": "^1.1.0" } }, - "html-entities": { - "version": "2.4.0", - "resolved": "https://registry.npmjs.org/html-entities/-/html-entities-2.4.0.tgz", - "integrity": "sha512-igBTJcNNNhvZFRtm8uA6xMY6xYleeDwn3PeBCkDz7tHttv4F2hsDI2aPgNERWzvRcNYHNT3ymRaQzllmXj4YsQ==", - "dev": true - }, "http-deceiver": { "version": "1.2.7", "resolved": "https://registry.npmjs.org/http-deceiver/-/http-deceiver-1.2.7.tgz", @@ -5562,9 +5613,9 @@ } }, "http-proxy-middleware": { - "version": "2.0.6", - "resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.6.tgz", - "integrity": "sha512-ya/UeJ6HVBYxrgYotAZo1KvPWlgB48kUJLDePFeneHsVujFaW5WNj2NgWCAE//B1Dl02BIfYlpNgBy8Kf8Rjmw==", + "version": "2.0.9", + "resolved": "https://registry.npmjs.org/http-proxy-middleware/-/http-proxy-middleware-2.0.9.tgz", + "integrity": "sha512-c1IyJYLYppU574+YI7R4QyX2ystMtVXZwIdzazUIPIJsHuWNd+mho2j+bKoHftndicGj9yh+xjd+l0yj7VeT1Q==", "dev": true, "requires": { "@types/http-proxy": "^1.17.8", @@ -5574,10 +5625,10 @@ "micromatch": "^4.0.2" } }, - "human-signals": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/human-signals/-/human-signals-2.1.0.tgz", - "integrity": "sha512-B4FFZ6q/T2jhhksgkbEW3HBvWIfDW85snkQgawt07S7J5QXTk6BkNV+0yAeZrM5QpMAdYlocGoljn0sJ/WQkFw==", + "hyperdyperid": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/hyperdyperid/-/hyperdyperid-1.2.0.tgz", + "integrity": "sha512-Y93lCzHYgGWdrJ66yIktxiaGULYc6oGiABxhcO5AufBeOyoIdZF7bIfLaOrbM0iGIOXQQgxxRrFEnb+Y6w1n4A==", "dev": true }, "iconv-lite": { @@ -5605,16 +5656,6 @@ "resolve-cwd": "^3.0.0" } }, - "inflight": { - "version": "1.0.6", - "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", - "integrity": "sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==", - "dev": true, - "requires": { - "once": "^1.3.0", - "wrappy": "1" - } - }, "inherits": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.3.tgz", @@ -5652,9 +5693,9 @@ } }, "is-docker": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/is-docker/-/is-docker-2.2.1.tgz", - "integrity": "sha512-F+i2BKsFrH66iaUFc0woD8sLy8getkwTwtOBjvs56Cx4CgJDeKQeqfz8wAYiSb8JOprWhHH5p77PbmYCvvUuXQ==", + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-docker/-/is-docker-3.0.0.tgz", + "integrity": "sha512-eljcgEDlEns/7AXFosB5K/2nCM4P7FQPkGc/DWLy5rmFEWvZayGrik1d9/QIY5nJ4f9YsVvBkA6kJpHn9rISdQ==", "dev": true }, "is-extglob": { @@ -5672,6 +5713,21 @@ "is-extglob": "^2.1.1" } }, + "is-inside-container": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/is-inside-container/-/is-inside-container-1.0.0.tgz", + "integrity": "sha512-KIYLCCJghfHZxqjYBE7rEy0OBuTd5xCHS7tHVgvCLkx7StIoaxwNW3hCALgEUjFfeRk+MG/Qxmp/vtETEF3tRA==", + "dev": true, + "requires": { + "is-docker": "^3.0.0" + } + }, + "is-network-error": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/is-network-error/-/is-network-error-1.1.0.tgz", + "integrity": "sha512-tUdRRAnhT+OtCZR/LxZelH/C7QtjtFrTu5tXCA8pl55eTUElUHT+GPYV8MBMBvea/j+NxQqVt3LbWMRir7Gx9g==", + "dev": true + }, "is-number": { "version": "7.0.0", "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", @@ -5693,19 +5749,13 @@ "isobject": "^3.0.1" } }, - "is-stream": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/is-stream/-/is-stream-2.0.1.tgz", - "integrity": "sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg==", - "dev": true - }, "is-wsl": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/is-wsl/-/is-wsl-2.2.0.tgz", - "integrity": "sha512-fKzAra0rGJUUBwGBgNkHZuToZcn+TtXHpeCgmkMJMMYx1sQDYaCSyjJBSCa2nH1DGm7s3n1oBnohoVTBaN7Lww==", + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/is-wsl/-/is-wsl-3.1.0.tgz", + "integrity": "sha512-UcVfVfaK4Sc4m7X3dUSoHoozQGBEFeDC+zVo06t98xe8CzHSZZBekNXH+tu0NalHolcJ/QAGqS46Hef7QXBIMw==", "dev": true, "requires": { - "is-docker": "^2.0.0" + "is-inside-container": "^1.0.0" } }, "isarray": { @@ -5756,13 +5806,13 @@ "dev": true }, "launch-editor": { - "version": "2.6.0", - "resolved": "https://registry.npmjs.org/launch-editor/-/launch-editor-2.6.0.tgz", - "integrity": "sha512-JpDCcQnyAAzZZaZ7vEiSqL690w7dAEyLao+KC96zBplnYbJS7TYNjvM3M7y3dGz+v7aIsJk3hllWuc0kWAjyRQ==", + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/launch-editor/-/launch-editor-2.10.0.tgz", + "integrity": "sha512-D7dBRJo/qcGX9xlvt/6wUYzQxjh5G1RvZPgPv8vi4KRU99DVQL/oW7tnVOCCTm2HGeo3C5HvGE5Yrh6UBoZ0vA==", "dev": true, "requires": { "picocolors": "^1.0.0", - "shell-quote": "^1.7.3" + "shell-quote": "^1.8.1" } }, "loader-runner": { @@ -5780,6 +5830,12 @@ "p-locate": "^4.1.0" } }, + "math-intrinsics": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz", + "integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==", + "dev": true + }, "media-typer": { "version": "0.3.0", "resolved": "https://registry.npmjs.org/media-typer/-/media-typer-0.3.0.tgz", @@ -5787,12 +5843,63 @@ "dev": true }, "memfs": { - "version": "3.5.3", - "resolved": "https://registry.npmjs.org/memfs/-/memfs-3.5.3.tgz", - "integrity": "sha512-UERzLsxzllchadvbPs5aolHh65ISpKpM+ccLbOJ8/vvpBKmAWf+la7dXFy7Mr0ySHbdHrFv5kGFCUHHe6GFEmw==", + "version": "4.17.2", + "resolved": "https://registry.npmjs.org/memfs/-/memfs-4.17.2.tgz", + "integrity": "sha512-NgYhCOWgovOXSzvYgUW0LQ7Qy72rWQMGGFJDoWg4G30RHd3z77VbYdtJ4fembJXBy8pMIUA31XNAupobOQlwdg==", "dev": true, "requires": { - "fs-monkey": "^1.0.4" + "@jsonjoy.com/json-pack": "^1.0.3", + "@jsonjoy.com/util": "^1.3.0", + "tree-dump": "^1.0.1", + "tslib": "^2.0.0" + }, + "dependencies": { + "@jsonjoy.com/base64": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@jsonjoy.com/base64/-/base64-1.1.2.tgz", + "integrity": "sha512-q6XAnWQDIMA3+FTiOYajoYqySkO+JSat0ytXGSuRdq9uXE7o92gzuQwQM14xaCRlBLGq3v5miDGC4vkVTn54xA==", + "dev": true, + "requires": {} + }, + "@jsonjoy.com/json-pack": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/@jsonjoy.com/json-pack/-/json-pack-1.2.0.tgz", + "integrity": "sha512-io1zEbbYcElht3tdlqEOFxZ0dMTYrHz9iMf0gqn1pPjZFTCgM5R4R5IMA20Chb2UPYYsxjzs8CgZ7Nb5n2K2rA==", + "dev": true, + "requires": { + "@jsonjoy.com/base64": "^1.1.1", + "@jsonjoy.com/util": "^1.1.2", + "hyperdyperid": "^1.2.0", + "thingies": "^1.20.0" + } + }, + "@jsonjoy.com/util": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/@jsonjoy.com/util/-/util-1.6.0.tgz", + "integrity": "sha512-sw/RMbehRhN68WRtcKCpQOPfnH6lLP4GJfqzi3iYej8tnzpZUDr6UkZYJjcjjC0FWEJOJbyM3PTIwxucUmDG2A==", + "dev": true, + "requires": {} + }, + "thingies": { + "version": "1.21.0", + "resolved": "https://registry.npmjs.org/thingies/-/thingies-1.21.0.tgz", + "integrity": "sha512-hsqsJsFMsV+aD4s3CWKk85ep/3I9XzYV/IXaSouJMYIoDlgyi11cBhsqYe9/geRfB0YIikBQg6raRaM+nIMP9g==", + "dev": true, + "requires": {} + }, + "tree-dump": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/tree-dump/-/tree-dump-1.0.3.tgz", + "integrity": "sha512-il+Cv80yVHFBwokQSfd4bldvr1Md951DpgAGfmhydt04L+YzHgubm2tQ7zueWDcGENKHq0ZvGFR/hjvNXilHEg==", + "dev": true, + "requires": {} + }, + "tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "dev": true + } } }, "merge-descriptors": { @@ -5850,27 +5957,12 @@ "mime-db": "1.52.0" } }, - "mimic-fn": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/mimic-fn/-/mimic-fn-2.1.0.tgz", - "integrity": "sha512-OqbOk5oEQeAZ8WXWydlu9HJjz9WVdEIvamMCcXmuqUYjTknH/sqsWvhQ3vgwKFRR1HpjvNBKQ37nbJgYzGqGcg==", - "dev": true - }, "minimalistic-assert": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/minimalistic-assert/-/minimalistic-assert-1.0.1.tgz", "integrity": "sha512-UtJcAD4yEaGtjPezWuO9wC4nwUnVH/8/Im3yEHQP4b67cXlD/Qr9hdITCU1xDbSEXg2XKNaP8jsReV7vQd00/A==", "dev": true }, - "minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", - "dev": true, - "requires": { - "brace-expansion": "^1.1.7" - } - }, "ms": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", @@ -5917,19 +6009,10 @@ "integrity": "sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==", "dev": true }, - "npm-run-path": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/npm-run-path/-/npm-run-path-4.0.1.tgz", - "integrity": "sha512-S48WzZW777zhNIrn7gxOlISNAqi9ZC/uQFnRdbeIHhZhCA6UqpkOT8T1G7BvfdgP4Er8gF4sUbaS0i7QvIfCWw==", - "dev": true, - "requires": { - "path-key": "^3.0.0" - } - }, "object-inspect": { - "version": "1.13.2", - "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.2.tgz", - "integrity": "sha512-IRZSRuzJiynemAXPYtPe5BoI/RESNYR7TYm50MC5Mqbd3Jmw5y790sErYw3V6SryFJD64b74qQQs9wn5Bg/k3g==", + "version": "1.13.4", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz", + "integrity": "sha512-W67iLl4J2EXEGTbfeHCffrjDfitvLANg0UlX3wFUUSTx92KXRFegMHUVgSqE+wvhAbi4WqjGg9czysTV2Epbew==", "dev": true }, "obuf": { @@ -5953,33 +6036,16 @@ "integrity": "sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA==", "dev": true }, - "once": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", - "integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==", - "dev": true, - "requires": { - "wrappy": "1" - } - }, - "onetime": { - "version": "5.1.2", - "resolved": "https://registry.npmjs.org/onetime/-/onetime-5.1.2.tgz", - "integrity": "sha512-kbpaSSGJTWdAY5KPVeMOKXSrPtr8C8C7wodJbcsd51jRnmD+GZu8Y0VoU6Dm5Z4vWr0Ig/1NKuWRKf7j5aaYSg==", - "dev": true, - "requires": { - "mimic-fn": "^2.1.0" - } - }, "open": { - "version": "8.4.2", - "resolved": "https://registry.npmjs.org/open/-/open-8.4.2.tgz", - "integrity": "sha512-7x81NCL719oNbsq/3mh+hVrAWmFuEYUqrq/Iw3kUzH8ReypT9QQ0BLoJS7/G9k6N81XjW4qHWtjWwe/9eLy1EQ==", + "version": "10.1.2", + "resolved": "https://registry.npmjs.org/open/-/open-10.1.2.tgz", + "integrity": "sha512-cxN6aIDPz6rm8hbebcP7vrQNhvRcveZoJU72Y7vskh4oIm+BZwBECnx5nTmrlres1Qapvx27Qo1Auukpf8PKXw==", "dev": true, "requires": { - "define-lazy-prop": "^2.0.0", - "is-docker": "^2.1.1", - "is-wsl": "^2.2.0" + "default-browser": "^5.2.1", + "define-lazy-prop": "^3.0.0", + "is-inside-container": "^1.0.0", + "is-wsl": "^3.1.0" } }, "p-locate": { @@ -6003,12 +6069,13 @@ } }, "p-retry": { - "version": "4.6.2", - "resolved": "https://registry.npmjs.org/p-retry/-/p-retry-4.6.2.tgz", - "integrity": "sha512-312Id396EbJdvRONlngUx0NydfrIQ5lsYu0znKVUzVvArzEIt08V1qhtyESbGVd1FGX7UKtiFp5uwKZdM8wIuQ==", + "version": "6.2.1", + "resolved": "https://registry.npmjs.org/p-retry/-/p-retry-6.2.1.tgz", + "integrity": "sha512-hEt02O4hUct5wtwg4H4KcWgDdm+l1bOaEy/hWzd8xtXB9BqxTWBBhb+2ImAtH4Cv4rPjV76xN3Zumqk3k3AhhQ==", "dev": true, "requires": { - "@types/retry": "0.12.0", + "@types/retry": "0.12.2", + "is-network-error": "^1.0.0", "retry": "^0.13.1" } }, @@ -6030,12 +6097,6 @@ "integrity": "sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==", "dev": true }, - "path-is-absolute": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz", - "integrity": "sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==", - "dev": true - }, "path-key": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", @@ -6049,9 +6110,9 @@ "dev": true }, "path-to-regexp": { - "version": "0.1.10", - "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.10.tgz", - "integrity": "sha512-7lf7qcQidTku0Gu3YDPc8DJ1q7OOucfa/BSsIwjuh56VU7katFvuM8hULfkwB3Fns/rsVF7PwPKVw1sl5KQS9w==", + "version": "0.1.12", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.12.tgz", + "integrity": "sha512-RA1GjUVMnvYFxuqovrEqZoxxW5NUZqbwKtYz/Tt7nXerk0LbLblQmrsgdeOxV5SFHf0UDggjS/bSeOZwt1pmEQ==", "dev": true }, "path-type": { @@ -6244,14 +6305,11 @@ "integrity": "sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==", "dev": true }, - "rimraf": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-3.0.2.tgz", - "integrity": "sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA==", - "dev": true, - "requires": { - "glob": "^7.1.3" - } + "run-applescript": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/run-applescript/-/run-applescript-7.0.0.tgz", + "integrity": "sha512-9by4Ij99JUr/MCFBUkDKLWK3G9HVXmabKz9U5MlIAIuvuzkiOicRYs8XJLxX+xahD+mLiiCYDqF9dKAgtzKP1A==", + "dev": true }, "run-parallel": { "version": "1.2.0", @@ -6292,11 +6350,12 @@ "dev": true }, "selfsigned": { - "version": "2.1.1", - "resolved": "https://registry.npmjs.org/selfsigned/-/selfsigned-2.1.1.tgz", - "integrity": "sha512-GSL3aowiF7wa/WtSFwnUrludWFoNhftq8bUkH9pkzjpN2XSPOAYEgg6e0sS9s0rZwgJzJiQRPU18A6clnoW5wQ==", + "version": "2.4.1", + "resolved": "https://registry.npmjs.org/selfsigned/-/selfsigned-2.4.1.tgz", + "integrity": "sha512-th5B4L2U+eGLq1TVh7zNRGBapioSORUeymIydxgFpwww9d2qyKvtuPU2jJuHvYAwwqi2Y596QBL3eEqcPEYL8Q==", "dev": true, "requires": { + "@types/node-forge": "^1.3.0", "node-forge": "^1" } }, @@ -6344,6 +6403,12 @@ "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", "dev": true }, + "encodeurl": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz", + "integrity": "sha512-TPJXq8JqFaVYm2CWmPvnP2Iyo4ZSM7/QKcSmuMLDObfpH5fi7RUGmd/rTDf+rut/saiDiQEeVTNgAmJEdAOx0w==", + "dev": true + }, "ms": { "version": "2.1.3", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", @@ -6421,28 +6486,6 @@ "escape-html": "~1.0.3", "parseurl": "~1.3.3", "send": "0.19.0" - }, - "dependencies": { - "encodeurl": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", - "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", - "dev": true - } - } - }, - "set-function-length": { - "version": "1.2.2", - "resolved": "https://registry.npmjs.org/set-function-length/-/set-function-length-1.2.2.tgz", - "integrity": "sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==", - "dev": true, - "requires": { - "define-data-property": "^1.1.4", - "es-errors": "^1.3.0", - "function-bind": "^1.1.2", - "get-intrinsic": "^1.2.4", - "gopd": "^1.0.1", - "has-property-descriptors": "^1.0.2" } }, "setprototypeof": { @@ -6476,28 +6519,58 @@ "dev": true }, "shell-quote": { - "version": "1.8.1", - "resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.1.tgz", - "integrity": "sha512-6j1W9l1iAs/4xYBI1SYOVZyFcCis9b4KCLQ8fgAGG07QvzaRLVVRQvAy85yNmmZSjYjg4MWh4gNvlPujU/5LpA==", + "version": "1.8.3", + "resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.3.tgz", + "integrity": "sha512-ObmnIF4hXNg1BqhnHmgbDETF8dLPCggZWBjkQfhZpbszZnYur5DUljTcCHii5LC3J5E0yeO/1LIMyH+UvHQgyw==", "dev": true }, "side-channel": { - "version": "1.0.6", - "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.6.tgz", - "integrity": "sha512-fDW/EZ6Q9RiO8eFG8Hj+7u/oW+XrPTIChwCOM2+th2A6OblDtYYIpve9m+KvI9Z4C9qSEXlaGR6bTEYHReuglA==", + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.1.0.tgz", + "integrity": "sha512-ZX99e6tRweoUXqR+VBrslhda51Nh5MTQwou5tnUDgbtyM0dBgmhEDtWGP/xbKn6hqfPRHujUNwz5fy/wbbhnpw==", "dev": true, "requires": { - "call-bind": "^1.0.7", "es-errors": "^1.3.0", - "get-intrinsic": "^1.2.4", - "object-inspect": "^1.13.1" + "object-inspect": "^1.13.3", + "side-channel-list": "^1.0.0", + "side-channel-map": "^1.0.1", + "side-channel-weakmap": "^1.0.2" } }, - "signal-exit": { - "version": "3.0.7", - "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-3.0.7.tgz", - "integrity": "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==", - "dev": true + "side-channel-list": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/side-channel-list/-/side-channel-list-1.0.0.tgz", + "integrity": "sha512-FCLHtRD/gnpCiCHEiJLOwdmFP+wzCmDEkc9y7NsYxeF4u7Btsn1ZuwgwJGxImImHicJArLP4R0yX4c2KCrMrTA==", + "dev": true, + "requires": { + "es-errors": "^1.3.0", + "object-inspect": "^1.13.3" + } + }, + "side-channel-map": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/side-channel-map/-/side-channel-map-1.0.1.tgz", + "integrity": "sha512-VCjCNfgMsby3tTdo02nbjtM/ewra6jPHmpThenkTYh8pG9ucZ/1P8So4u4FGBek/BjpOVsDCMoLA/iuBKIFXRA==", + "dev": true, + "requires": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3" + } + }, + "side-channel-weakmap": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/side-channel-weakmap/-/side-channel-weakmap-1.0.2.tgz", + "integrity": "sha512-WPS/HvHQTYnHisLo9McqBHOJk2FkHO/tlpvldyrnem4aeQp4hai3gythswg6p01oSoTl58rcpiFAjF2br2Ak2A==", + "dev": true, + "requires": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3", + "side-channel-map": "^1.0.1" + } }, "slash": { "version": "5.1.0", @@ -6587,12 +6660,6 @@ "safe-buffer": "~5.1.0" } }, - "strip-final-newline": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/strip-final-newline/-/strip-final-newline-2.0.0.tgz", - "integrity": "sha512-BrpvfNAE3dcvq7ll3xVumzjKjZQ5tI1sEUIKr3Uoks0XUl45St3FlatVqef9prk4jRDzhW6WZg+3bk93y6pLjA==", - "dev": true - }, "supports-color": { "version": "8.1.1", "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-8.1.1.tgz", @@ -6819,28 +6886,29 @@ } }, "webpack-dev-middleware": { - "version": "5.3.4", - "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-5.3.4.tgz", - "integrity": "sha512-BVdTqhhs+0IfoeAf7EoH5WE+exCmqGerHfDM0IL096Px60Tq2Mn9MAbnaGUe6HiMa41KMCYF19gyzZmBcq/o4Q==", + "version": "7.4.2", + "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-7.4.2.tgz", + "integrity": "sha512-xOO8n6eggxnwYpy1NlzUKpvrjfJTvae5/D6WOK0S2LSo7vjmo5gCM1DbLUmFqrMTJP+W/0YZNctm7jasWvLuBA==", "dev": true, "requires": { "colorette": "^2.0.10", - "memfs": "^3.4.3", + "memfs": "^4.6.0", "mime-types": "^2.1.31", + "on-finished": "^2.4.1", "range-parser": "^1.2.1", "schema-utils": "^4.0.0" }, "dependencies": { "ajv": { - "version": "8.12.0", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.12.0.tgz", - "integrity": "sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", + "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", "dev": true, "requires": { - "fast-deep-equal": "^3.1.1", + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" + "require-from-string": "^2.0.2" } }, "ajv-keywords": { @@ -6859,9 +6927,9 @@ "dev": true }, "schema-utils": { - "version": "4.2.0", - "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.2.0.tgz", - "integrity": "sha512-L0jRsrPpjdckP3oPug3/VxNKt2trR8TcabrM6FOAAlvC/9Phcmm+cuAgTlxBqdBR1WJx7Naj9WHw+aOmheSVbw==", + "version": "4.3.2", + "resolved": "https://registry.npmjs.org/schema-utils/-/schema-utils-4.3.2.tgz", + "integrity": "sha512-Gn/JaSk/Mt9gYubxTtSn/QCV4em9mpAPiR1rqy/Ocu19u/G9J5WWdNoUT4SiV6mFC3y6cxyFcFwdzPM3FgxGAQ==", "dev": true, "requires": { "@types/json-schema": "^7.0.9", @@ -6873,41 +6941,39 @@ } }, "webpack-dev-server": { - "version": "4.15.1", - "resolved": "https://registry.npmjs.org/webpack-dev-server/-/webpack-dev-server-4.15.1.tgz", - "integrity": "sha512-5hbAst3h3C3L8w6W4P96L5vaV0PxSmJhxZvWKYIdgxOQm8pNZ5dEOmmSLBVpP85ReeyRt6AS1QJNyo/oFFPeVA==", - "dev": true, - "requires": { - "@types/bonjour": "^3.5.9", - "@types/connect-history-api-fallback": "^1.3.5", - "@types/express": "^4.17.13", - "@types/serve-index": "^1.9.1", - "@types/serve-static": "^1.13.10", - "@types/sockjs": "^0.3.33", - "@types/ws": "^8.5.5", + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/webpack-dev-server/-/webpack-dev-server-5.2.1.tgz", + "integrity": "sha512-ml/0HIj9NLpVKOMq+SuBPLHcmbG+TGIjXRHsYfZwocUBIqEvws8NnS/V9AFQ5FKP+tgn5adwVwRrTEpGL33QFQ==", + "dev": true, + "requires": { + "@types/bonjour": "^3.5.13", + "@types/connect-history-api-fallback": "^1.5.4", + "@types/express": "^4.17.21", + "@types/express-serve-static-core": "^4.17.21", + "@types/serve-index": "^1.9.4", + "@types/serve-static": "^1.15.5", + "@types/sockjs": "^0.3.36", + "@types/ws": "^8.5.10", "ansi-html-community": "^0.0.8", - "bonjour-service": "^1.0.11", - "chokidar": "^3.5.3", + "bonjour-service": "^1.2.1", + "chokidar": "^3.6.0", "colorette": "^2.0.10", "compression": "^1.7.4", "connect-history-api-fallback": "^2.0.0", - "default-gateway": "^6.0.3", - "express": "^4.17.3", + "express": "^4.21.2", "graceful-fs": "^4.2.6", - "html-entities": "^2.3.2", - "http-proxy-middleware": "^2.0.3", - "ipaddr.js": "^2.0.1", - "launch-editor": "^2.6.0", - "open": "^8.0.9", - "p-retry": "^4.5.0", - "rimraf": "^3.0.2", - "schema-utils": "^4.0.0", - "selfsigned": "^2.1.1", + "http-proxy-middleware": "^2.0.7", + "ipaddr.js": "^2.1.0", + "launch-editor": "^2.6.1", + "open": "^10.0.3", + "p-retry": "^6.2.0", + "schema-utils": "^4.2.0", + "selfsigned": "^2.4.1", "serve-index": "^1.9.1", "sockjs": "^0.3.24", "spdy": "^4.0.2", - "webpack-dev-middleware": "^5.3.1", - "ws": "^8.13.0" + "webpack-dev-middleware": "^7.4.2", + "ws": "^8.18.0" }, "dependencies": { "ajv": { @@ -6993,16 +7059,10 @@ "integrity": "sha512-CC1bOL87PIWSBhDcTrdeLo6eGT7mCFtrg0uIJtqJUFyK+eJnzl8A1niH56uu7KMa5XFrtiV+AQuHO3n7DsHnLQ==", "dev": true }, - "wrappy": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", - "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", - "dev": true - }, "ws": { - "version": "8.17.1", - "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz", - "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==", + "version": "8.18.2", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.18.2.tgz", + "integrity": "sha512-DMricUmwGZUVr++AEAe2uiVM7UoO9MAVZMDu05UQOaUII0lp+zOzLLU4Xqh/JvTqklB1T4uELaaPBKyjE1r4fQ==", "dev": true, "requires": {} } diff --git a/datafusion/wasmtest/datafusion-wasm-app/package.json b/datafusion/wasmtest/datafusion-wasm-app/package.json index 5a2262400cfd..b46993de77d9 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package.json @@ -29,7 +29,7 @@ "devDependencies": { "webpack": "5.94.0", "webpack-cli": "5.1.4", - "webpack-dev-server": "4.15.1", + "webpack-dev-server": "5.2.1", "copy-webpack-plugin": "12.0.2" } } diff --git a/datafusion/wasmtest/src/lib.rs b/datafusion/wasmtest/src/lib.rs index 0a7e546b4b18..e30a1046ab27 100644 --- a/datafusion/wasmtest/src/lib.rs +++ b/datafusion/wasmtest/src/lib.rs @@ -92,7 +92,8 @@ mod test { }; use datafusion_common::test_util::batches_to_string; use datafusion_execution::{ - config::SessionConfig, disk_manager::DiskManagerConfig, + config::SessionConfig, + disk_manager::{DiskManagerBuilder, DiskManagerMode}, runtime_env::RuntimeEnvBuilder, }; use datafusion_physical_plan::collect; @@ -112,7 +113,9 @@ mod test { fn get_ctx() -> Arc { let rt = RuntimeEnvBuilder::new() - .with_disk_manager(DiskManagerConfig::Disabled) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled), + ) .build_arc() .unwrap(); let session_config = SessionConfig::new().with_target_partitions(1); diff --git a/dev/changelog/47.0.0.md b/dev/changelog/47.0.0.md new file mode 100644 index 000000000000..64ca2e157a9e --- /dev/null +++ b/dev/changelog/47.0.0.md @@ -0,0 +1,506 @@ + + +# Apache DataFusion 47.0.0 Changelog + +This release consists of 364 commits from 94 contributors. See credits at the end of this changelog for more information. + +**Breaking changes:** + +- chore: cleanup deprecated API since `version <= 40` [#15027](https://github.com/apache/datafusion/pull/15027) (qazxcdswe123) +- fix: mark ScalarUDFImpl::invoke_batch as deprecated [#15049](https://github.com/apache/datafusion/pull/15049) (Blizzara) +- feat: support customize metadata in alias for dataframe api [#15120](https://github.com/apache/datafusion/pull/15120) (chenkovsky) +- Refactor: add `FileGroup` structure for `Vec` [#15379](https://github.com/apache/datafusion/pull/15379) (xudong963) +- Change default `EXPLAIN` format in `datafusion-cli` to `tree` format [#15427](https://github.com/apache/datafusion/pull/15427) (alamb) +- Support computing statistics for FileGroup [#15432](https://github.com/apache/datafusion/pull/15432) (xudong963) +- Remove redundant statistics from FileScanConfig [#14955](https://github.com/apache/datafusion/pull/14955) (Standing-Man) +- parquet reader: move pruning predicate creation from ParquetSource to ParquetOpener [#15561](https://github.com/apache/datafusion/pull/15561) (adriangb) +- feat: Add unique id for every memory consumer [#15613](https://github.com/apache/datafusion/pull/15613) (EmilyMatt) + +**Performance related:** + +- Fix sequential metadata fetching in ListingTable causing high latency [#14918](https://github.com/apache/datafusion/pull/14918) (geoffreyclaude) +- Implement GroupsAccumulator for min/max Duration [#15322](https://github.com/apache/datafusion/pull/15322) (shruti2522) +- [Minor] Remove/reorder logical plan rules [#15421](https://github.com/apache/datafusion/pull/15421) (Dandandan) +- Improve performance of `first_value` by implementing special `GroupsAccumulator` [#15266](https://github.com/apache/datafusion/pull/15266) (UBarney) +- perf: unwrap cast for comparing ints =/!= strings [#15110](https://github.com/apache/datafusion/pull/15110) (alan910127) +- Improve performance sort TPCH q3 with Utf8Vew ( Sort-preserving mergi… [#15447](https://github.com/apache/datafusion/pull/15447) (zhuqi-lucas) +- perf: Reuse row converter during sort [#15302](https://github.com/apache/datafusion/pull/15302) (2010YOUY01) +- perf: Add TopK benchmarks as variation over the `sort_tpch` benchmarks [#15560](https://github.com/apache/datafusion/pull/15560) (geoffreyclaude) +- Perf: remove `clone` on `uninitiated_partitions` in SortPreservingMergeStream [#15562](https://github.com/apache/datafusion/pull/15562) (rluvaton) +- Add short circuit evaluation for `AND` and `OR` [#15462](https://github.com/apache/datafusion/pull/15462) (acking-you) +- perf: Introduce sort prefix computation for early TopK exit optimization on partially sorted input (10x speedup on top10 bench) [#15563](https://github.com/apache/datafusion/pull/15563) (geoffreyclaude) +- Improve performance of `last_value` by implementing special `GroupsAccumulator` [#15542](https://github.com/apache/datafusion/pull/15542) (UBarney) +- Enhance: simplify `x=x` --> `x IS NOT NULL OR NULL` [#15589](https://github.com/apache/datafusion/pull/15589) (ding-young) + +**Implemented enhancements:** + +- feat: Add `tree` / pretty explain mode [#14677](https://github.com/apache/datafusion/pull/14677) (irenjj) +- feat: Add `array_max` function support [#14470](https://github.com/apache/datafusion/pull/14470) (erenavsarogullari) +- feat: implement tree explain for `ProjectionExec` [#15082](https://github.com/apache/datafusion/pull/15082) (Standing-Man) +- feat: support ApproxDistinct with utf8view [#15200](https://github.com/apache/datafusion/pull/15200) (zhuqi-lucas) +- feat: Attach `Diagnostic` to more than one column errors in scalar_subquery and in_subquery [#15143](https://github.com/apache/datafusion/pull/15143) (changsun20) +- feat: topk functionality for aggregates should support utf8view and largeutf8 [#15152](https://github.com/apache/datafusion/pull/15152) (zhuqi-lucas) +- feat: Native support utf8view for regex string operators [#15275](https://github.com/apache/datafusion/pull/15275) (zhuqi-lucas) +- feat: introduce `JoinSetTracer` trait for tracing context propagation in spawned tasks [#14547](https://github.com/apache/datafusion/pull/14547) (geoffreyclaude) +- feat: Support serde for JsonSource PhysicalPlan [#15311](https://github.com/apache/datafusion/pull/15311) (westhide) +- feat: Support serde for FileScanConfig `batch_size` [#15335](https://github.com/apache/datafusion/pull/15335) (westhide) +- feat: simplify regex wildcard pattern [#15299](https://github.com/apache/datafusion/pull/15299) (waynexia) +- feat: Add union_by_name, union_by_name_distinct to DataFrame api [#15489](https://github.com/apache/datafusion/pull/15489) (Omega359) +- feat: Add config `max_temp_directory_size` to limit max disk usage for spilling queries [#15520](https://github.com/apache/datafusion/pull/15520) (2010YOUY01) +- feat: Add tracing regression tests [#15673](https://github.com/apache/datafusion/pull/15673) (geoffreyclaude) + +**Fixed bugs:** + +- fix: External sort failing on an edge case [#15017](https://github.com/apache/datafusion/pull/15017) (2010YOUY01) +- fix: graceful NULL and type error handling in array functions [#14737](https://github.com/apache/datafusion/pull/14737) (alan910127) +- fix: Support datatype cast for insert api same as insert into sql [#15091](https://github.com/apache/datafusion/pull/15091) (zhuqi-lucas) +- fix: unparse for subqueryalias [#15068](https://github.com/apache/datafusion/pull/15068) (chenkovsky) +- fix: date_trunc bench broken by #15049 [#15169](https://github.com/apache/datafusion/pull/15169) (Blizzara) +- fix: compound_field_access doesn't identifier qualifier. [#15153](https://github.com/apache/datafusion/pull/15153) (chenkovsky) +- fix: unparsing left/ right semi/mark join [#15212](https://github.com/apache/datafusion/pull/15212) (chenkovsky) +- fix: handle duplicate WindowFunction expressions in Substrait consumer [#15211](https://github.com/apache/datafusion/pull/15211) (Blizzara) +- fix: write hive partitions for any int/uint/float [#15337](https://github.com/apache/datafusion/pull/15337) (christophermcdermott) +- fix: `core_expressions` feature flag broken, move `overlay` into `core` functions [#15217](https://github.com/apache/datafusion/pull/15217) (shruti2522) +- fix: Redundant files spilled during external sort + introduce `SpillManager` [#15355](https://github.com/apache/datafusion/pull/15355) (2010YOUY01) +- fix: typo of DropFunction [#15434](https://github.com/apache/datafusion/pull/15434) (chenkovsky) +- fix: Unconditionally wrap UNION BY NAME input nodes w/ `Projection` [#15242](https://github.com/apache/datafusion/pull/15242) (rkrishn7) +- fix: the average time for clickbench query compute should use new vec to make it compute for each query [#15472](https://github.com/apache/datafusion/pull/15472) (zhuqi-lucas) +- fix: Assertion fail in external sort [#15469](https://github.com/apache/datafusion/pull/15469) (2010YOUY01) +- fix: aggregation corner case [#15457](https://github.com/apache/datafusion/pull/15457) (chenkovsky) +- fix: update group by columns for merge phase after spill [#15531](https://github.com/apache/datafusion/pull/15531) (rluvaton) +- fix: Queries similar to `count-bug` produce incorrect results [#15281](https://github.com/apache/datafusion/pull/15281) (suibianwanwank) +- fix: ffi aggregation [#15576](https://github.com/apache/datafusion/pull/15576) (chenkovsky) +- fix: nested window function [#15033](https://github.com/apache/datafusion/pull/15033) (chenkovsky) +- fix: dictionary encoded column to partition column casting bug [#15652](https://github.com/apache/datafusion/pull/15652) (haruband) +- fix: recursion protection for physical plan node [#15600](https://github.com/apache/datafusion/pull/15600) (chenkovsky) +- fix: add map coercion for binary ops [#15551](https://github.com/apache/datafusion/pull/15551) (alexwilcoxson-rel) +- fix: Rewrite `date_trunc` and `from_unixtime` for the SQLite unparser [#15630](https://github.com/apache/datafusion/pull/15630) (peasee) +- fix(substrait): fix regressed edge case in renaming inner struct fields [#15634](https://github.com/apache/datafusion/pull/15634) (Blizzara) +- fix: normalize window ident [#15639](https://github.com/apache/datafusion/pull/15639) (chenkovsky) +- fix: unparse join without projection [#15693](https://github.com/apache/datafusion/pull/15693) (chenkovsky) + +**Documentation updates:** + +- MINOR fix(docs): set the proper link for dev-env setup in contrib guide [#14960](https://github.com/apache/datafusion/pull/14960) (clflushopt) +- Add Upgrade Guide for DataFusion 46.0.0 [#14891](https://github.com/apache/datafusion/pull/14891) (alamb) +- Improve `SessionStateBuilder::new` documentation [#14980](https://github.com/apache/datafusion/pull/14980) (alamb) +- Minor: Replace Star and Fork buttons in docs with static versions [#14988](https://github.com/apache/datafusion/pull/14988) (amoeba) +- Fix documentation warnings and error if anymore occur [#14952](https://github.com/apache/datafusion/pull/14952) (AmosAidoo) +- docs: Improve docs on AggregateFunctionExpr construction [#15044](https://github.com/apache/datafusion/pull/15044) (ctsk) +- Minor: More comment to aggregation fuzzer [#15048](https://github.com/apache/datafusion/pull/15048) (2010YOUY01) +- Improve benchmark documentation [#15054](https://github.com/apache/datafusion/pull/15054) (carols10cents) +- doc: update RecordBatchReceiverStreamBuilder::spawn_blocking task behaviour [#14995](https://github.com/apache/datafusion/pull/14995) (shruti2522) +- doc: Correct benchmark command [#15094](https://github.com/apache/datafusion/pull/15094) (qazxcdswe123) +- Add `insta` / snapshot testing to CLI & set up AWS mock [#13672](https://github.com/apache/datafusion/pull/13672) (blaginin) +- Config: Add support default sql varchar to view types [#15104](https://github.com/apache/datafusion/pull/15104) (zhuqi-lucas) +- Support `EXPLAIN ... FORMAT ...` [#15166](https://github.com/apache/datafusion/pull/15166) (alamb) +- Update version to 46.0.1, add CHANGELOG (#15243) [#15244](https://github.com/apache/datafusion/pull/15244) (xudong963) +- docs: update documentation for Final GroupBy in accumulator.rs [#15279](https://github.com/apache/datafusion/pull/15279) (qazxcdswe123) +- minor: fix `data/sqlite` link [#15286](https://github.com/apache/datafusion/pull/15286) (sdht0) +- Add upgrade notes for array signatures [#15237](https://github.com/apache/datafusion/pull/15237) (jkosh44) +- Add doc for the `statistics_from_parquet_meta_calc method` [#15330](https://github.com/apache/datafusion/pull/15330) (xudong963) +- added explaination for Schema and DFSchema to documentation [#15329](https://github.com/apache/datafusion/pull/15329) (Jiashu-Hu) +- Documentation: Plan custom expressions [#15353](https://github.com/apache/datafusion/pull/15353) (Jiashu-Hu) +- Update concepts-readings-events.md [#15440](https://github.com/apache/datafusion/pull/15440) (berkaysynnada) +- Add support for DISTINCT + ORDER BY in `ARRAY_AGG` [#14413](https://github.com/apache/datafusion/pull/14413) (gabotechs) +- Update the copyright year [#15453](https://github.com/apache/datafusion/pull/15453) (omkenge) +- Docs: Formatting and Added Extra resources [#15450](https://github.com/apache/datafusion/pull/15450) (2SpaceMasterRace) +- Add documentation for `Run extended tests` command [#15463](https://github.com/apache/datafusion/pull/15463) (alamb) +- bench: Document how to use cross platform Samply profiler [#15481](https://github.com/apache/datafusion/pull/15481) (comphead) +- Update user guide to note decimal is not experimental anymore [#15515](https://github.com/apache/datafusion/pull/15515) (Jiashu-Hu) +- datafusion-cli: document reading partitioned parquet [#15505](https://github.com/apache/datafusion/pull/15505) (marvelshan) +- Update concepts-readings-events.md [#15541](https://github.com/apache/datafusion/pull/15541) (oznur-synnada) +- Add documentation example for `AggregateExprBuilder` [#15504](https://github.com/apache/datafusion/pull/15504) (Shreyaskr1409) +- Docs : Added Sql examples for window Functions : `nth_val` , etc [#15555](https://github.com/apache/datafusion/pull/15555) (Adez017) +- Add disk usage limit configuration to datafusion-cli [#15586](https://github.com/apache/datafusion/pull/15586) (jsai28) +- Bug fix : fix the bug in docs in 'cum_dist()' Example [#15618](https://github.com/apache/datafusion/pull/15618) (Adez017) +- Make tree the Default EXPLAIN Format and Reorder Documentation Sections [#15706](https://github.com/apache/datafusion/pull/15706) (kosiew) +- Add coerce int96 option for Parquet to support different TimeUnits, test int96_from_spark.parquet from parquet-testing [#15537](https://github.com/apache/datafusion/pull/15537) (mbutrovich) +- STRING_AGG missing functionality [#14412](https://github.com/apache/datafusion/pull/14412) (gabotechs) +- doc : update RepartitionExec display tree [#15710](https://github.com/apache/datafusion/pull/15710) (getChan) +- Update version to 47.0.0, add CHANGELOG [#15731](https://github.com/apache/datafusion/pull/15731) (xudong963) + +**Other:** + +- Improve documentation for `DataSourceExec`, `FileScanConfig`, `DataSource` etc [#14941](https://github.com/apache/datafusion/pull/14941) (alamb) +- Do not swap with projection when file is partitioned [#14956](https://github.com/apache/datafusion/pull/14956) (blaginin) +- Minor: Add more projection pushdown tests, clarify comments [#14963](https://github.com/apache/datafusion/pull/14963) (alamb) +- Update labeler components [#14942](https://github.com/apache/datafusion/pull/14942) (alamb) +- Deprecate `Expr::Wildcard` [#14959](https://github.com/apache/datafusion/pull/14959) (linhr) +- Minor: use FileScanConfig builder API in some tests [#14938](https://github.com/apache/datafusion/pull/14938) (alamb) +- Minor: improve documentation of `AggregateMode` [#14946](https://github.com/apache/datafusion/pull/14946) (alamb) +- chore(deps): bump thiserror from 2.0.11 to 2.0.12 [#14971](https://github.com/apache/datafusion/pull/14971) (dependabot[bot]) +- chore(deps): bump pyo3 from 0.23.4 to 0.23.5 [#14972](https://github.com/apache/datafusion/pull/14972) (dependabot[bot]) +- chore(deps): bump async-trait from 0.1.86 to 0.1.87 [#14973](https://github.com/apache/datafusion/pull/14973) (dependabot[bot]) +- Fix verification script and extended tests due to `rustup` changes [#14990](https://github.com/apache/datafusion/pull/14990) (alamb) +- Split out avro, parquet, json and csv into individual crates [#14951](https://github.com/apache/datafusion/pull/14951) (AdamGS) +- Minor: Add `backtrace` feature in datafusion-cli [#14997](https://github.com/apache/datafusion/pull/14997) (2010YOUY01) +- chore: Update `SessionStateBuilder::with_default_features` does not replace existing features [#14935](https://github.com/apache/datafusion/pull/14935) (irenjj) +- Make `create_ordering` pub and add doc for it [#14996](https://github.com/apache/datafusion/pull/14996) (xudong963) +- Simplify Between expression to Eq [#14994](https://github.com/apache/datafusion/pull/14994) (jayzhan211) +- Count wildcard alias [#14927](https://github.com/apache/datafusion/pull/14927) (jayzhan211) +- replace TypeSignature::String with TypeSignature::Coercible [#14917](https://github.com/apache/datafusion/pull/14917) (zjregee) +- Minor: Add indentation to EnforceDistribution test plans. [#15007](https://github.com/apache/datafusion/pull/15007) (wiedld) +- Minor: add method `SessionStateBuilder::new_with_default_features()` [#14998](https://github.com/apache/datafusion/pull/14998) (shruti2522) +- Implement `tree` explain for FilterExec [#15001](https://github.com/apache/datafusion/pull/15001) (alamb) +- Unparser add `AtArrow` and `ArrowAt` conversion to BinaryOperator [#14968](https://github.com/apache/datafusion/pull/14968) (cetra3) +- Add dependency checks to verify-release-candidate script [#15009](https://github.com/apache/datafusion/pull/15009) (waynexia) +- Fix: to_char Function Now Correctly Handles DATE Values in DataFusion [#14970](https://github.com/apache/datafusion/pull/14970) (kosiew) +- Make Substrait Schema Structs always non-nullable [#15011](https://github.com/apache/datafusion/pull/15011) (amoeba) +- Adjust physical optimizer rule order, put `ProjectionPushdown` at last [#15040](https://github.com/apache/datafusion/pull/15040) (xudong963) +- Move `UnwrapCastInComparison` into `Simplifier` [#15012](https://github.com/apache/datafusion/pull/15012) (jayzhan211) +- chore(deps): bump aws-config from 1.5.17 to 1.5.18 [#15041](https://github.com/apache/datafusion/pull/15041) (dependabot[bot]) +- chore(deps): bump bytes from 1.10.0 to 1.10.1 [#15042](https://github.com/apache/datafusion/pull/15042) (dependabot[bot]) +- Minor: Deprecate `ScalarValue::raw_data` [#15016](https://github.com/apache/datafusion/pull/15016) (qazxcdswe123) +- Implement tree explain for `DataSourceExec` [#15029](https://github.com/apache/datafusion/pull/15029) (alamb) +- Refactor test suite in EnforceDistribution, to use standard test config. [#15010](https://github.com/apache/datafusion/pull/15010) (wiedld) +- Update ring to v0.17.13 [#15063](https://github.com/apache/datafusion/pull/15063) (alamb) +- Remove deprecated function `OptimizerRule::try_optimize` [#15051](https://github.com/apache/datafusion/pull/15051) (qazxcdswe123) +- Minor: fix CI to make the sqllogic testing result consistent [#15059](https://github.com/apache/datafusion/pull/15059) (zhuqi-lucas) +- Refactor SortPushdown using the standard top-down visitor and using `EquivalenceProperties` [#14821](https://github.com/apache/datafusion/pull/14821) (wiedld) +- Improve explain tree formatting for longer lines / word wrap [#15031](https://github.com/apache/datafusion/pull/15031) (irenjj) +- chore(deps): bump sqllogictest from 0.27.2 to 0.28.0 [#15060](https://github.com/apache/datafusion/pull/15060) (dependabot[bot]) +- chore(deps): bump async-compression from 0.4.18 to 0.4.19 [#15061](https://github.com/apache/datafusion/pull/15061) (dependabot[bot]) +- Handle columns in with_new_exprs with a Join [#15055](https://github.com/apache/datafusion/pull/15055) (delamarch3) +- Minor: Improve documentation of `need_handle_count_bug` [#15050](https://github.com/apache/datafusion/pull/15050) (suibianwanwank) +- Implement `tree` explain for `HashJoinExec` [#15079](https://github.com/apache/datafusion/pull/15079) (irenjj) +- Implement tree explain for PartialSortExec [#15066](https://github.com/apache/datafusion/pull/15066) (irenjj) +- Implement `tree` explain for `SortExec` [#15077](https://github.com/apache/datafusion/pull/15077) (irenjj) +- Minor: final `46.0.0` release tweaks: changelog + instructions [#15073](https://github.com/apache/datafusion/pull/15073) (alamb) +- Implement tree explain for `NestedLoopJoinExec`, `CrossJoinExec`, `So… [#15081](https://github.com/apache/datafusion/pull/15081) (irenjj) +- Implement `tree` explain for `BoundedWindowAggExec` and `WindowAggExec` [#15084](https://github.com/apache/datafusion/pull/15084) (irenjj) +- implement tree rendering for StreamingTableExec [#15085](https://github.com/apache/datafusion/pull/15085) (Standing-Man) +- chore(deps): bump semver from 1.0.25 to 1.0.26 [#15116](https://github.com/apache/datafusion/pull/15116) (dependabot[bot]) +- chore(deps): bump clap from 4.5.30 to 4.5.31 [#15115](https://github.com/apache/datafusion/pull/15115) (dependabot[bot]) +- implement tree explain for GlobalLimitExec [#15100](https://github.com/apache/datafusion/pull/15100) (zjregee) +- Minor: Cleanup useless/duplicated code in gen tools [#15113](https://github.com/apache/datafusion/pull/15113) (lewiszlw) +- Refactor EnforceDistribution test cases to demonstrate dependencies across optimizer runs. [#15074](https://github.com/apache/datafusion/pull/15074) (wiedld) +- Improve parsing `extra_info` in tree explain [#15125](https://github.com/apache/datafusion/pull/15125) (irenjj) +- Add tests for simplification and coercion of `SessionContext::create_physical_expr` [#15034](https://github.com/apache/datafusion/pull/15034) (alamb) +- Minor: Fix invalid query in test [#15131](https://github.com/apache/datafusion/pull/15131) (alamb) +- Do not display logical_plan win explain `tree` mode 🧹 [#15132](https://github.com/apache/datafusion/pull/15132) (alamb) +- Substrait support for propagating TableScan.filters to Substrait ReadRel.filter [#14194](https://github.com/apache/datafusion/pull/14194) (jamxia155) +- Fix wasm32 build on version 46 [#15102](https://github.com/apache/datafusion/pull/15102) (XiangpengHao) +- Fix broken `serde` feature [#15124](https://github.com/apache/datafusion/pull/15124) (vadimpiven) +- chore(deps): bump tempfile from 3.17.1 to 3.18.0 [#15146](https://github.com/apache/datafusion/pull/15146) (dependabot[bot]) +- chore(deps): bump syn from 2.0.98 to 2.0.100 [#15147](https://github.com/apache/datafusion/pull/15147) (dependabot[bot]) +- Implement tree explain for AggregateExec [#15103](https://github.com/apache/datafusion/pull/15103) (zebsme) +- Implement tree explain for `RepartitionExec` and `WorkTableExec` [#15137](https://github.com/apache/datafusion/pull/15137) (Standing-Man) +- Expand wildcard to actual expressions in `prepare_select_exprs` [#15090](https://github.com/apache/datafusion/pull/15090) (jayzhan211) +- fixed PushDownFilter bug [15047] [#15142](https://github.com/apache/datafusion/pull/15142) (Jiashu-Hu) +- Bump `env_logger` from `0.11.6` to `0.11.7` [#15148](https://github.com/apache/datafusion/pull/15148) (mbrobbel) +- Minor: fix extend sqllogical consistent with main test [#15145](https://github.com/apache/datafusion/pull/15145) (zhuqi-lucas) +- Implement tree rendering for `SortPreservingMergeExec` [#15140](https://github.com/apache/datafusion/pull/15140) (Standing-Man) +- Remove expand wildcard rule [#15170](https://github.com/apache/datafusion/pull/15170) (jayzhan211) +- chore: remove ScalarUDFImpl::return_type_from_exprs [#15130](https://github.com/apache/datafusion/pull/15130) (Blizzara) +- chore(deps): bump libc from 0.2.170 to 0.2.171 [#15176](https://github.com/apache/datafusion/pull/15176) (dependabot[bot]) +- chore(deps): bump serde_json from 1.0.139 to 1.0.140 [#15175](https://github.com/apache/datafusion/pull/15175) (dependabot[bot]) +- chore(deps): bump substrait from 0.53.2 to 0.54.0 [#15043](https://github.com/apache/datafusion/pull/15043) (dependabot[bot]) +- Minor: split EXPLAIN and ANALYZE planning into different functions [#15188](https://github.com/apache/datafusion/pull/15188) (alamb) +- Implement `tree` explain for `JsonSink` [#15185](https://github.com/apache/datafusion/pull/15185) (irenjj) +- Split out `datafusion-substrait` and `datafusion-proto` CI feature checks, increase coverage [#15156](https://github.com/apache/datafusion/pull/15156) (alamb) +- Remove unused wildcard expanding methods [#15180](https://github.com/apache/datafusion/pull/15180) (goldmedal) +- #15108 issue: "Non Panic Task error" is not an internal error [#15109](https://github.com/apache/datafusion/pull/15109) (Satyam018) +- Implement tree explain for LazyMemoryExec [#15187](https://github.com/apache/datafusion/pull/15187) (zebsme) +- implement tree explain for CoalesceBatchesExec [#15194](https://github.com/apache/datafusion/pull/15194) (Standing-Man) +- Implement `tree` explain for `CsvSink` [#15204](https://github.com/apache/datafusion/pull/15204) (irenjj) +- chore(deps): bump blake3 from 1.6.0 to 1.6.1 [#15198](https://github.com/apache/datafusion/pull/15198) (dependabot[bot]) +- chore(deps): bump clap from 4.5.31 to 4.5.32 [#15199](https://github.com/apache/datafusion/pull/15199) (dependabot[bot]) +- chore(deps): bump serde from 1.0.218 to 1.0.219 [#15197](https://github.com/apache/datafusion/pull/15197) (dependabot[bot]) +- Fix datafusion proto crate `json` feature [#15172](https://github.com/apache/datafusion/pull/15172) (Owen-CH-Leung) +- Add blog link to `EquivalenceProperties` docs [#15215](https://github.com/apache/datafusion/pull/15215) (alamb) +- Minor: split datafusion-cli testing into its own CI job [#15075](https://github.com/apache/datafusion/pull/15075) (alamb) +- Implement tree explain for InterleaveExec [#15219](https://github.com/apache/datafusion/pull/15219) (zebsme) +- Move catalog_common out of core [#15193](https://github.com/apache/datafusion/pull/15193) (logan-keede) +- chore(deps): bump tokio-util from 0.7.13 to 0.7.14 [#15223](https://github.com/apache/datafusion/pull/15223) (dependabot[bot]) +- chore(deps): bump aws-config from 1.5.18 to 1.6.0 [#15222](https://github.com/apache/datafusion/pull/15222) (dependabot[bot]) +- chore(deps): bump bzip2 from 0.5.1 to 0.5.2 [#15221](https://github.com/apache/datafusion/pull/15221) (dependabot[bot]) +- Document guidelines for physical operator yielding [#15030](https://github.com/apache/datafusion/pull/15030) (carols10cents) +- Implement `tree` explain for `ArrowFileSink`, fix original URL [#15206](https://github.com/apache/datafusion/pull/15206) (irenjj) +- Implement tree explain for `LocalLimitExec` [#15232](https://github.com/apache/datafusion/pull/15232) (shruti2522) +- Use insta for `DataFrame` tests [#15165](https://github.com/apache/datafusion/pull/15165) (blaginin) +- Re-enable github discussion [#15241](https://github.com/apache/datafusion/pull/15241) (2010YOUY01) +- Minor: exclude datafusion-cli testing for mac [#15240](https://github.com/apache/datafusion/pull/15240) (zhuqi-lucas) +- Implement tree explain for CoalescePartitionsExec [#15225](https://github.com/apache/datafusion/pull/15225) (Shreyaskr1409) +- Enable `used_underscore_binding` clippy lint [#15189](https://github.com/apache/datafusion/pull/15189) (Shreyaskr1409) +- Simpler to see expressions in explain `tree` mode [#15163](https://github.com/apache/datafusion/pull/15163) (irenjj) +- Fix invalid schema for unions in ViewTables [#15135](https://github.com/apache/datafusion/pull/15135) (Friede80) +- Make `ListingTableUrl::try_new` public [#15250](https://github.com/apache/datafusion/pull/15250) (linhr) +- Fix wildcard dataframe case [#15230](https://github.com/apache/datafusion/pull/15230) (jayzhan211) +- Simplify the printing of all plans containing `expr` in `tree` mode [#15249](https://github.com/apache/datafusion/pull/15249) (irenjj) +- Support utf8view datatype for window [#15257](https://github.com/apache/datafusion/pull/15257) (zhuqi-lucas) +- chore: remove deprecated variants of UDF's invoke (invoke, invoke_no_args, invoke_batch) [#15123](https://github.com/apache/datafusion/pull/15123) (Blizzara) +- Improve feature flag CI coverage `datafusion` and `datafusion-functions` [#15203](https://github.com/apache/datafusion/pull/15203) (alamb) +- Add debug logging for default catalog overwrite in SessionState build [#15251](https://github.com/apache/datafusion/pull/15251) (byte-sourcerer) +- Implement tree explain for PlaceholderRowExec [#15270](https://github.com/apache/datafusion/pull/15270) (zebsme) +- Implement tree explain for UnionExec [#15278](https://github.com/apache/datafusion/pull/15278) (zebsme) +- Migrate dataframe tests to `insta` [#15262](https://github.com/apache/datafusion/pull/15262) (jsai28) +- Minor: consistently apply `clippy::clone_on_ref_ptr` in all crates [#15284](https://github.com/apache/datafusion/pull/15284) (alamb) +- chore(deps): bump async-trait from 0.1.87 to 0.1.88 [#15294](https://github.com/apache/datafusion/pull/15294) (dependabot[bot]) +- chore(deps): bump uuid from 1.15.1 to 1.16.0 [#15292](https://github.com/apache/datafusion/pull/15292) (dependabot[bot]) +- Add CatalogProvider and SchemaProvider to FFI Crate [#15280](https://github.com/apache/datafusion/pull/15280) (timsaucer) +- Refactor file schema type coercions [#15268](https://github.com/apache/datafusion/pull/15268) (xudong963) +- chore(deps): bump rust_decimal from 1.36.0 to 1.37.0 [#15293](https://github.com/apache/datafusion/pull/15293) (dependabot[bot]) +- chore: Attach Diagnostic to "incompatible type in unary expression" error [#15209](https://github.com/apache/datafusion/pull/15209) (onlyjackfrost) +- Support logic optimize rule to pass the case that Utf8view datatype combined with Utf8 datatype [#15239](https://github.com/apache/datafusion/pull/15239) (zhuqi-lucas) +- Migrate user_defined tests to insta [#15255](https://github.com/apache/datafusion/pull/15255) (shruti2522) +- Remove inline table scan analyzer rule [#15201](https://github.com/apache/datafusion/pull/15201) (jayzhan211) +- CI Red: Fix union in view table test [#15300](https://github.com/apache/datafusion/pull/15300) (jayzhan211) +- refactor: Move view and stream from `datasource` to `catalog`, deprecate `View::try_new` [#15260](https://github.com/apache/datafusion/pull/15260) (logan-keede) +- chore(deps): bump substrait from 0.54.0 to 0.55.0 [#15305](https://github.com/apache/datafusion/pull/15305) (dependabot[bot]) +- chore(deps): bump half from 2.4.1 to 2.5.0 [#15303](https://github.com/apache/datafusion/pull/15303) (dependabot[bot]) +- chore(deps): bump mimalloc from 0.1.43 to 0.1.44 [#15304](https://github.com/apache/datafusion/pull/15304) (dependabot[bot]) +- Fix predicate pushdown for custom SchemaAdapters [#15263](https://github.com/apache/datafusion/pull/15263) (adriangb) +- Fix extended tests by restore datafusion-testing submodule [#15318](https://github.com/apache/datafusion/pull/15318) (alamb) +- Support Duration in min/max agg functions [#15310](https://github.com/apache/datafusion/pull/15310) (svranesevic) +- Migrate tests to insta [#15288](https://github.com/apache/datafusion/pull/15288) (jsai28) +- chore(deps): bump quote from 1.0.38 to 1.0.40 [#15332](https://github.com/apache/datafusion/pull/15332) (dependabot[bot]) +- chore(deps): bump blake3 from 1.6.1 to 1.7.0 [#15331](https://github.com/apache/datafusion/pull/15331) (dependabot[bot]) +- Simplify display format of `AggregateFunctionExpr`, add `Expr::sql_name` [#15253](https://github.com/apache/datafusion/pull/15253) (irenjj) +- chore(deps): bump indexmap from 2.7.1 to 2.8.0 [#15333](https://github.com/apache/datafusion/pull/15333) (dependabot[bot]) +- chore(deps): bump tokio from 1.43.0 to 1.44.1 [#15347](https://github.com/apache/datafusion/pull/15347) (dependabot[bot]) +- chore(deps): bump tempfile from 3.18.0 to 3.19.1 [#15346](https://github.com/apache/datafusion/pull/15346) (dependabot[bot]) +- Minor: Keep debug symbols for `release-nonlto` build [#15350](https://github.com/apache/datafusion/pull/15350) (2010YOUY01) +- Use `any` instead of `for_each` [#15289](https://github.com/apache/datafusion/pull/15289) (xudong963) +- refactor: move `CteWorkTable`, `default_table_source` a bunch of files out of core [#15316](https://github.com/apache/datafusion/pull/15316) (logan-keede) +- Fix empty aggregation function count() in Substrait [#15345](https://github.com/apache/datafusion/pull/15345) (gabotechs) +- Improved error for expand wildcard rule [#15287](https://github.com/apache/datafusion/pull/15287) (Jiashu-Hu) +- Added tests with are writing into parquet files in memory for issue #… [#15325](https://github.com/apache/datafusion/pull/15325) (pranavJibhakate) +- Migrate physical plan tests to `insta` (Part-1) [#15313](https://github.com/apache/datafusion/pull/15313) (Shreyaskr1409) +- Fix array_has_all and array_has_any with empty array [#15039](https://github.com/apache/datafusion/pull/15039) (LuQQiu) +- Update datafusion-testing pin to fix extended tests [#15368](https://github.com/apache/datafusion/pull/15368) (alamb) +- chore(deps): Update sqlparser to 0.55.0 [#15183](https://github.com/apache/datafusion/pull/15183) (PokIsemaine) +- Only unnest source for `EmptyRelation` [#15159](https://github.com/apache/datafusion/pull/15159) (blaginin) +- chore(deps): bump rust_decimal from 1.37.0 to 1.37.1 [#15378](https://github.com/apache/datafusion/pull/15378) (dependabot[bot]) +- chore(deps): bump chrono-tz from 0.10.1 to 0.10.2 [#15377](https://github.com/apache/datafusion/pull/15377) (dependabot[bot]) +- remove the duplicate test for unparser [#15385](https://github.com/apache/datafusion/pull/15385) (goldmedal) +- Minor: add average time for clickbench benchmark query [#15381](https://github.com/apache/datafusion/pull/15381) (zhuqi-lucas) +- include some BinaryOperator from sqlparser [#15327](https://github.com/apache/datafusion/pull/15327) (waynexia) +- Add "end to end parquet reading test" for WASM [#15362](https://github.com/apache/datafusion/pull/15362) (jsai28) +- Migrate physical plan tests to `insta` (Part-2) [#15364](https://github.com/apache/datafusion/pull/15364) (Shreyaskr1409) +- Migrate physical plan tests to `insta` (Part-3 / Final) [#15399](https://github.com/apache/datafusion/pull/15399) (Shreyaskr1409) +- Restore lazy evaluation of fallible CASE [#15390](https://github.com/apache/datafusion/pull/15390) (findepi) +- chore(deps): bump log from 0.4.26 to 0.4.27 [#15410](https://github.com/apache/datafusion/pull/15410) (dependabot[bot]) +- chore(deps): bump chrono-tz from 0.10.2 to 0.10.3 [#15412](https://github.com/apache/datafusion/pull/15412) (dependabot[bot]) +- Perf: Support Utf8View datatype single column comparisons for SortPreservingMergeStream [#15348](https://github.com/apache/datafusion/pull/15348) (zhuqi-lucas) +- Enforce JOIN plan to require condition [#15334](https://github.com/apache/datafusion/pull/15334) (goldmedal) +- Fix type coercion for unsigned and signed integers (`Int64` vs `UInt64`, etc) [#15341](https://github.com/apache/datafusion/pull/15341) (Omega359) +- simplify `array_has` UDF to `InList` expr when haystack is constant [#15354](https://github.com/apache/datafusion/pull/15354) (davidhewitt) +- Move `DataSink` to `datasource` and add session crate [#15371](https://github.com/apache/datafusion/pull/15371) (jayzhan-synnada) +- refactor: SpillManager into a separate file [#15407](https://github.com/apache/datafusion/pull/15407) (Weijun-H) +- Always use `PartitionMode::Auto` in planner [#15339](https://github.com/apache/datafusion/pull/15339) (Dandandan) +- Fix link to Volcano paper [#15437](https://github.com/apache/datafusion/pull/15437) (JackKelly) +- minor: Add new crates to labeler [#15426](https://github.com/apache/datafusion/pull/15426) (logan-keede) +- refactor: Use SpillManager for all spilling scenarios [#15405](https://github.com/apache/datafusion/pull/15405) (2010YOUY01) +- refactor(hash_join): Move JoinHashMap to separate mod [#15419](https://github.com/apache/datafusion/pull/15419) (ctsk) +- Migrate datasource tests to insta [#15258](https://github.com/apache/datafusion/pull/15258) (shruti2522) +- Add `downcast_to_source` method for `DataSourceExec` [#15416](https://github.com/apache/datafusion/pull/15416) (xudong963) +- refactor: use TypeSignature::Coercible for crypto functions [#14826](https://github.com/apache/datafusion/pull/14826) (Chen-Yuan-Lai) +- Minor: fix doc for `FileGroupPartitioner` [#15448](https://github.com/apache/datafusion/pull/15448) (xudong963) +- chore(deps): bump clap from 4.5.32 to 4.5.34 [#15452](https://github.com/apache/datafusion/pull/15452) (dependabot[bot]) +- Fix roundtrip bug with empty projection in DataSourceExec [#15449](https://github.com/apache/datafusion/pull/15449) (XiangpengHao) +- Triggering extended tests through PR comment: `Run extended tests` [#15101](https://github.com/apache/datafusion/pull/15101) (danila-b) +- Use `equals_datatype` to compare type when type coercion [#15366](https://github.com/apache/datafusion/pull/15366) (goldmedal) +- Fix no effect metrics bug in ParquetSource [#15460](https://github.com/apache/datafusion/pull/15460) (XiangpengHao) +- chore(deps): bump aws-config from 1.6.0 to 1.6.1 [#15470](https://github.com/apache/datafusion/pull/15470) (dependabot[bot]) +- minor: Allow to run TPCH bench for a specific query [#15467](https://github.com/apache/datafusion/pull/15467) (comphead) +- Migrate subtraits tests to insta, part1 [#15444](https://github.com/apache/datafusion/pull/15444) (qstommyshu) +- Add `FileScanConfigBuilder` [#15352](https://github.com/apache/datafusion/pull/15352) (blaginin) +- Update ClickBench queries to avoid to_timestamp_seconds [#15475](https://github.com/apache/datafusion/pull/15475) (acking-you) +- Remove CoalescePartitions insertion from HashJoinExec [#15476](https://github.com/apache/datafusion/pull/15476) (ctsk) +- Migrate-substrait-tests-to-insta, part2 [#15480](https://github.com/apache/datafusion/pull/15480) (qstommyshu) +- Revert #15476 to fix the datafusion-examples CI fail [#15496](https://github.com/apache/datafusion/pull/15496) (goldmedal) +- Migrate datafusion/sql tests to insta, part1 [#15497](https://github.com/apache/datafusion/pull/15497) (qstommyshu) +- Allow type coersion of zero input arrays to nullary [#15487](https://github.com/apache/datafusion/pull/15487) (timsaucer) +- Decimal type support for `to_timestamp` [#15486](https://github.com/apache/datafusion/pull/15486) (jatin510) +- refactor: Move `Memtable` to catalog [#15459](https://github.com/apache/datafusion/pull/15459) (logan-keede) +- Migrate optimizer tests to insta [#15446](https://github.com/apache/datafusion/pull/15446) (qstommyshu) +- FIX : some benchmarks are failing [#15367](https://github.com/apache/datafusion/pull/15367) (getChan) +- Add query to extended clickbench suite for "complex filter" [#15500](https://github.com/apache/datafusion/pull/15500) (acking-you) +- Extract tokio runtime creation from hot loop in benchmarks [#15508](https://github.com/apache/datafusion/pull/15508) (Omega359) +- chore(deps): bump blake3 from 1.7.0 to 1.8.0 [#15502](https://github.com/apache/datafusion/pull/15502) (dependabot[bot]) +- Minor: clone and debug for FileSinkConfig [#15516](https://github.com/apache/datafusion/pull/15516) (jayzhan211) +- use state machine to refactor the `get_files_with_limit` method [#15521](https://github.com/apache/datafusion/pull/15521) (xudong963) +- Migrate `datafusion/sql` tests to insta, part2 [#15499](https://github.com/apache/datafusion/pull/15499) (qstommyshu) +- Disable sccache action to fix gh cache issue [#15536](https://github.com/apache/datafusion/pull/15536) (Omega359) +- refactor: Cleanup unused `fetch` field inside `ExternalSorter` [#15525](https://github.com/apache/datafusion/pull/15525) (2010YOUY01) +- Fix duplicate unqualified Field name (schema error) on join queries [#15438](https://github.com/apache/datafusion/pull/15438) (LiaCastaneda) +- Add utf8view benchmark for aggregate topk [#15518](https://github.com/apache/datafusion/pull/15518) (zhuqi-lucas) +- ArraySort: support structs [#15527](https://github.com/apache/datafusion/pull/15527) (cht42) +- Migrate datafusion/sql tests to insta, part3 [#15533](https://github.com/apache/datafusion/pull/15533) (qstommyshu) +- Migrate datafusion/sql tests to insta, part4 [#15548](https://github.com/apache/datafusion/pull/15548) (qstommyshu) +- Add topk information into tree explain plans [#15547](https://github.com/apache/datafusion/pull/15547) (kumarlokesh) +- Minor: add Arc for statistics in FileGroup [#15564](https://github.com/apache/datafusion/pull/15564) (xudong963) +- Test: configuration fuzzer for (external) sort queries [#15501](https://github.com/apache/datafusion/pull/15501) (2010YOUY01) +- minor: Organize fields inside SortMergeJoinStream [#15557](https://github.com/apache/datafusion/pull/15557) (suibianwanwank) +- Minor: rm session downcast [#15575](https://github.com/apache/datafusion/pull/15575) (jayzhan211) +- Migrate datafusion/sql tests to insta, part5 [#15567](https://github.com/apache/datafusion/pull/15567) (qstommyshu) +- Add SQL logic tests for compound field access in JOIN conditions [#15556](https://github.com/apache/datafusion/pull/15556) (kosiew) +- Run audit CI check on all pushes to main [#15572](https://github.com/apache/datafusion/pull/15572) (alamb) +- Introduce load-balanced `split_groups_by_statistics` method [#15473](https://github.com/apache/datafusion/pull/15473) (xudong963) +- chore: update clickbench [#15574](https://github.com/apache/datafusion/pull/15574) (chenkovsky) +- Improve spill performance: Disable re-validation of spilled files [#15454](https://github.com/apache/datafusion/pull/15454) (zebsme) +- chore: rm duplicated `JoinOn` type [#15590](https://github.com/apache/datafusion/pull/15590) (jayzhan211) +- Chore: Call arrow's methods `row_count` and `skipped_row_count` [#15587](https://github.com/apache/datafusion/pull/15587) (jayzhan211) +- Actually run wasm test in ci [#15595](https://github.com/apache/datafusion/pull/15595) (XiangpengHao) +- Migrate datafusion/sql tests to insta, part6 [#15578](https://github.com/apache/datafusion/pull/15578) (qstommyshu) +- Add test case for new casting feature from date to tz-aware timestamps [#15609](https://github.com/apache/datafusion/pull/15609) (friendlymatthew) +- Remove CoalescePartitions insertion from Joins [#15570](https://github.com/apache/datafusion/pull/15570) (ctsk) +- fix doc and broken api [#15602](https://github.com/apache/datafusion/pull/15602) (logan-keede) +- Migrate datafusion/sql tests to insta, part7 [#15621](https://github.com/apache/datafusion/pull/15621) (qstommyshu) +- ignore security_audit CI check proc-macro-error warning [#15626](https://github.com/apache/datafusion/pull/15626) (Jiashu-Hu) +- chore(deps): bump tokio from 1.44.1 to 1.44.2 [#15627](https://github.com/apache/datafusion/pull/15627) (dependabot[bot]) +- Upgrade toolchain to Rust-1.86 [#15625](https://github.com/apache/datafusion/pull/15625) (jsai28) +- chore(deps): bump bigdecimal from 0.4.7 to 0.4.8 [#15523](https://github.com/apache/datafusion/pull/15523) (dependabot[bot]) +- chore(deps): bump the arrow-parquet group across 1 directory with 7 updates [#15593](https://github.com/apache/datafusion/pull/15593) (dependabot[bot]) +- chore: improve RepartitionExec display tree [#15606](https://github.com/apache/datafusion/pull/15606) (getChan) +- Move back schema not matching check and workaround [#15580](https://github.com/apache/datafusion/pull/15580) (LiaCastaneda) +- Minor: refine comments for statistics compution [#15647](https://github.com/apache/datafusion/pull/15647) (xudong963) +- Remove uneeded binary_op benchmarks [#15632](https://github.com/apache/datafusion/pull/15632) (alamb) +- chore(deps): bump blake3 from 1.8.0 to 1.8.1 [#15650](https://github.com/apache/datafusion/pull/15650) (dependabot[bot]) +- chore(deps): bump mimalloc from 0.1.44 to 0.1.46 [#15651](https://github.com/apache/datafusion/pull/15651) (dependabot[bot]) +- chore: avoid erroneuous warning for FFI table operation (only not default value) [#15579](https://github.com/apache/datafusion/pull/15579) (chenkovsky) +- Update datafusion-testing pin (to fix extended test on main) [#15655](https://github.com/apache/datafusion/pull/15655) (alamb) +- Ignore false positive only_used_in_recursion Clippy warning [#15635](https://github.com/apache/datafusion/pull/15635) (DerGut) +- chore: Rename protobuf Java package [#15658](https://github.com/apache/datafusion/pull/15658) (andygrove) +- Remove redundant `Precision` combination code in favor of `Precision::min/max/add` [#15659](https://github.com/apache/datafusion/pull/15659) (alamb) +- Introduce DynamicFilterSource and DynamicPhysicalExpr [#15568](https://github.com/apache/datafusion/pull/15568) (adriangb) +- Public some projected methods in `FileScanConfig` [#15671](https://github.com/apache/datafusion/pull/15671) (xudong963) +- fix decimal precision issue in simplify expression optimize rule [#15588](https://github.com/apache/datafusion/pull/15588) (jayzhan211) +- Implement Future for SpawnedTask. [#15653](https://github.com/apache/datafusion/pull/15653) (ashdnazg) +- chore(deps): bump crossbeam-channel from 0.5.14 to 0.5.15 [#15674](https://github.com/apache/datafusion/pull/15674) (dependabot[bot]) +- chore(deps): bump clap from 4.5.34 to 4.5.35 [#15668](https://github.com/apache/datafusion/pull/15668) (dependabot[bot]) +- [Minor] Use interleave_record_batch in TopK implementation [#15677](https://github.com/apache/datafusion/pull/15677) (Dandandan) +- Consolidate statistics merging code (try 2) [#15661](https://github.com/apache/datafusion/pull/15661) (alamb) +- Add Table Functions to FFI Crate [#15581](https://github.com/apache/datafusion/pull/15581) (timsaucer) +- Remove waits from blocking threads reading spill files. [#15654](https://github.com/apache/datafusion/pull/15654) (ashdnazg) +- chore(deps): bump sysinfo from 0.33.1 to 0.34.2 [#15682](https://github.com/apache/datafusion/pull/15682) (dependabot[bot]) +- Minor: add order by arg for last value [#15695](https://github.com/apache/datafusion/pull/15695) (jayzhan211) +- Upgrade to arrow/parquet 55, and `object_store` to `0.12.0` and pyo3 to `0.24.0` [#15466](https://github.com/apache/datafusion/pull/15466) (alamb) +- tests: only refresh the minimum sysinfo in mem limit tests. [#15702](https://github.com/apache/datafusion/pull/15702) (ashdnazg) +- ci: fix workflow triggering extended tests from pr comments. [#15704](https://github.com/apache/datafusion/pull/15704) (ashdnazg) +- chore(deps): bump flate2 from 1.1.0 to 1.1.1 [#15703](https://github.com/apache/datafusion/pull/15703) (dependabot[bot]) +- Fix internal error in sort when hitting memory limit [#15692](https://github.com/apache/datafusion/pull/15692) (DerGut) +- Update checked in Cargo.lock file to get clean CI [#15725](https://github.com/apache/datafusion/pull/15725) (alamb) +- chore(deps): bump indexmap from 2.8.0 to 2.9.0 [#15732](https://github.com/apache/datafusion/pull/15732) (dependabot[bot]) +- Minor: include output partition count of `RepartitionExec` to tree explain [#15717](https://github.com/apache/datafusion/pull/15717) (2010YOUY01) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 48 dependabot[bot] + 34 Andrew Lamb + 16 xudong.w + 15 Jay Zhan + 15 Qi Zhu + 15 irenjj + 13 Chen Chongchen + 13 Yongting You + 10 Tommy shu + 7 Shruti Sharma + 6 Alan Tang + 6 Arttu + 6 Jiashu Hu + 6 Shreyas (Lua) + 6 logan-keede + 6 zeb + 5 Dmitrii Blaginin + 5 Geoffrey Claude + 5 Jax Liu + 5 YuNing Chen + 4 Bruce Ritchie + 4 Christian + 4 Eshed Schacham + 4 Xiangpeng Hao + 4 wiedld + 3 Adrian Garcia Badaracco + 3 Daniël Heres + 3 Gabriel + 3 LB7666 + 3 Namgung Chan + 3 Ruihang Xia + 3 Tim Saucer + 3 jsai28 + 3 kosiew + 3 suibianwanwan + 2 Bryce Mecum + 2 Carol (Nichols || Goulding) + 2 Heran Lin + 2 Jannik Steinmann + 2 Jyotir Sai + 2 Li-Lun Lin + 2 Lía Adriana + 2 Oleks V + 2 Raz Luvaton + 2 UBarney + 2 aditya singh rathore + 2 westhide + 2 zjregee + 1 @clflushopt + 1 Adam Gutglick + 1 Alex Huang + 1 Alex Wilcoxson + 1 Amos Aidoo + 1 Andy Grove + 1 Andy Yen + 1 Berkay Şahin + 1 Chang + 1 Danila Baklazhenko + 1 David Hewitt + 1 Emily Matheys + 1 Eren Avsarogullari + 1 Hari Varsha + 1 Ian Lai + 1 Jack Kelly + 1 Jagdish Parihar + 1 Joseph Koshakow + 1 Lokesh + 1 LuQQiu + 1 Matt Butrovich + 1 Matt Friede + 1 Matthew Kim + 1 Matthijs Brobbel + 1 Om Kenge + 1 Owen Leung + 1 Peter L + 1 Piotr Findeisen + 1 Rohan Krishnaswamy + 1 Satyam018 + 1 Sava Vranešević + 1 Siddhartha Sahu + 1 Sile Zhou + 1 Vadim Piven + 1 Zaki + 1 christophermcdermott + 1 cht42 + 1 cjw + 1 delamarch3 + 1 ding-young + 1 haruband + 1 jamxia155 + 1 oznur-synnada + 1 peasee + 1 pranavJibhakate + 1 张林伟 +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/changelog/48.0.0.md b/dev/changelog/48.0.0.md new file mode 100644 index 000000000000..9cf6c03b7acf --- /dev/null +++ b/dev/changelog/48.0.0.md @@ -0,0 +1,405 @@ + + +# Apache DataFusion 48.0.0 Changelog + +This release consists of 267 commits from 89 contributors. See credits at the end of this changelog for more information. + +**Breaking changes:** + +- Attach Diagnostic to syntax errors [#15680](https://github.com/apache/datafusion/pull/15680) (logan-keede) +- Change `flatten` so it does only a level, not recursively [#15160](https://github.com/apache/datafusion/pull/15160) (delamarch3) +- Improve `simplify_expressions` rule [#15735](https://github.com/apache/datafusion/pull/15735) (xudong963) +- Support WITHIN GROUP syntax to standardize certain existing aggregate functions [#13511](https://github.com/apache/datafusion/pull/13511) (Garamda) +- Add Extension Type / Metadata support for Scalar UDFs [#15646](https://github.com/apache/datafusion/pull/15646) (timsaucer) +- chore: fix clippy::large_enum_variant for DataFusionError [#15861](https://github.com/apache/datafusion/pull/15861) (rroelke) +- Feat: introduce `ExecutionPlan::partition_statistics` API [#15852](https://github.com/apache/datafusion/pull/15852) (xudong963) +- refactor: remove deprecated `ParquetExec` [#15973](https://github.com/apache/datafusion/pull/15973) (miroim) +- refactor: remove deprecated `ArrowExec` [#16006](https://github.com/apache/datafusion/pull/16006) (miroim) +- refactor: remove deprecated `MemoryExec` [#16007](https://github.com/apache/datafusion/pull/16007) (miroim) +- refactor: remove deprecated `JsonExec` [#16005](https://github.com/apache/datafusion/pull/16005) (miroim) +- feat: metadata handling for aggregates and window functions [#15911](https://github.com/apache/datafusion/pull/15911) (timsaucer) +- Remove `Filter::having` field [#16154](https://github.com/apache/datafusion/pull/16154) (findepi) +- Shift from Field to FieldRef for all user defined functions [#16122](https://github.com/apache/datafusion/pull/16122) (timsaucer) +- Change default SQL mapping for `VARCAHR` from `Utf8` to `Utf8View` [#16142](https://github.com/apache/datafusion/pull/16142) (zhuqi-lucas) +- Minor: remove unused IPCWriter [#16215](https://github.com/apache/datafusion/pull/16215) (alamb) +- Reduce size of `Expr` struct [#16207](https://github.com/apache/datafusion/pull/16207) (hendrikmakait) + +**Performance related:** + +- Apply pre-selection and computation skipping to short-circuit optimization [#15694](https://github.com/apache/datafusion/pull/15694) (acking-you) +- Add a fast path for `optimize_projection` [#15746](https://github.com/apache/datafusion/pull/15746) (xudong963) +- Speed up `optimize_projection` by improving `is_projection_unnecessary` [#15761](https://github.com/apache/datafusion/pull/15761) (xudong963) +- Speed up `optimize_projection` [#15787](https://github.com/apache/datafusion/pull/15787) (xudong963) +- Support `GroupsAccumulator` for Avg duration [#15748](https://github.com/apache/datafusion/pull/15748) (shruti2522) +- Optimize performance of `string::ascii` function [#16087](https://github.com/apache/datafusion/pull/16087) (tlm365) + +**Implemented enhancements:** + +- Set DataFusion runtime configurations through SQL interface [#15594](https://github.com/apache/datafusion/pull/15594) (kumarlokesh) +- feat: Add option to adjust writer buffer size for query output [#15747](https://github.com/apache/datafusion/pull/15747) (m09526) +- feat: Add `datafusion-spark` crate [#15168](https://github.com/apache/datafusion/pull/15168) (shehabgamin) +- feat: create helpers to set the max_temp_directory_size [#15919](https://github.com/apache/datafusion/pull/15919) (jdrouet) +- feat: ORDER BY ALL [#15772](https://github.com/apache/datafusion/pull/15772) (PokIsemaine) +- feat: support min/max for struct [#15667](https://github.com/apache/datafusion/pull/15667) (chenkovsky) +- feat(proto): udf decoding fallback [#15997](https://github.com/apache/datafusion/pull/15997) (leoyvens) +- feat: make error handling in indent explain consistent with that in tree [#16097](https://github.com/apache/datafusion/pull/16097) (chenkovsky) +- feat: coerce to/from fixed size binary to binary view [#16110](https://github.com/apache/datafusion/pull/16110) (chenkovsky) +- feat: array_length for fixed size list [#16167](https://github.com/apache/datafusion/pull/16167) (chenkovsky) +- feat: ADD sha2 spark function [#16168](https://github.com/apache/datafusion/pull/16168) (getChan) +- feat: create builder for disk manager [#16191](https://github.com/apache/datafusion/pull/16191) (jdrouet) +- feat: Add Aggregate UDF to FFI crate [#14775](https://github.com/apache/datafusion/pull/14775) (timsaucer) +- feat(small): Add `BaselineMetrics` to `generate_series()` table function [#16255](https://github.com/apache/datafusion/pull/16255) (2010YOUY01) +- feat: Add Window UDFs to FFI Crate [#16261](https://github.com/apache/datafusion/pull/16261) (timsaucer) + +**Fixed bugs:** + +- fix: serialize listing table without partition column [#15737](https://github.com/apache/datafusion/pull/15737) (chenkovsky) +- fix: describe Parquet schema with coerce_int96 [#15750](https://github.com/apache/datafusion/pull/15750) (chenkovsky) +- fix: clickbench type err [#15773](https://github.com/apache/datafusion/pull/15773) (chenkovsky) +- Fix: fetch is missing in `replace_order_preserving_variants` method during `EnforceDistribution` optimizer [#15808](https://github.com/apache/datafusion/pull/15808) (xudong963) +- Fix: fetch is missing in `EnforceSorting` optimizer (two places) [#15822](https://github.com/apache/datafusion/pull/15822) (xudong963) +- fix: Avoid mistaken ILike to string equality optimization [#15836](https://github.com/apache/datafusion/pull/15836) (srh) +- Map file-level column statistics to the table-level [#15865](https://github.com/apache/datafusion/pull/15865) (xudong963) +- fix(avro): Respect projection order in Avro reader [#15840](https://github.com/apache/datafusion/pull/15840) (nantunes) +- fix: correctly specify the nullability of `map_values` return type [#15901](https://github.com/apache/datafusion/pull/15901) (rluvaton) +- Fix CI in main [#15917](https://github.com/apache/datafusion/pull/15917) (blaginin) +- fix: sqllogictest on Windows [#15932](https://github.com/apache/datafusion/pull/15932) (nuno-faria) +- fix: fold cast null to substrait typed null [#15854](https://github.com/apache/datafusion/pull/15854) (discord9) +- Fix: `build_predicate_expression` method doesn't process `false` expr correctly [#15995](https://github.com/apache/datafusion/pull/15995) (xudong963) +- fix: add an "expr_planners" method to SessionState [#15119](https://github.com/apache/datafusion/pull/15119) (niebayes) +- fix: overcounting of memory in first/last. [#15924](https://github.com/apache/datafusion/pull/15924) (ashdnazg) +- fix: track timing for coalescer's in execution time [#16048](https://github.com/apache/datafusion/pull/16048) (waynexia) +- fix: stack overflow for substrait functions with large argument lists that translate to DataFusion binary operators [#16031](https://github.com/apache/datafusion/pull/16031) (fmonjalet) +- fix: coerce int96 resolution inside of list, struct, and map types [#16058](https://github.com/apache/datafusion/pull/16058) (mbutrovich) +- fix: Add coercion rules for Float16 types [#15816](https://github.com/apache/datafusion/pull/15816) (etseidl) +- fix: describe escaped quoted identifiers [#16082](https://github.com/apache/datafusion/pull/16082) (jfahne) +- fix: Remove trailing whitespace in `Display` for `LogicalPlan::Projection` [#16164](https://github.com/apache/datafusion/pull/16164) (atahanyorganci) +- fix: metadata of join schema [#16221](https://github.com/apache/datafusion/pull/16221) (chenkovsky) +- fix: add missing row count limits to TPC-H queries [#16230](https://github.com/apache/datafusion/pull/16230) (0ax1) +- fix: NaN semantics in GROUP BY [#16256](https://github.com/apache/datafusion/pull/16256) (chenkovsky) + +**Documentation updates:** + +- Add DataFusion 47.0.0 Upgrade Guide [#15749](https://github.com/apache/datafusion/pull/15749) (alamb) +- Improve documentation for format `OPTIONS` clause [#15708](https://github.com/apache/datafusion/pull/15708) (marvelshan) +- doc: Adding Feldera as known user [#15799](https://github.com/apache/datafusion/pull/15799) (comphead) +- docs: add ArkFlow [#15826](https://github.com/apache/datafusion/pull/15826) (chenquan) +- Fix `from_unixtime` function documentation [#15844](https://github.com/apache/datafusion/pull/15844) (Viicos) +- Upgrade-guide: Downgrade "FileScanConfig –> FileScanConfigBuilder" headline [#15883](https://github.com/apache/datafusion/pull/15883) (simonvandel) +- doc: Update known users docs [#15895](https://github.com/apache/datafusion/pull/15895) (comphead) +- Add `union_tag` scalar function [#14687](https://github.com/apache/datafusion/pull/14687) (gstvg) +- Fix typo in introduction.md [#15910](https://github.com/apache/datafusion/pull/15910) (tom-mont) +- Add `FormatOptions` to Config [#15793](https://github.com/apache/datafusion/pull/15793) (blaginin) +- docs: Label `bloom_filter_on_read` as a reading config [#15933](https://github.com/apache/datafusion/pull/15933) (nuno-faria) +- Implement Parquet filter pushdown via new filter pushdown APIs [#15769](https://github.com/apache/datafusion/pull/15769) (adriangb) +- Enable repartitioning on MemTable. [#15409](https://github.com/apache/datafusion/pull/15409) (wiedld) +- Updated extending operators documentation [#15612](https://github.com/apache/datafusion/pull/15612) (the0ninjas) +- chore: Replace MSRV link on main page with Github badge [#16020](https://github.com/apache/datafusion/pull/16020) (comphead) +- Add note to upgrade guide for removal of `ParquetExec`, `AvroExec`, `CsvExec`, `JsonExec` [#16034](https://github.com/apache/datafusion/pull/16034) (alamb) +- docs: Clarify that it is only the name of the field that is ignored [#16052](https://github.com/apache/datafusion/pull/16052) (alamb) +- [Docs]: Added SQL example for all window functions [#16074](https://github.com/apache/datafusion/pull/16074) (Adez017) +- Fix CI on main: Add window function examples in code [#16102](https://github.com/apache/datafusion/pull/16102) (alamb) +- chore: Remove SMJ experimental status in docs [#16072](https://github.com/apache/datafusion/pull/16072) (comphead) +- doc: fix indent format explain [#16085](https://github.com/apache/datafusion/pull/16085) (chenkovsky) +- Update documentation for `datafusion.execution.collect_statistics` [#16100](https://github.com/apache/datafusion/pull/16100) (alamb) +- Make `SessionContext::register_parquet` obey `collect_statistics` config [#16080](https://github.com/apache/datafusion/pull/16080) (adriangb) +- Improve the DML / DDL Documentation [#16115](https://github.com/apache/datafusion/pull/16115) (alamb) +- docs: Fix typos and minor grammatical issues in Architecture docs [#16119](https://github.com/apache/datafusion/pull/16119) (patrickcsullivan) +- Set `TrackConsumersPool` as default in datafusion-cli [#16081](https://github.com/apache/datafusion/pull/16081) (ding-young) +- Minor: Fix links in substrait readme [#16156](https://github.com/apache/datafusion/pull/16156) (alamb) +- Add macro for creating DataFrame (#16090) [#16104](https://github.com/apache/datafusion/pull/16104) (cj-zhukov) +- doc: Move `dataframe!` example into dedicated example [#16197](https://github.com/apache/datafusion/pull/16197) (comphead) +- doc: add diagram to describe how DataSource, FileSource, and DataSourceExec are related [#16181](https://github.com/apache/datafusion/pull/16181) (onlyjackfrost) +- Clarify documentation about gathering statistics for parquet files [#16157](https://github.com/apache/datafusion/pull/16157) (alamb) +- Add change to VARCHAR in the upgrade guide [#16216](https://github.com/apache/datafusion/pull/16216) (alamb) +- Add iceberg-rust to user list [#16246](https://github.com/apache/datafusion/pull/16246) (jonathanc-n) +- Prepare for 48.0.0 release: Version and Changelog [#16238](https://github.com/apache/datafusion/pull/16238) (xudong963) + +**Other:** + +- Enable setting default values for target_partitions and planning_concurrency [#15712](https://github.com/apache/datafusion/pull/15712) (nuno-faria) +- minor: fix doc comment [#15733](https://github.com/apache/datafusion/pull/15733) (niebayes) +- chore(deps-dev): bump http-proxy-middleware from 2.0.6 to 2.0.9 in /datafusion/wasmtest/datafusion-wasm-app [#15738](https://github.com/apache/datafusion/pull/15738) (dependabot[bot]) +- Avoid computing unnecessary statstics [#15729](https://github.com/apache/datafusion/pull/15729) (xudong963) +- chore(deps): bump libc from 0.2.171 to 0.2.172 [#15745](https://github.com/apache/datafusion/pull/15745) (dependabot[bot]) +- Final release note touchups [#15741](https://github.com/apache/datafusion/pull/15741) (alamb) +- Refactor regexp slt tests [#15709](https://github.com/apache/datafusion/pull/15709) (kumarlokesh) +- ExecutionPlan: add APIs for filter pushdown & optimizer rule to apply them [#15566](https://github.com/apache/datafusion/pull/15566) (adriangb) +- Coerce and simplify FixedSizeBinary equality to literal binary [#15726](https://github.com/apache/datafusion/pull/15726) (leoyvens) +- Minor: simplify code in datafusion-proto [#15752](https://github.com/apache/datafusion/pull/15752) (alamb) +- chore(deps): bump clap from 4.5.35 to 4.5.36 [#15759](https://github.com/apache/datafusion/pull/15759) (dependabot[bot]) +- Support `Accumulator` for avg duration [#15468](https://github.com/apache/datafusion/pull/15468) (shruti2522) +- Show current SQL recursion limit in RecursionLimitExceeded error message [#15644](https://github.com/apache/datafusion/pull/15644) (kumarlokesh) +- Minor: fix flaky test in `aggregate.slt` [#15786](https://github.com/apache/datafusion/pull/15786) (xudong963) +- Minor: remove unused logic for limit pushdown [#15730](https://github.com/apache/datafusion/pull/15730) (zhuqi-lucas) +- chore(deps): bump sqllogictest from 0.28.0 to 0.28.1 [#15788](https://github.com/apache/datafusion/pull/15788) (dependabot[bot]) +- Add try_new for LogicalPlan::Join [#15757](https://github.com/apache/datafusion/pull/15757) (kumarlokesh) +- Minor: eliminate unnecessary struct creation in session state build [#15800](https://github.com/apache/datafusion/pull/15800) (Rachelint) +- chore(deps): bump half from 2.5.0 to 2.6.0 [#15806](https://github.com/apache/datafusion/pull/15806) (dependabot[bot]) +- Add `or_fun_call` and `unnecessary_lazy_evaluations` lints on `core` [#15807](https://github.com/apache/datafusion/pull/15807) (Rachelint) +- chore(deps): bump env_logger from 0.11.7 to 0.11.8 [#15823](https://github.com/apache/datafusion/pull/15823) (dependabot[bot]) +- Support unparsing `UNION` for distinct results [#15814](https://github.com/apache/datafusion/pull/15814) (phillipleblanc) +- Add `MemoryPool::memory_limit` to expose setting memory usage limit [#15828](https://github.com/apache/datafusion/pull/15828) (Rachelint) +- Preserve projection for inline scan [#15825](https://github.com/apache/datafusion/pull/15825) (jayzhan211) +- Minor: cleanup hash table after emit all [#15834](https://github.com/apache/datafusion/pull/15834) (jayzhan211) +- chore(deps): bump pyo3 from 0.24.1 to 0.24.2 [#15838](https://github.com/apache/datafusion/pull/15838) (dependabot[bot]) +- Minor: fix potential flaky test in aggregate.slt [#15829](https://github.com/apache/datafusion/pull/15829) (bikbov) +- Fix `ILIKE` expression support in SQL unparser [#15820](https://github.com/apache/datafusion/pull/15820) (ewgenius) +- Make `Diagnostic` easy/convinient to attach by using macro and avoiding `map_err` [#15796](https://github.com/apache/datafusion/pull/15796) (logan-keede) +- Feature/benchmark config from env [#15782](https://github.com/apache/datafusion/pull/15782) (ctsk) +- predicate pruning: support cast and try_cast for more types [#15764](https://github.com/apache/datafusion/pull/15764) (adriangb) +- Fix: fetch is missing in `plan_with_order_breaking_variants` method [#15842](https://github.com/apache/datafusion/pull/15842) (xudong963) +- Fix `CoalescePartitionsExec` proto serialization [#15824](https://github.com/apache/datafusion/pull/15824) (lewiszlw) +- Fix build failure caused by new `CoalescePartitionsExec::with_fetch` method [#15849](https://github.com/apache/datafusion/pull/15849) (lewiszlw) +- Fix ScalarValue::List comparison when the compared lists have different lengths [#15856](https://github.com/apache/datafusion/pull/15856) (gabotechs) +- chore: More details to `No UDF registered` error [#15843](https://github.com/apache/datafusion/pull/15843) (comphead) +- chore(deps): bump clap from 4.5.36 to 4.5.37 [#15853](https://github.com/apache/datafusion/pull/15853) (dependabot[bot]) +- Remove usage of `dbg!` [#15858](https://github.com/apache/datafusion/pull/15858) (phillipleblanc) +- Minor: Interval singleton [#15859](https://github.com/apache/datafusion/pull/15859) (jayzhan211) +- Make aggr fuzzer query builder more configurable [#15851](https://github.com/apache/datafusion/pull/15851) (Rachelint) +- chore(deps): bump aws-config from 1.6.1 to 1.6.2 [#15874](https://github.com/apache/datafusion/pull/15874) (dependabot[bot]) +- Add slt tests for `datafusion.execution.parquet.coerce_int96` setting [#15723](https://github.com/apache/datafusion/pull/15723) (alamb) +- Improve `ListingTable` / `ListingTableOptions` docs [#15767](https://github.com/apache/datafusion/pull/15767) (alamb) +- Migrate Optimizer tests to insta, part2 [#15884](https://github.com/apache/datafusion/pull/15884) (qstommyshu) +- Improve documentation for `FileSource`, `DataSource` and `DataSourceExec` [#15766](https://github.com/apache/datafusion/pull/15766) (alamb) +- Implement min max for dictionary types [#15827](https://github.com/apache/datafusion/pull/15827) (XiangpengHao) +- chore(deps): bump blake3 from 1.8.1 to 1.8.2 [#15890](https://github.com/apache/datafusion/pull/15890) (dependabot[bot]) +- Respect ignore_nulls in array_agg [#15544](https://github.com/apache/datafusion/pull/15544) (joroKr21) +- Set HashJoin seed [#15783](https://github.com/apache/datafusion/pull/15783) (ctsk) +- Saner handling of nulls inside arrays [#15149](https://github.com/apache/datafusion/pull/15149) (joroKr21) +- Keeping pull request in sync with the base branch [#15894](https://github.com/apache/datafusion/pull/15894) (xudong963) +- Fix `flatten` scalar function when inner list is `FixedSizeList` [#15898](https://github.com/apache/datafusion/pull/15898) (gstvg) +- support OR operator in binary `evaluate_bounds` [#15716](https://github.com/apache/datafusion/pull/15716) (davidhewitt) +- infer placeholder datatype for IN lists [#15864](https://github.com/apache/datafusion/pull/15864) (kczimm) +- Fix allow_update_branch [#15904](https://github.com/apache/datafusion/pull/15904) (xudong963) +- chore(deps): bump tokio from 1.44.1 to 1.44.2 [#15900](https://github.com/apache/datafusion/pull/15900) (dependabot[bot]) +- chore(deps): bump assert_cmd from 2.0.16 to 2.0.17 [#15909](https://github.com/apache/datafusion/pull/15909) (dependabot[bot]) +- Factor out Substrait consumers into separate files [#15794](https://github.com/apache/datafusion/pull/15794) (gabotechs) +- Unparse `UNNEST` projection with the table column alias [#15879](https://github.com/apache/datafusion/pull/15879) (goldmedal) +- Migrate Optimizer tests to insta, part3 [#15893](https://github.com/apache/datafusion/pull/15893) (qstommyshu) +- Minor: cleanup datafusion-spark scalar functions [#15921](https://github.com/apache/datafusion/pull/15921) (alamb) +- Fix ClickBench extended queries after update to APPROX_PERCENTILE_CONT [#15929](https://github.com/apache/datafusion/pull/15929) (alamb) +- Add extended query for checking improvement for blocked groups optimization [#15936](https://github.com/apache/datafusion/pull/15936) (Rachelint) +- Speedup `character_length` [#15931](https://github.com/apache/datafusion/pull/15931) (Dandandan) +- chore(deps): bump tokio-util from 0.7.14 to 0.7.15 [#15918](https://github.com/apache/datafusion/pull/15918) (dependabot[bot]) +- Migrate Optimizer tests to insta, part4 [#15937](https://github.com/apache/datafusion/pull/15937) (qstommyshu) +- fix query results for predicates referencing partition columns and data columns [#15935](https://github.com/apache/datafusion/pull/15935) (adriangb) +- chore(deps): bump substrait from 0.55.0 to 0.55.1 [#15941](https://github.com/apache/datafusion/pull/15941) (dependabot[bot]) +- Fix main CI by adding `rowsort` to slt test [#15942](https://github.com/apache/datafusion/pull/15942) (xudong963) +- Improve sqllogictest error reporting [#15905](https://github.com/apache/datafusion/pull/15905) (gabotechs) +- refactor filter pushdown apis [#15801](https://github.com/apache/datafusion/pull/15801) (adriangb) +- Add additional tests for filter pushdown apis [#15955](https://github.com/apache/datafusion/pull/15955) (adriangb) +- Improve filter pushdown optimizer rule performance [#15959](https://github.com/apache/datafusion/pull/15959) (adriangb) +- Reduce rehashing cost for primitive grouping by also reusing hash value [#15962](https://github.com/apache/datafusion/pull/15962) (Rachelint) +- chore(deps): bump chrono from 0.4.40 to 0.4.41 [#15956](https://github.com/apache/datafusion/pull/15956) (dependabot[bot]) +- refactor: replace `unwrap_or` with `unwrap_or_else` for improved lazy… [#15841](https://github.com/apache/datafusion/pull/15841) (NevroHelios) +- add benchmark code for `Reuse rows in row cursor stream` [#15913](https://github.com/apache/datafusion/pull/15913) (acking-you) +- [Update] : Removal of duplicate CI jobs [#15966](https://github.com/apache/datafusion/pull/15966) (Adez017) +- Segfault in ByteGroupValueBuilder [#15968](https://github.com/apache/datafusion/pull/15968) (thinkharderdev) +- make can_expr_be_pushed_down_with_schemas public again [#15971](https://github.com/apache/datafusion/pull/15971) (adriangb) +- re-export can_expr_be_pushed_down_with_schemas to be public [#15974](https://github.com/apache/datafusion/pull/15974) (adriangb) +- Migrate Optimizer tests to insta, part5 [#15945](https://github.com/apache/datafusion/pull/15945) (qstommyshu) +- Show LogicalType name for `INFORMATION_SCHEMA` [#15965](https://github.com/apache/datafusion/pull/15965) (goldmedal) +- chore(deps): bump sha2 from 0.10.8 to 0.10.9 [#15970](https://github.com/apache/datafusion/pull/15970) (dependabot[bot]) +- chore(deps): bump insta from 1.42.2 to 1.43.1 [#15988](https://github.com/apache/datafusion/pull/15988) (dependabot[bot]) +- [datafusion-spark] Add Spark-compatible hex function [#15947](https://github.com/apache/datafusion/pull/15947) (andygrove) +- refactor: remove deprecated `AvroExec` [#15987](https://github.com/apache/datafusion/pull/15987) (miroim) +- Substrait: Handle inner map fields in schema renaming [#15869](https://github.com/apache/datafusion/pull/15869) (cht42) +- refactor: remove deprecated `CsvExec` [#15991](https://github.com/apache/datafusion/pull/15991) (miroim) +- Migrate Optimizer tests to insta, part6 [#15984](https://github.com/apache/datafusion/pull/15984) (qstommyshu) +- chore(deps): bump nix from 0.29.0 to 0.30.1 [#16002](https://github.com/apache/datafusion/pull/16002) (dependabot[bot]) +- Implement RightSemi join for SortMergeJoin [#15972](https://github.com/apache/datafusion/pull/15972) (irenjj) +- Migrate Optimizer tests to insta, part7 [#16010](https://github.com/apache/datafusion/pull/16010) (qstommyshu) +- chore(deps): bump sysinfo from 0.34.2 to 0.35.1 [#16027](https://github.com/apache/datafusion/pull/16027) (dependabot[bot]) +- refactor: move `should_enable_page_index` from `mod.rs` to `opener.rs` [#16026](https://github.com/apache/datafusion/pull/16026) (miroim) +- chore(deps): bump sqllogictest from 0.28.1 to 0.28.2 [#16037](https://github.com/apache/datafusion/pull/16037) (dependabot[bot]) +- chores: Add lint rule to enforce string formatting style [#16024](https://github.com/apache/datafusion/pull/16024) (Lordworms) +- Use human-readable byte sizes in `EXPLAIN` [#16043](https://github.com/apache/datafusion/pull/16043) (tlm365) +- Docs: Add example of creating a field in `return_field_from_args` [#16039](https://github.com/apache/datafusion/pull/16039) (alamb) +- Support `MIN` and `MAX` for `DataType::List` [#16025](https://github.com/apache/datafusion/pull/16025) (gabotechs) +- Improve docs for Exprs and scalar functions [#16036](https://github.com/apache/datafusion/pull/16036) (alamb) +- Add h2o window benchmark [#16003](https://github.com/apache/datafusion/pull/16003) (2010YOUY01) +- Fix Infer prepare statement type tests [#15743](https://github.com/apache/datafusion/pull/15743) (brayanjuls) +- style: simplify some strings for readability [#15999](https://github.com/apache/datafusion/pull/15999) (hamirmahal) +- support simple/cross lateral joins [#16015](https://github.com/apache/datafusion/pull/16015) (jayzhan211) +- Improve error message on Out of Memory [#16050](https://github.com/apache/datafusion/pull/16050) (ding-young) +- chore(deps): bump the arrow-parquet group with 7 updates [#16047](https://github.com/apache/datafusion/pull/16047) (dependabot[bot]) +- chore(deps): bump petgraph from 0.7.1 to 0.8.1 [#15669](https://github.com/apache/datafusion/pull/15669) (dependabot[bot]) +- [datafusion-spark] Add Spark-compatible `char` expression [#15994](https://github.com/apache/datafusion/pull/15994) (andygrove) +- chore(deps): bump substrait from 0.55.1 to 0.56.0 [#16091](https://github.com/apache/datafusion/pull/16091) (dependabot[bot]) +- Add test that demonstrate behavior for `collect_statistics` [#16098](https://github.com/apache/datafusion/pull/16098) (alamb) +- Refactor substrait producer into multiple files [#16089](https://github.com/apache/datafusion/pull/16089) (gabotechs) +- Fix temp dir leak in tests [#16094](https://github.com/apache/datafusion/pull/16094) (findepi) +- Label Spark functions PRs with spark label [#16095](https://github.com/apache/datafusion/pull/16095) (findepi) +- Added SLT tests for IMDB benchmark queries [#16067](https://github.com/apache/datafusion/pull/16067) (kumarlokesh) +- chore(CI) Upgrade toolchain to Rust-1.87 [#16068](https://github.com/apache/datafusion/pull/16068) (kadai0308) +- minor: Add benchmark query and corresponding documentation for Average Duration [#16105](https://github.com/apache/datafusion/pull/16105) (logan-keede) +- Use qualified names on DELETE selections [#16033](https://github.com/apache/datafusion/pull/16033) (nuno-faria) +- chore(deps): bump testcontainers from 0.23.3 to 0.24.0 [#15989](https://github.com/apache/datafusion/pull/15989) (dependabot[bot]) +- Clean up ExternalSorter and use upstream kernel [#16109](https://github.com/apache/datafusion/pull/16109) (alamb) +- Test Duration in aggregation `fuzz` tests [#16111](https://github.com/apache/datafusion/pull/16111) (alamb) +- Move PruningStatistics into datafusion::common [#16069](https://github.com/apache/datafusion/pull/16069) (adriangb) +- Revert use file schema in parquet pruning [#16086](https://github.com/apache/datafusion/pull/16086) (adriangb) +- Minor: Add `ScalarFunctionArgs::return_type` method [#16113](https://github.com/apache/datafusion/pull/16113) (alamb) +- Fix `contains` function expression [#16046](https://github.com/apache/datafusion/pull/16046) (liamzwbao) +- chore: Use materialized data for filter pushdown tests [#16123](https://github.com/apache/datafusion/pull/16123) (comphead) +- chore: Upgrade rand crate and some other minor crates [#16062](https://github.com/apache/datafusion/pull/16062) (comphead) +- Include data types in logical plans of inferred prepare statements [#16019](https://github.com/apache/datafusion/pull/16019) (brayanjuls) +- CI: Fix extended test failure [#16144](https://github.com/apache/datafusion/pull/16144) (2010YOUY01) +- Fix: handle column name collisions when combining UNION logical inputs & nested Column expressions in maybe_fix_physical_column_name [#16064](https://github.com/apache/datafusion/pull/16064) (LiaCastaneda) +- adding support for Min/Max over LargeList and FixedSizeList [#16071](https://github.com/apache/datafusion/pull/16071) (logan-keede) +- Move prepare/parameter handling tests into `params.rs` [#16141](https://github.com/apache/datafusion/pull/16141) (liamzwbao) +- Minor: Add `Accumulator::return_type` and `StateFieldsArgs::return_type` to help with upgrade to 48 [#16112](https://github.com/apache/datafusion/pull/16112) (alamb) +- Support filtering specific sqllogictests identified by line number [#16029](https://github.com/apache/datafusion/pull/16029) (gabotechs) +- Enrich GroupedHashAggregateStream name to ease debugging Resources exhausted errors [#16152](https://github.com/apache/datafusion/pull/16152) (ahmed-mez) +- chore(deps): bump uuid from 1.16.0 to 1.17.0 [#16162](https://github.com/apache/datafusion/pull/16162) (dependabot[bot]) +- Clarify docs and names in parquet predicate pushdown tests [#16155](https://github.com/apache/datafusion/pull/16155) (alamb) +- Minor: Fix name() for FilterPushdown physical optimizer rule [#16175](https://github.com/apache/datafusion/pull/16175) (adriangb) +- migrate tests in `pool.rs` to use insta [#16145](https://github.com/apache/datafusion/pull/16145) (lifan-ake) +- refactor(optimizer): Add support for dynamically adding test tables [#16138](https://github.com/apache/datafusion/pull/16138) (atahanyorganci) +- [Minor] Speedup TPC-H benchmark run with memtable option [#16159](https://github.com/apache/datafusion/pull/16159) (Dandandan) +- Fast path for joins with distinct values in build side [#16153](https://github.com/apache/datafusion/pull/16153) (Dandandan) +- chore: Reduce repetition in the parameter type inference tests [#16079](https://github.com/apache/datafusion/pull/16079) (jsai28) +- chore(deps): bump tokio from 1.45.0 to 1.45.1 [#16190](https://github.com/apache/datafusion/pull/16190) (dependabot[bot]) +- Improve `unproject_sort_expr` to handle arbitrary expressions [#16127](https://github.com/apache/datafusion/pull/16127) (phillipleblanc) +- chore(deps): bump rustyline from 15.0.0 to 16.0.0 [#16194](https://github.com/apache/datafusion/pull/16194) (dependabot[bot]) +- migrate `logical_plan` tests to insta [#16184](https://github.com/apache/datafusion/pull/16184) (lifan-ake) +- chore(deps): bump clap from 4.5.38 to 4.5.39 [#16204](https://github.com/apache/datafusion/pull/16204) (dependabot[bot]) +- implement `AggregateExec.partition_statistics` [#15954](https://github.com/apache/datafusion/pull/15954) (UBarney) +- Propagate .execute() calls immediately in `RepartitionExec` [#16093](https://github.com/apache/datafusion/pull/16093) (gabotechs) +- Set aggregation hash seed [#16165](https://github.com/apache/datafusion/pull/16165) (ctsk) +- Fix ScalarStructBuilder::build() for an empty struct [#16205](https://github.com/apache/datafusion/pull/16205) (Blizzara) +- Return an error on overflow in `do_append_val_inner` [#16201](https://github.com/apache/datafusion/pull/16201) (liamzwbao) +- chore(deps): bump testcontainers-modules from 0.12.0 to 0.12.1 [#16212](https://github.com/apache/datafusion/pull/16212) (dependabot[bot]) +- Substrait: handle identical grouping expressions [#16189](https://github.com/apache/datafusion/pull/16189) (cht42) +- Add new stats pruning helpers to allow combining partition values in file level stats [#16139](https://github.com/apache/datafusion/pull/16139) (adriangb) +- Implement schema adapter support for FileSource and add integration tests [#16148](https://github.com/apache/datafusion/pull/16148) (kosiew) +- Minor: update documentation for PrunableStatistics [#16213](https://github.com/apache/datafusion/pull/16213) (alamb) +- Remove use of deprecated dict_ordered in datafusion-proto (#16218) [#16220](https://github.com/apache/datafusion/pull/16220) (cj-zhukov) +- Minor: Print cargo command in bench script [#16236](https://github.com/apache/datafusion/pull/16236) (2010YOUY01) +- Simplify FileSource / SchemaAdapterFactory API [#16214](https://github.com/apache/datafusion/pull/16214) (alamb) +- Add dicts to aggregation fuzz testing [#16232](https://github.com/apache/datafusion/pull/16232) (blaginin) +- chore(deps): bump sysinfo from 0.35.1 to 0.35.2 [#16247](https://github.com/apache/datafusion/pull/16247) (dependabot[bot]) +- Improve performance of constant aggregate window expression [#16234](https://github.com/apache/datafusion/pull/16234) (suibianwanwank) +- Support compound identifier when parsing tuples [#16225](https://github.com/apache/datafusion/pull/16225) (hozan23) +- Schema adapter helper [#16108](https://github.com/apache/datafusion/pull/16108) (kosiew) +- Update tpch, clickbench, sort_tpch to mark failed queries [#16182](https://github.com/apache/datafusion/pull/16182) (ding-young) +- Adjust slttest to pass without RUST_BACKTRACE enabled [#16251](https://github.com/apache/datafusion/pull/16251) (alamb) +- Handle dicts for distinct count [#15871](https://github.com/apache/datafusion/pull/15871) (blaginin) +- Add `--substrait-round-trip` option in sqllogictests [#16183](https://github.com/apache/datafusion/pull/16183) (gabotechs) +- Minor: fix upgrade papercut `pub use PruningStatistics` [#16264](https://github.com/apache/datafusion/pull/16264) (alamb) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 30 dependabot[bot] + 29 Andrew Lamb + 16 xudong.w + 14 Adrian Garcia Badaracco + 10 Chen Chongchen + 8 Gabriel + 8 Oleks V + 7 miro + 6 Tommy shu + 6 kamille + 5 Lokesh + 5 Tim Saucer + 4 Dmitrii Blaginin + 4 Jay Zhan + 4 Nuno Faria + 4 Yongting You + 4 logan-keede + 3 Christian + 3 Daniël Heres + 3 Liam Bao + 3 Phillip LeBlanc + 3 Piotr Findeisen + 3 ding-young + 2 Andy Grove + 2 Atahan Yorgancı + 2 Brayan Jules + 2 Georgi Krastev + 2 Jax Liu + 2 Jérémie Drouet + 2 LB7666 + 2 Leonardo Yvens + 2 Qi Zhu + 2 Sergey Zhukov + 2 Shruti Sharma + 2 Tai Le Manh + 2 aditya singh rathore + 2 ake + 2 cht42 + 2 gstvg + 2 kosiew + 2 niebayes + 2 张林伟 + 1 Ahmed Mezghani + 1 Alexander Droste + 1 Andy Yen + 1 Arka Dash + 1 Arttu + 1 Dan Harris + 1 David Hewitt + 1 Davy + 1 Ed Seidl + 1 Eshed Schacham + 1 Evgenii Khramkov + 1 Florent Monjalet + 1 Galim Bikbov + 1 Garam Choi + 1 Hamir Mahal + 1 Hendrik Makait + 1 Jonathan Chen + 1 Joseph Fahnestock + 1 Kevin Zimmerman + 1 Lordworms + 1 Lía Adriana + 1 Matt Butrovich + 1 Namgung Chan + 1 Nelson Antunes + 1 Patrick Sullivan + 1 Raz Luvaton + 1 Ruihang Xia + 1 Ryan Roelke + 1 Sam Hughes + 1 Shehab Amin + 1 Sile Zhou + 1 Simon Vandel Sillesen + 1 Tom Montgomery + 1 UBarney + 1 Victorien + 1 Xiangpeng Hao + 1 Zaki + 1 chen quan + 1 delamarch3 + 1 discord9 + 1 hozan23 + 1 irenjj + 1 jsai28 + 1 m09526 + 1 suibianwanwan + 1 the0ninjas + 1 wiedld +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/release/README.md b/dev/release/README.md index 6e4079de8f06..f1b0d286e895 100644 --- a/dev/release/README.md +++ b/dev/release/README.md @@ -278,17 +278,23 @@ Verify that the Cargo.toml in the tarball contains the correct version (cd datafusion/optimizer && cargo publish) (cd datafusion/common-runtime && cargo publish) (cd datafusion/physical-plan && cargo publish) +(cd datafusion/session && cargo publish) (cd datafusion/physical-optimizer && cargo publish) -(cd datafusion/catalog && cargo publish) (cd datafusion/datasource && cargo publish) +(cd datafusion/catalog && cargo publish) (cd datafusion/catalog-listing && cargo publish) (cd datafusion/functions-table && cargo publish) +(cd datafusion/datasource-csv && cargo publish) +(cd datafusion/datasource-json && cargo publish) +(cd datafusion/datasource-parquet && cargo publish) (cd datafusion/core && cargo publish) (cd datafusion/proto-common && cargo publish) (cd datafusion/proto && cargo publish) +(cd datafusion/datasource-avro && cargo publish) (cd datafusion/substrait && cargo publish) (cd datafusion/ffi && cargo publish) (cd datafusion-cli && cargo publish) +(cd datafusion/spark && cargo publish) (cd datafusion/sqllogictest && cargo publish) ``` diff --git a/dev/update_config_docs.sh b/dev/update_config_docs.sh index 585cb77839f9..10f82ce94547 100755 --- a/dev/update_config_docs.sh +++ b/dev/update_config_docs.sh @@ -25,6 +25,8 @@ cd "${SOURCE_DIR}/../" && pwd TARGET_FILE="docs/source/user-guide/configs.md" PRINT_CONFIG_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_config_docs" +PRINT_RUNTIME_CONFIG_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_runtime_config_docs" + echo "Inserting header" cat <<'EOF' > "$TARGET_FILE" @@ -70,6 +72,27 @@ EOF echo "Running CLI and inserting config docs table" $PRINT_CONFIG_DOCS_COMMAND >> "$TARGET_FILE" +echo "Inserting runtime config header" +cat <<'EOF' >> "$TARGET_FILE" + +# Runtime Configuration Settings + +DataFusion runtime configurations can be set via SQL using the `SET` command. + +For example, to configure `datafusion.runtime.memory_limit`: + +```sql +SET datafusion.runtime.memory_limit = '2G'; +``` + +The following runtime configuration settings are available: + +EOF + +echo "Running CLI and inserting runtime config docs table" +$PRINT_RUNTIME_CONFIG_DOCS_COMMAND >> "$TARGET_FILE" + + echo "Running prettier" npx prettier@2.3.2 --write "$TARGET_FILE" diff --git a/docs/requirements.txt b/docs/requirements.txt index bd030fb67044..b1cb76a3cc7e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -16,6 +16,7 @@ # under the License. sphinx +sphinx-reredirects pydata-sphinx-theme==0.8.0 myst-parser maturin diff --git a/docs/source/conf.py b/docs/source/conf.py index 00037867a092..5e31864e9ad6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -52,6 +52,7 @@ 'sphinx.ext.viewcode', 'sphinx.ext.napoleon', 'myst_parser', + 'sphinx_reredirects', ] source_suffix = { @@ -121,3 +122,8 @@ # issue for our documentation. So, suppress these warnings to keep our build # log cleaner. suppress_warnings = ['misc.highlighting_failure'] + +redirects = { + "library-user-guide/adding-udfs": "functions/index.html", + "user-guide/runtime_configs": "configs.html", +} \ No newline at end of file diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index e38898db5a92..a4e3e4cb9407 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -108,6 +108,26 @@ Features above) prior to acceptance include: [extensions list]: ../library-user-guide/extensions.md [design goal]: https://docs.rs/datafusion/latest/datafusion/index.html#design-goals +### Design Build vs. Big Up Front Design + +Typically, the DataFusion community attacks large problems by solving them bit +by bit and refining a solution iteratively on the `main` branch as a series of +Pull Requests. This is different from projects which front-load the effort +with a more comprehensive design process. + +By "advancing the front" the community always makes tangible progress, and the strategy is +especially effective in a project that relies on individual contributors who may +not have the time or resources to invest in a large upfront design effort. +However, this "bit by bit approach" doesn't always succeed, and sometimes we get +stuck or go down the wrong path and then change directions. + +Our process necessarily results in imperfect solutions being the "state of the +code" in some cases, and larger visions are not yet fully realized. However, the +community is good at driving things to completion in the long run. If you see +something that needs improvement or an area that is not yet fully realized, +please consider submitting an issue or PR to improve it. We are always looking +for more contributions. + # Developer's guide ## Pull Request Overview diff --git a/docs/source/contributor-guide/inviting.md b/docs/source/contributor-guide/inviting.md index c6ed2695cfc1..20b644b2baae 100644 --- a/docs/source/contributor-guide/inviting.md +++ b/docs/source/contributor-guide/inviting.md @@ -280,9 +280,10 @@ If you accept, please let us know by replying to private@datafusion.apache.org. ## New PMC Members -See also the ASF instructions on [how to add a PMC member]. +This is a DataFusion specific cookbook for the Apache Software Foundation +instructions on [how to add a PMC member]. -[how to add a pmc member]: https://www.apache.org/dev/pmc.html#newpmc +[how to add a pmc member]: https://www.apache.org/dev/pmc.html#pmcmembers ### Step 1: Start a Discussion Thread @@ -333,29 +334,18 @@ Thanks, Your Name ``` -### Step 3: Send Notice to ASF Board - -The DataFusion PMC Chair then sends a NOTICE to `board@apache.org` (cc'ing -`private@`) like this: +If this vote succeeds, send a "RESULT" email to `private@` like this: ``` -To: board@apache.org -Cc: private@datafusion.apache.org -Subject: [NOTICE] $NEW_PMC_MEMBER to join DataFusion PMC - -DataFusion proposes to invite $NEW_PMC_MEMBER ($NEW_PMC_MEMBER_APACHE_ID) to join the PMC. - -The vote result is available here: -$VOTE_RESULT_URL +To: private@datafusion.apache.org +Subject: [RESULT][VOTE] $NEW_PMC_MEMBER for PMC -FYI: Full vote details: -$VOTE_URL +The vote carries with N +1 votes and no -1 votes. I will send an invitation ``` -### Step 4: Send invitation email +### Step 3: Send invitation email -Once, the PMC chair has confirmed that the email sent to `board@apache.org` has -made it to the archives, the Chair sends an invitation e-mail to the new PMC +Assuming the vote passes, the Chair sends an invitation e-mail to the new PMC member (cc'ing `private@`) like this: ``` @@ -405,11 +395,11 @@ With the expectation of your acceptance, welcome! The Apache DataFusion PMC ``` -### Step 5: Chair Promotes the Committer to PMC +### Step 4: Chair Promotes the Committer to PMC The PMC chair adds the user to the PMC using the [Whimsy Roster Tool]. -### Step 6: Announce and Celebrate the New PMC Member +### Step 5: Announce and Celebrate the New PMC Member Send an email such as the following to `dev@datafusion.apache.org` to celebrate: diff --git a/docs/source/contributor-guide/roadmap.md b/docs/source/contributor-guide/roadmap.md index 3d9c1ee371fe..79add1b86f47 100644 --- a/docs/source/contributor-guide/roadmap.md +++ b/docs/source/contributor-guide/roadmap.md @@ -46,81 +46,13 @@ make review efficient and avoid surprises. # Quarterly Roadmap -A quarterly roadmap will be published to give the DataFusion community -visibility into the priorities of the projects contributors. This roadmap is not -binding and we would welcome any/all contributions to help keep this list up to -date. +The DataFusion roadmap is driven by the priorities of contributors rather than +any single organization or coordinating committee. We typically discuss our +roadmap using GitHub issues, approximately quarterly, and invite you to join the +discussion. -## 2023 Q4 +For more information: -- Improve data output (`COPY`, `INSERT` and DataFrame) output capability [#6569](https://github.com/apache/datafusion/issues/6569) -- Implementation of `ARRAY` types and related functions [#6980](https://github.com/apache/datafusion/issues/6980) -- Write an industrial paper about DataFusion for SIGMOD [#6782](https://github.com/apache/datafusion/issues/6782) - -## 2022 Q2 - -### DataFusion Core - -- IO Improvements - - Reading, registering, and writing more file formats from both DataFrame API and SQL - - Additional options for IO including partitioning and metadata support -- Work Scheduling - - Improve predictability, observability and performance of IO and CPU-bound work - - Develop a more explicit story for managing parallelism during plan execution -- Memory Management - - Add more operators for memory limited execution -- Performance - - Incorporate row-format into operators such as aggregate - - Add row-format benchmarks - - Explore JIT-compiling complex expressions - - Explore LLVM for JIT, with inline Rust functions as the primary goal - - Improve performance of Sort and Merge using Row Format / JIT expressions -- Documentation - - General improvements to DataFusion website - - Publish design documents -- Streaming - - Create `StreamProvider` trait - -### Ballista - -- Make production ready - - Shuffle file cleanup - - Fill functional gaps between DataFusion and Ballista - - Improve task scheduling and data exchange efficiency - - Better error handling - - Task failure - - Executor lost - - Schedule restart - - Improve monitoring and logging - - Auto scaling support -- Support for multi-scheduler deployments. Initially for resiliency and fault tolerance but ultimately to support sharding for scalability and more efficient caching. -- Executor deployment grouping based on resource allocation - -### Extensions ([datafusion-contrib](https://github.com/datafusion-contrib)) - -### [DataFusion-Python](https://github.com/datafusion-contrib/datafusion-python) - -- Add missing functionality to DataFrame and SessionContext -- Improve documentation - -### [DataFusion-S3](https://github.com/datafusion-contrib/datafusion-objectstore-s3) - -- Create Python bindings to use with datafusion-python - -### [DataFusion-Tui](https://github.com/datafusion-contrib/datafusion-tui) - -- Create multiple SQL editors -- Expose more Context and query metadata -- Support new data sources - - BigTable, HDFS, HTTP APIs - -### [DataFusion-BigTable](https://github.com/datafusion-contrib/datafusion-bigtable) - -- Python binding to use with datafusion-python -- Timestamp range predicate pushdown -- Multi-threaded partition aware execution -- Production ready Rust SDK - -### [DataFusion-Streams](https://github.com/datafusion-contrib/datafusion-streams) - -- Create experimental implementation of `StreamProvider` trait +1. [Search for issues labeled `roadmap`](https://github.com/apache/datafusion/issues?q=is%3Aissue%20%20%20roadmap) +2. [DataFusion Road Map: Q3-Q4 2025](https://github.com/apache/datafusion/issues/15878) +3. [2024 Q4 / 2025 Q1 Roadmap](https://github.com/apache/datafusion/issues/13274) diff --git a/docs/source/index.rst b/docs/source/index.rst index 0dc947fdea57..01f39bcb7c2e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -132,8 +132,9 @@ To get started, see library-user-guide/using-the-dataframe-api library-user-guide/building-logical-plans library-user-guide/catalogs - library-user-guide/adding-udfs + library-user-guide/functions/index library-user-guide/custom-table-providers + library-user-guide/table-constraints library-user-guide/extending-operators library-user-guide/profiling library-user-guide/query-optimizer diff --git a/docs/source/library-user-guide/catalogs.md b/docs/source/library-user-guide/catalogs.md index 906039ba2300..d4e6633d40ba 100644 --- a/docs/source/library-user-guide/catalogs.md +++ b/docs/source/library-user-guide/catalogs.md @@ -23,11 +23,14 @@ This section describes how to create and manage catalogs, schemas, and tables in ## General Concepts -CatalogProviderList, Catalogs, schemas, and tables are organized in a hierarchy. A CatalogProviderList contains catalog providers, a catalog provider contains schemas and a schema contains tables. +Catalog providers, catalogs, schemas, and tables are organized in a hierarchy. A `CatalogProviderList` contains `CatalogProvider`s, a `CatalogProvider` contains `SchemaProviders` and a `SchemaProvider` contains `TableProvider`s. DataFusion comes with a basic in memory catalog functionality in the [`catalog` module]. You can use these in memory implementations as is, or extend DataFusion with your own catalog implementations, for example based on local files or files on remote object storage. +DataFusion supports DDL queries (e.g. `CREATE TABLE`) using the catalog API described in this section. See the [TableProvider] section for information on DML queries (e.g. `INSERT INTO`). + [`catalog` module]: https://docs.rs/datafusion/latest/datafusion/catalog/index.html +[tableprovider]: ./custom-table-providers.md Similarly to other concepts in DataFusion, you'll implement various traits to create your own catalogs, schemas, and tables. The following sections describe the traits you'll need to implement. diff --git a/docs/source/library-user-guide/custom-table-providers.md b/docs/source/library-user-guide/custom-table-providers.md index 886ac9629566..695cb16ac860 100644 --- a/docs/source/library-user-guide/custom-table-providers.md +++ b/docs/source/library-user-guide/custom-table-providers.md @@ -19,17 +19,25 @@ # Custom Table Provider -Like other areas of DataFusion, you extend DataFusion's functionality by implementing a trait. The `TableProvider` and associated traits, have methods that allow you to implement a custom table provider, i.e. use DataFusion's other functionality with your custom data source. +Like other areas of DataFusion, you extend DataFusion's functionality by implementing a trait. The [`TableProvider`] and associated traits allow you to implement a custom table provider, i.e. use DataFusion's other functionality with your custom data source. -This section will also touch on how to have DataFusion use the new `TableProvider` implementation. +This section describes how to create a [`TableProvider`] and how to configure DataFusion to use it for reading. + +For details on how table constraints such as primary keys or unique +constraints are handled, see [Table Constraint Enforcement](table-constraints.md). ## Table Provider and Scan -The `scan` method on the `TableProvider` is likely its most important. It returns an `ExecutionPlan` that DataFusion will use to read the actual data during execution of the query. +The [`TableProvider::scan`] method reads data from the table and is likely the most important. It returns an [`ExecutionPlan`] that DataFusion will use to read the actual data during execution of the query. The [`TableProvider::insert_into`] method is used to `INSERT` data into the table. ### Scan -As mentioned, `scan` returns an execution plan, and in particular a `Result>`. The core of this is returning something that can be dynamically dispatched to an `ExecutionPlan`. And as per the general DataFusion idea, we'll need to implement it. +As mentioned, [`TableProvider::scan`] returns an execution plan, and in particular a `Result>`. The core of this is returning something that can be dynamically dispatched to an `ExecutionPlan`. And as per the general DataFusion idea, we'll need to implement it. + +[`tableprovider`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html +[`tableprovider::scan`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html#tymethod.scan +[`tableprovider::insert_into`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html#tymethod.insert_into +[`executionplan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html #### Execution Plan diff --git a/docs/source/library-user-guide/extending-operators.md b/docs/source/library-user-guide/extending-operators.md index 631bdc67975a..3d491806a4e6 100644 --- a/docs/source/library-user-guide/extending-operators.md +++ b/docs/source/library-user-guide/extending-operators.md @@ -19,4 +19,41 @@ # Extending DataFusion's operators: custom LogicalPlan and Execution Plans -Coming soon +DataFusion supports extension of operators by transforming logical plan and execution plan through customized [optimizer rules](https://docs.rs/datafusion/latest/datafusion/optimizer/trait.OptimizerRule.html). This section will use the µWheel project to illustrate such capabilities. + +## About DataFusion µWheel + +[DataFusion µWheel](https://github.com/uwheel/datafusion-uwheel/tree/main) is a native DataFusion optimizer which improves query performance for time-based analytics through fast temporal aggregation and pruning using custom indices. The integration of µWheel into DataFusion is a joint effort with the DataFusion community. + +### Optimizing Logical Plan + +The `rewrite` function transforms logical plans by identifying temporal patterns and aggregation functions that match the stored wheel indices. When match is found, it queries the corresponding index to retrieve pre-computed aggregate values, stores these results in a [MemTable](https://docs.rs/datafusion/latest/datafusion/datasource/memory/struct.MemTable.html), and returns as a new `LogicalPlan::TableScan`. If no match is found, the original plan proceeds unchanged through DataFusion's standard execution path. + +```rust,ignore +fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, +) -> Result> { + // Attemps to rewrite a logical plan to a uwheel-based plan that either provides + // plan-time aggregates or skips execution based on min/max pruning. + if let Some(rewritten) = self.try_rewrite(&plan) { + Ok(Transformed::yes(rewritten)) + } else { + Ok(Transformed::no(plan)) + } +} +``` + +```rust,ignore +// Converts a uwheel aggregate result to a TableScan with a MemTable as source +fn agg_to_table_scan(result: f64, schema: SchemaRef) -> Result { + let data = Float64Array::from(vec![result]); + let record_batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(data)])?; + let df_schema = Arc::new(DFSchema::try_from(schema.clone())?); + let mem_table = MemTable::try_new(schema, vec![vec![record_batch]])?; + mem_table_as_table_scan(mem_table, df_schema) +} +``` + +To get a deeper dive into the usage of the µWheel project, visit the [blog post](https://uwheel.rs/post/datafusion_uwheel/) by Max Meldrum. diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/functions/adding-udfs.md similarity index 84% rename from docs/source/library-user-guide/adding-udfs.md rename to docs/source/library-user-guide/functions/adding-udfs.md index 8fb8a59fb860..cf5624f68d04 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/functions/adding-udfs.md @@ -23,19 +23,20 @@ User Defined Functions (UDFs) are functions that can be used in the context of D This page covers how to add UDFs to DataFusion. In particular, it covers how to add Scalar, Window, and Aggregate UDFs. -| UDF Type | Description | Example | -| --------- | ---------------------------------------------------------------------------------------------------------- | ------------------- | -| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs][1] | -| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs][2] | -| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs][3] | -| Table | A function that takes parameters and returns a `TableProvider` to be used in an query plan. | [simple_udtf.rs][4] | +| UDF Type | Description | Example | +| ------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------- | +| Scalar | A function that takes a row of data and returns a single value. | [simple_udf.rs][1] | +| Window | A function that takes a row of data and returns a single value, but also has access to the rows around it. | [simple_udwf.rs][2] | +| Aggregate | A function that takes a group of rows and returns a single value. | [simple_udaf.rs][3] | +| Table | A function that takes parameters and returns a `TableProvider` to be used in an query plan. | [simple_udtf.rs][4] | +| Async Scalar | A scalar function that natively supports asynchronous execution, allowing you to perform async operations (such as network or I/O calls) within the UDF. | [async_udf.rs][5] | First we'll talk about adding an Scalar UDF end-to-end, then we'll talk about the differences between the different types of UDFs. ## Adding a Scalar UDF -A Scalar UDF is a function that takes a row of data and returns a single value. In order for good performance +A Scalar UDF is a function that takes a row of data and returns a single value. To achieve good performance, such functions are "vectorized" in DataFusion, meaning they get one or more Arrow Arrays as input and produce an Arrow Array with the same number of rows as output. @@ -47,8 +48,8 @@ To create a Scalar UDF, you In the following example, we will add a function takes a single i64 and returns a single i64 with 1 added to it: -For brevity, we'll skipped some error handling, but e.g. you may want to check that `args.len()` is the expected number -of arguments. +For brevity, we'll skip some error handling. +For production code, you may want to check, for example, that `args.len()` matches the expected number of arguments. ### Adding by `impl ScalarUDFImpl` @@ -344,6 +345,232 @@ async fn main() { } ``` +## Adding a Scalar Async UDF + +A Scalar Async UDF allows you to implement user-defined functions that support +asynchronous execution, such as performing network or I/O operations within the +UDF. + +To add a Scalar Async UDF, you need to: + +1. Implement the `AsyncScalarUDFImpl` trait to define your async function logic, signature, and types. +2. Wrap your implementation with `AsyncScalarUDF::new` and register it with the `SessionContext`. + +### Adding by `impl AsyncScalarUDFImpl` + +```rust +use arrow::array::{ArrayIter, ArrayRef, AsArray, StringArray}; +use arrow_schema::DataType; +use async_trait::async_trait; +use datafusion::common::error::Result; +use datafusion::common::{internal_err, not_impl_err}; +use datafusion::common::types::logical_string; +use datafusion::config::ConfigOptions; +use datafusion_expr::ScalarUDFImpl; +use datafusion::logical_expr::async_udf::AsyncScalarUDFImpl; +use datafusion::logical_expr::{ + ColumnarValue, Signature, TypeSignature, TypeSignatureClass, Volatility, ScalarFunctionArgs +}; +use datafusion::logical_expr_common::signature::Coercion; +use log::trace; +use std::any::Any; +use std::sync::Arc; + +#[derive(Debug)] +pub struct AsyncUpper { + signature: Signature, +} + +impl Default for AsyncUpper { + fn default() -> Self { + Self::new() + } +} + +impl AsyncUpper { + pub fn new() -> Self { + Self { + signature: Signature::new( + TypeSignature::Coercible(vec![Coercion::Exact { + desired_type: TypeSignatureClass::Native(logical_string()), + }]), + Volatility::Volatile, + ), + } + } +} + +/// Implement the normal ScalarUDFImpl trait for AsyncUpper +#[async_trait] +impl ScalarUDFImpl for AsyncUpper { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "async_upper" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> Result { + not_impl_err!("AsyncUpper can only be called from async contexts") + } +} + +/// The actual implementation of the async UDF +#[async_trait] +impl AsyncScalarUDFImpl for AsyncUpper { + fn ideal_batch_size(&self) -> Option { + Some(10) + } + + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + _option: &ConfigOptions, + ) -> Result { + trace!("Invoking async_upper with args: {:?}", args); + let value = &args.args[0]; + let result = match value { + ColumnarValue::Array(array) => { + let string_array = array.as_string::(); + let iter = ArrayIter::new(string_array); + let result = iter + .map(|string| string.map(|s| s.to_uppercase())) + .collect::(); + Arc::new(result) as ArrayRef + } + _ => return internal_err!("Expected a string argument, got {:?}", value), + }; + Ok(result) + } +} +``` + +We can now transfer the async UDF into the normal scalar using `into_scalar_udf` to register the function with DataFusion so that it can be used in the context of a query. + +```rust +# use arrow::array::{ArrayIter, ArrayRef, AsArray, StringArray}; +# use arrow_schema::DataType; +# use async_trait::async_trait; +# use datafusion::common::error::Result; +# use datafusion::common::{internal_err, not_impl_err}; +# use datafusion::common::types::logical_string; +# use datafusion::config::ConfigOptions; +# use datafusion_expr::ScalarUDFImpl; +# use datafusion::logical_expr::async_udf::AsyncScalarUDFImpl; +# use datafusion::logical_expr::{ +# ColumnarValue, Signature, TypeSignature, TypeSignatureClass, Volatility, ScalarFunctionArgs +# }; +# use datafusion::logical_expr_common::signature::Coercion; +# use log::trace; +# use std::any::Any; +# use std::sync::Arc; +# +# #[derive(Debug)] +# pub struct AsyncUpper { +# signature: Signature, +# } +# +# impl Default for AsyncUpper { +# fn default() -> Self { +# Self::new() +# } +# } +# +# impl AsyncUpper { +# pub fn new() -> Self { +# Self { +# signature: Signature::new( +# TypeSignature::Coercible(vec![Coercion::Exact { +# desired_type: TypeSignatureClass::Native(logical_string()), +# }]), +# Volatility::Volatile, +# ), +# } +# } +# } +# +# #[async_trait] +# impl ScalarUDFImpl for AsyncUpper { +# fn as_any(&self) -> &dyn Any { +# self +# } +# +# fn name(&self) -> &str { +# "async_upper" +# } +# +# fn signature(&self) -> &Signature { +# &self.signature +# } +# +# fn return_type(&self, _arg_types: &[DataType]) -> Result { +# Ok(DataType::Utf8) +# } +# +# fn invoke_with_args( +# &self, +# _args: ScalarFunctionArgs, +# ) -> Result { +# not_impl_err!("AsyncUpper can only be called from async contexts") +# } +# } +# +# #[async_trait] +# impl AsyncScalarUDFImpl for AsyncUpper { +# fn ideal_batch_size(&self) -> Option { +# Some(10) +# } +# +# async fn invoke_async_with_args( +# &self, +# args: ScalarFunctionArgs, +# _option: &ConfigOptions, +# ) -> Result { +# trace!("Invoking async_upper with args: {:?}", args); +# let value = &args.args[0]; +# let result = match value { +# ColumnarValue::Array(array) => { +# let string_array = array.as_string::(); +# let iter = ArrayIter::new(string_array); +# let result = iter +# .map(|string| string.map(|s| s.to_uppercase())) +# .collect::(); +# Arc::new(result) as ArrayRef +# } +# _ => return internal_err!("Expected a string argument, got {:?}", value), +# }; +# Ok(result) +# } +# } +use datafusion::execution::context::SessionContext; +use datafusion::logical_expr::async_udf::AsyncScalarUDF; + +let async_upper = AsyncUpper::new(); +let udf = AsyncScalarUDF::new(Arc::new(async_upper)); +let mut ctx = SessionContext::new(); +ctx.register_udf(udf.into_scalar_udf()); +``` + +After registration, you can use these async UDFs directly in SQL queries, for example: + +```sql +SELECT async_upper('datafusion'); +``` + +For async UDF implementation details, see [`async_udf.rs`](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/async_udf.rs). + [`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html [`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html [`process_scalar_func_inputs`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html @@ -1076,7 +1303,7 @@ pub struct EchoFunction {} impl TableFunctionImpl for EchoFunction { fn call(&self, exprs: &[Expr]) -> Result> { - let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else { + let Some(Expr::Literal(ScalarValue::Int64(Some(value)), _)) = exprs.get(0) else { return plan_err!("First argument must be an integer"); }; @@ -1117,7 +1344,7 @@ With the UDTF implemented, you can register it with the `SessionContext`: # # impl TableFunctionImpl for EchoFunction { # fn call(&self, exprs: &[Expr]) -> Result> { -# let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else { +# let Some(Expr::Literal(ScalarValue::Int64(Some(value)), _)) = exprs.get(0) else { # return plan_err!("First argument must be an integer"); # }; # @@ -1244,8 +1471,3 @@ async fn main() -> Result<()> { Ok(()) } ``` - -[1]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udf.rs -[2]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs -[3]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs -[4]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udtf.rs diff --git a/docs/source/library-user-guide/functions/index.rst b/docs/source/library-user-guide/functions/index.rst new file mode 100644 index 000000000000..d6127446c228 --- /dev/null +++ b/docs/source/library-user-guide/functions/index.rst @@ -0,0 +1,25 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +Functions +============= + +.. toctree:: + :maxdepth: 2 + + adding-udfs + spark diff --git a/docs/source/library-user-guide/functions/spark.md b/docs/source/library-user-guide/functions/spark.md new file mode 100644 index 000000000000..c371ae1cb5a8 --- /dev/null +++ b/docs/source/library-user-guide/functions/spark.md @@ -0,0 +1,29 @@ + + +# Spark Compatible Functions + +The [`datafusion-spark`] crate provides Apache Spark-compatible expressions for +use with DataFusion. + +[`datafusion-spark`]: https://crates.io/crates/datafusion-spark + +Please see the documentation for the [`datafusion-spark` crate] for more details. + +[`datafusion-spark` crate]: https://docs.rs/datafusion-spark/latest/datafusion_spark/ diff --git a/docs/source/library-user-guide/query-optimizer.md b/docs/source/library-user-guide/query-optimizer.md index 03cd7b5bbbbe..a1ccd0a15a7e 100644 --- a/docs/source/library-user-guide/query-optimizer.md +++ b/docs/source/library-user-guide/query-optimizer.md @@ -193,7 +193,7 @@ Looking at the `EXPLAIN` output we can see that the optimizer has effectively re `3 as "1 + 2"`: ```text -> explain select 1 + 2; +> explain format indent select 1 + 2; +---------------+-------------------------------------------------+ | plan_type | plan | +---------------+-------------------------------------------------+ diff --git a/docs/source/library-user-guide/table-constraints.md b/docs/source/library-user-guide/table-constraints.md new file mode 100644 index 000000000000..dea746463d23 --- /dev/null +++ b/docs/source/library-user-guide/table-constraints.md @@ -0,0 +1,42 @@ + + +# Table Constraint Enforcement + +Table providers can describe table constraints using the +[`TableConstraint`] and [`Constraints`] APIs. These constraints include +primary keys, unique keys, foreign keys and check constraints. + +DataFusion does **not** currently enforce these constraints at runtime. +They are provided for informational purposes and can be used by custom +`TableProvider` implementations or other parts of the system. + +- **Nullability**: The only property enforced by DataFusion is the + nullability of each [`Field`] in a schema. Returning data with null values + for Columns marked as not nullable will result in runtime errors during execution. DataFusion + does not check or enforce nullability when data is ingested. +- **Primary and unique keys**: DataFusion does not verify that the data + satisfies primary or unique key constraints. Table providers that + require this behaviour must implement their own checks. +- **Foreign keys and check constraints**: These constraints are parsed + but are not validated or used during query planning. + +[`tableconstraint`]: https://docs.rs/datafusion/latest/datafusion/sql/planner/enum.TableConstraint.html +[`constraints`]: https://docs.rs/datafusion/latest/datafusion/common/functional_dependencies/struct.Constraints.html +[`field`]: https://docs.rs/arrow/latest/arrow/datatype/struct.Field.html diff --git a/docs/source/library-user-guide/upgrading.md b/docs/source/library-user-guide/upgrading.md index 11fd49566522..499e7b14304e 100644 --- a/docs/source/library-user-guide/upgrading.md +++ b/docs/source/library-user-guide/upgrading.md @@ -19,6 +19,425 @@ # Upgrade Guides +## DataFusion `49.0.0` + +**Note:** DataFusion `49.0.0` has not been released yet. The information provided in this section pertains to features and changes that have already been merged to the main branch and are awaiting release in this version. +You can see the current [status of the `49.0.0 `release here](https://github.com/apache/datafusion/issues/16235) + +### `datafusion.execution.collect_statistics` now defaults to `true` + +The default value of the `datafusion.execution.collect_statistics` configuration +setting is now true. This change impacts users that use that value directly and relied +on its default value being `false`. + +This change also restores the default behavior of `ListingTable` to its previous. If you use it directly +you can maintain the current behavior by overriding the default value in your code. + +```rust +# /* comment to avoid running +ListingOptions::new(Arc::new(ParquetFormat::default())) + .with_collect_stat(false) + // other options +# */ +``` + +### Metadata is now represented by `FieldMetadata` + +Metadata from the Arrow `Field` is now stored using the `FieldMetadata` +structure. In prior versions it was stored as both a `HashMap` +and a `BTreeMap`. `FieldMetadata` is a easier to work with and +is more efficient. + +To create `FieldMetadata` from a `Field`: + +```rust +# /* comment to avoid running + let metadata = FieldMetadata::from(&field); +# */ +``` + +To add metadata to a `Field`, use the `add_to_field` method: + +```rust +# /* comment to avoid running +let updated_field = metadata.add_to_field(field); +# */ +``` + +See [#16317] for details. + +[#16317]: https://github.com/apache/datafusion/pull/16317 + +### New `datafusion.execution.spill_compression` configuration option + +DataFusion 49.0.0 adds support for compressing spill files when data is written to disk during spilling query execution. A new configuration option `datafusion.execution.spill_compression` controls the compression codec used. + +**Configuration:** + +- **Key**: `datafusion.execution.spill_compression` +- **Default**: `uncompressed` +- **Valid values**: `uncompressed`, `lz4_frame`, `zstd` + +**Usage:** + +```rust +# /* comment to avoid running +use datafusion::prelude::*; +use datafusion_common::config::SpillCompression; + +let config = SessionConfig::default() + .with_spill_compression(SpillCompression::Zstd); +let ctx = SessionContext::new_with_config(config); +# */ +``` + +Or via SQL: + +```sql +SET datafusion.execution.spill_compression = 'zstd'; +``` + +For more details about this configuration option, including performance trade-offs between different compression codecs, see the [Configuration Settings](../user-guide/configs.md) documentation. + +## DataFusion `48.0.0` + +### `Expr::Literal` has optional metadata + +The [`Expr::Literal`] variant now includes optional metadata, which allows for +carrying through Arrow field metadata to support extension types and other uses. + +This means code such as + +```rust +# /* comment to avoid running +match expr { +... + Expr::Literal(scalar) => ... +... +} +# */ +``` + +Should be updated to: + +```rust +# /* comment to avoid running +match expr { +... + Expr::Literal(scalar, _metadata) => ... +... +} +# */ +``` + +Likewise constructing `Expr::Literal` requires metadata as well. The [`lit`] function +has not changed and returns an `Expr::Literal` with no metadata. + +[`expr::literal`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html#variant.Literal +[`lit`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.lit.html + +### `Expr::WindowFunction` is now `Box`ed + +`Expr::WindowFunction` is now a `Box` instead of a `WindowFunction` directly. +This change was made to reduce the size of `Expr` and improve performance when +planning queries (see [details on #16207]). + +This is a breaking change, so you will need to update your code if you match +on `Expr::WindowFunction` directly. For example, if you have code like this: + +```rust +# /* comment to avoid running +match expr { + Expr::WindowFunction(WindowFunction { + params: + WindowFunctionParams { + partition_by, + order_by, + .. + } + }) => { + // Use partition_by and order_by as needed + } + _ => { + // other expr + } +} +# */ +``` + +You will need to change it to: + +```rust +# /* comment to avoid running +match expr { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + params: WindowFunctionParams { + args, + partition_by, + .. + }, + } = window_fun.as_ref(); + // Use partition_by and order_by as needed + } + _ => { + // other expr + } +} +# */ +``` + +[details on #16207]: https://github.com/apache/datafusion/pull/16207#issuecomment-2922659103 + +### The `VARCHAR` SQL type is now represented as `Utf8View` in Arrow + +The mapping of the SQL `VARCHAR` type has been changed from `Utf8` to `Utf8View` +which improves performance for many string operations. You can read more about +`Utf8View` in the [DataFusion blog post on German-style strings] + +[datafusion blog post on german-style strings]: https://datafusion.apache.org/blog/2024/09/13/string-view-german-style-strings-part-1/ + +This means that when you create a table with a `VARCHAR` column, it will now use +`Utf8View` as the underlying data type. For example: + +```sql +> CREATE TABLE my_table (my_column VARCHAR); +0 row(s) fetched. +Elapsed 0.001 seconds. + +> DESCRIBE my_table; ++-------------+-----------+-------------+ +| column_name | data_type | is_nullable | ++-------------+-----------+-------------+ +| my_column | Utf8View | YES | ++-------------+-----------+-------------+ +1 row(s) fetched. +Elapsed 0.000 seconds. +``` + +You can restore the old behavior of using `Utf8` by changing the +`datafusion.sql_parser.map_varchar_to_utf8view` configuration setting. For +example + +```sql +> set datafusion.sql_parser.map_varchar_to_utf8view = false; +0 row(s) fetched. +Elapsed 0.001 seconds. + +> CREATE TABLE my_table (my_column VARCHAR); +0 row(s) fetched. +Elapsed 0.014 seconds. + +> DESCRIBE my_table; ++-------------+-----------+-------------+ +| column_name | data_type | is_nullable | ++-------------+-----------+-------------+ +| my_column | Utf8 | YES | ++-------------+-----------+-------------+ +1 row(s) fetched. +Elapsed 0.004 seconds. +``` + +### `ListingOptions` default for `collect_stat` changed from `true` to `false` + +This makes it agree with the default for `SessionConfig`. +Most users won't be impacted by this change but if you were using `ListingOptions` directly +and relied on the default value of `collect_stat` being `true`, you will need to +explicitly set it to `true` in your code. + +```rust +# /* comment to avoid running +ListingOptions::new(Arc::new(ParquetFormat::default())) + .with_collect_stat(true) + // other options +# */ +``` + +### Processing `FieldRef` instead of `DataType` for user defined functions + +In order to support metadata handling and extension types, user defined functions are +now switching to traits which use `FieldRef` rather than a `DataType` and nullability. +This gives a single interface to both of these parameters and additionally allows +access to metadata fields, which can be used for extension types. + +To upgrade structs which implement `ScalarUDFImpl`, if you have implemented +`return_type_from_args` you need instead to implement `return_field_from_args`. +If your functions do not need to handle metadata, this should be straightforward +repackaging of the output data into a `FieldRef`. The name you specify on the +field is not important. It will be overwritten during planning. `ReturnInfo` +has been removed, so you will need to remove all references to it. + +`ScalarFunctionArgs` now contains a field called `arg_fields`. You can use this +to access the metadata associated with the columnar values during invocation. + +To upgrade user defined aggregate functions, there is now a function +`return_field` that will allow you to specify both metadata and nullability of +your function. You are not required to implement this if you do not need to +handle metatdata. + +The largest change to aggregate functions happens in the accumulator arguments. +Both the `AccumulatorArgs` and `StateFieldsArgs` now contain `FieldRef` rather +than `DataType`. + +To upgrade window functions, `ExpressionArgs` now contains input fields instead +of input data types. When setting these fields, the name of the field is +not important since this gets overwritten during the planning stage. All you +should need to do is wrap your existing data types in fields with nullability +set depending on your use case. + +### Physical Expression return `Field` + +To support the changes to user defined functions processing metadata, the +`PhysicalExpr` trait, which now must specify a return `Field` based on the input +schema. To upgrade structs which implement `PhysicalExpr` you need to implement +the `return_field` function. There are numerous examples in the `physical-expr` +crate. + +### `FileFormat::supports_filters_pushdown` replaced with `FileSource::try_pushdown_filters` + +To support more general filter pushdown, the `FileFormat::supports_filters_pushdown` was replaced with +`FileSource::try_pushdown_filters`. +If you implemented a custom `FileFormat` that uses a custom `FileSource` you will need to implement +`FileSource::try_pushdown_filters`. +See `ParquetSource::try_pushdown_filters` for an example of how to implement this. + +`FileFormat::supports_filters_pushdown` has been removed. + +### `ParquetExec`, `AvroExec`, `CsvExec`, `JsonExec` Removed + +`ParquetExec`, `AvroExec`, `CsvExec`, and `JsonExec` were deprecated in +DataFusion 46 and are removed in DataFusion 48. This is sooner than the normal +process described in the [API Deprecation Guidelines] because all the tests +cover the new `DataSourceExec` rather than the older structures. As we evolve +`DataSource`, the old structures began to show signs of "bit rotting" (not +working but no one knows due to lack of test coverage). + +[api deprecation guidelines]: https://datafusion.apache.org/contributor-guide/api-health.html#deprecation-guidelines + +### `PartitionedFile` added as an argument to the `FileOpener` trait + +This is necessary to properly fix filter pushdown for filters that combine partition +columns and file columns (e.g. `day = username['dob']`). + +If you implemented a custom `FileOpener` you will need to add the `PartitionedFile` argument +but are not required to use it in any way. + +## DataFusion `47.0.0` + +This section calls out some of the major changes in the `47.0.0` release of DataFusion. + +Here are some example upgrade PRs that demonstrate changes required when upgrading from DataFusion 46.0.0: + +- [delta-rs Upgrade to `47.0.0`](https://github.com/delta-io/delta-rs/pull/3378) +- [DataFusion Comet Upgrade to `47.0.0`](https://github.com/apache/datafusion-comet/pull/1563) +- [Sail Upgrade to `47.0.0`](https://github.com/lakehq/sail/pull/434) + +### Upgrades to `arrow-rs` and `arrow-parquet` 55.0.0 and `object_store` 0.12.0 + +Several APIs are changed in the underlying arrow and parquet libraries to use a +`u64` instead of `usize` to better support WASM (See [#7371] and [#6961]) + +Additionally `ObjectStore::list` and `ObjectStore::list_with_offset` have been changed to return `static` lifetimes (See [#6619]) + +[#6619]: https://github.com/apache/arrow-rs/pull/6619 +[#7371]: https://github.com/apache/arrow-rs/pull/7371 + +This requires converting from `usize` to `u64` occasionally as well as changes to `ObjectStore` implementations such as + +```rust +# /* comment to avoid running +impl Objectstore { + ... + // The range is now a u64 instead of usize + async fn get_range(&self, location: &Path, range: Range) -> ObjectStoreResult { + self.inner.get_range(location, range).await + } + ... + // the lifetime is now 'static instead of `_ (meaning the captured closure can't contain references) + // (this also applies to list_with_offset) + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, ObjectStoreResult> { + self.inner.list(prefix) + } +} +# */ +``` + +The `ParquetObjectReader` has been updated to no longer require the object size +(it can be fetched using a single suffix request). See [#7334] for details + +[#7334]: https://github.com/apache/arrow-rs/pull/7334 + +Pattern in DataFusion `46.0.0`: + +```rust +# /* comment to avoid running +let meta: ObjectMeta = ...; +let reader = ParquetObjectReader::new(store, meta); +# */ +``` + +Pattern in DataFusion `47.0.0`: + +```rust +# /* comment to avoid running +let meta: ObjectMeta = ...; +let reader = ParquetObjectReader::new(store, location) + .with_file_size(meta.size); +# */ +``` + +### `DisplayFormatType::TreeRender` + +DataFusion now supports [`tree` style explain plans]. Implementations of +`Executionplan` must also provide a description in the +`DisplayFormatType::TreeRender` format. This can be the same as the existing +`DisplayFormatType::Default`. + +[`tree` style explain plans]: https://datafusion.apache.org/user-guide/sql/explain.html#tree-format-default + +### Removed Deprecated APIs + +Several APIs have been removed in this release. These were either deprecated +previously or were hard to use correctly such as the multiple different +`ScalarUDFImpl::invoke*` APIs. See [#15130], [#15123], and [#15027] for more +details. + +[#15130]: https://github.com/apache/datafusion/pull/15130 +[#15123]: https://github.com/apache/datafusion/pull/15123 +[#15027]: https://github.com/apache/datafusion/pull/15027 + +### `FileScanConfig` --> `FileScanConfigBuilder` + +Previously, `FileScanConfig::build()` directly created ExecutionPlans. In +DataFusion 47.0.0 this has been changed to use `FileScanConfigBuilder`. See +[#15352] for details. + +[#15352]: https://github.com/apache/datafusion/pull/15352 + +Pattern in DataFusion `46.0.0`: + +```rust +# /* comment to avoid running +let plan = FileScanConfig::new(url, schema, Arc::new(file_source)) + .with_statistics(stats) + ... + .build() +# */ +``` + +Pattern in DataFusion `47.0.0`: + +```rust +# /* comment to avoid running +let config = FileScanConfigBuilder::new(url, schema, Arc::new(file_source)) + .with_statistics(stats) + ... + .build(); +let scan = DataSourceExec::from_data_source(config); +# */ +``` + ## DataFusion `46.0.0` ### Use `invoke_with_args` instead of `invoke()` and `invoke_batch()` @@ -39,7 +458,7 @@ below. See [PR 14876] for an example. Given existing code like this: ```rust -# /* +# /* comment to avoid running impl ScalarUDFImpl for SparkConcat { ... fn invoke_batch(&self, args: &[ColumnarValue], number_rows: usize) -> Result { @@ -59,7 +478,7 @@ impl ScalarUDFImpl for SparkConcat { To ```rust -# /* comment out so they don't run +# /* comment to avoid running impl ScalarUDFImpl for SparkConcat { ... fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -164,7 +583,7 @@ let mut file_source = ParquetSource::new(parquet_options) // Add filter if let Some(predicate) = logical_filter { if config.enable_parquet_pushdown { - file_source = file_source.with_predicate(Arc::clone(&file_schema), predicate); + file_source = file_source.with_predicate(predicate); } }; diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md index df4e5e3940aa..ce3d42cd1360 100644 --- a/docs/source/library-user-guide/working-with-exprs.md +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -75,7 +75,7 @@ Please see [expr_api.rs](https://github.com/apache/datafusion/blob/main/datafusi ## A Scalar UDF Example -We'll use a `ScalarUDF` expression as our example. This necessitates implementing an actual UDF, and for ease we'll use the same example from the [adding UDFs](./adding-udfs.md) guide. +We'll use a `ScalarUDF` expression as our example. This necessitates implementing an actual UDF, and for ease we'll use the same example from the [adding UDFs](functions/adding-udfs.md) guide. So assuming you've written that function, you can use it to create an `Expr`: diff --git a/docs/source/user-guide/cli/datasources.md b/docs/source/user-guide/cli/datasources.md index 2e14f1f54c6c..c15b8a5e46c9 100644 --- a/docs/source/user-guide/cli/datasources.md +++ b/docs/source/user-guide/cli/datasources.md @@ -82,22 +82,29 @@ select count(*) from 'https://datasets.clickhouse.com/hits_compatible/athena_par To read from an AWS S3 or GCS, use `s3` or `gs` as a protocol prefix. For example, to read a file in an S3 bucket named `my-data-bucket` use the URL `s3://my-data-bucket`and set the relevant access credentials as environmental -variables (e.g. for AWS S3 you need to at least `AWS_ACCESS_KEY_ID` and +variables (e.g. for AWS S3 you can use `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY`). ```sql -select count(*) from 's3://my-data-bucket/athena_partitioned/hits.parquet' +> select count(*) from 's3://altinity-clickhouse-data/nyc_taxi_rides/data/tripdata_parquet/'; ++------------+ +| count(*) | ++------------+ +| 1310903963 | ++------------+ ``` -See the [`CREATE EXTERNAL TABLE`](#create-external-table) section for +See the [`CREATE EXTERNAL TABLE`](#create-external-table) section below for additional configuration options. # `CREATE EXTERNAL TABLE` It is also possible to create a table backed by files or remote locations via -`CREATE EXTERNAL TABLE` as shown below. Note that DataFusion does not support wildcards (e.g. `*`) in file paths; instead, specify the directory path directly to read all compatible files in that directory. +`CREATE EXTERNAL TABLE` as shown below. Note that DataFusion does not support +wildcards (e.g. `*`) in file paths; instead, specify the directory path directly +to read all compatible files in that directory. -For example, to create a table `hits` backed by a local parquet file, use: +For example, to create a table `hits` backed by a local parquet file named `hits.parquet`: ```sql CREATE EXTERNAL TABLE hits @@ -105,7 +112,7 @@ STORED AS PARQUET LOCATION 'hits.parquet'; ``` -To create a table `hits` backed by a remote parquet file via HTTP(S), use +To create a table `hits` backed by a remote parquet file via HTTP(S): ```sql CREATE EXTERNAL TABLE hits @@ -127,7 +134,11 @@ select count(*) from hits; **Why Wildcards Are Not Supported** -Although wildcards (e.g., _.parquet or \*\*/_.parquet) may work for local filesystems in some cases, they are not officially supported by DataFusion. This is because wildcards are not universally applicable across all storage backends (e.g., S3, GCS). Instead, DataFusion expects the user to specify the directory path, and it will automatically read all compatible files within that directory. +Although wildcards (e.g., _.parquet or \*\*/_.parquet) may work for local +filesystems in some cases, they are not supported by DataFusion CLI. This +is because wildcards are not universally applicable across all storage backends +(e.g., S3, GCS). Instead, DataFusion expects the user to specify the directory +path, and it will automatically read all compatible files within that directory. For example, the following usage is not supported: @@ -148,7 +159,7 @@ CREATE EXTERNAL TABLE test ( day DATE ) STORED AS PARQUET -LOCATION 'gs://bucket/my_table'; +LOCATION 'gs://bucket/my_table/'; ``` # Formats @@ -168,17 +179,63 @@ LOCATION '/mnt/nyctaxi/tripdata.parquet'; Register a single folder parquet datasource. Note: All files inside must be valid parquet files and have compatible schemas +:::{note} +Paths must end in Slash `/` +: The path must end in `/` otherwise DataFusion will treat the path as a file and not a directory +::: + ```sql CREATE EXTERNAL TABLE taxi STORED AS PARQUET LOCATION '/mnt/nyctaxi/'; ``` +### Parquet Specific Options + +You can specify additional options for parquet files using the `OPTIONS` clause. +For example, to read and write a parquet directory with encryption settings you could use: + +```sql +CREATE EXTERNAL TABLE encrypted_parquet_table +( +double_field double, +float_field float +) +STORED AS PARQUET LOCATION 'pq/' OPTIONS ( + -- encryption + 'format.crypto.file_encryption.encrypt_footer' 'true', + 'format.crypto.file_encryption.footer_key_as_hex' '30313233343536373839303132333435', -- b"0123456789012345" + 'format.crypto.file_encryption.column_key_as_hex::double_field' '31323334353637383930313233343530', -- b"1234567890123450" + 'format.crypto.file_encryption.column_key_as_hex::float_field' '31323334353637383930313233343531', -- b"1234567890123451" + -- decryption + 'format.crypto.file_decryption.footer_key_as_hex' '30313233343536373839303132333435', -- b"0123456789012345" + 'format.crypto.file_decryption.column_key_as_hex::double_field' '31323334353637383930313233343530', -- b"1234567890123450" + 'format.crypto.file_decryption.column_key_as_hex::float_field' '31323334353637383930313233343531', -- b"1234567890123451" +); +``` + +Here the keys are specified in hexadecimal format because they are binary data. These can be encoded in SQL using: + +```sql +select encode('0123456789012345', 'hex'); +/* ++----------------------------------------------+ +| encode(Utf8("0123456789012345"),Utf8("hex")) | ++----------------------------------------------+ +| 30313233343536373839303132333435 | ++----------------------------------------------+ +*/ +``` + +For more details on the available options, refer to the Rust +[TableParquetOptions](https://docs.rs/datafusion/latest/datafusion/common/config/struct.TableParquetOptions.html) +documentation in DataFusion. + ## CSV DataFusion will infer the CSV schema automatically or you can provide it explicitly. -Register a single file csv datasource with a header row. +Register a single file csv datasource with a header row: ```sql CREATE EXTERNAL TABLE test @@ -187,7 +244,7 @@ LOCATION '/path/to/aggregate_test_100.csv' OPTIONS ('has_header' 'true'); ``` -Register a single file csv datasource with explicitly defined schema. +Register a single file csv datasource with explicitly defined schema: ```sql CREATE EXTERNAL TABLE test ( @@ -213,7 +270,7 @@ LOCATION '/path/to/aggregate_test_100.csv'; ## HTTP(s) -To read from a remote parquet file via HTTP(S) you can use the following: +To read from a remote parquet file via HTTP(S): ```sql CREATE EXTERNAL TABLE hits @@ -223,9 +280,12 @@ LOCATION 'https://datasets.clickhouse.com/hits_compatible/athena_partitioned/hit ## S3 -[AWS S3](https://aws.amazon.com/s3/) data sources must have connection credentials configured. +DataFusion CLI supports configuring [AWS S3](https://aws.amazon.com/s3/) via the +`CREATE EXTERNAL TABLE` statement and standard AWS configuration methods (via the +[`aws-config`] AWS SDK crate). -To create an external table from a file in an S3 bucket: +To create an external table from a file in an S3 bucket with explicit +credentials: ```sql CREATE EXTERNAL TABLE test @@ -238,7 +298,7 @@ OPTIONS( LOCATION 's3://bucket/path/file.parquet'; ``` -It is also possible to specify the access information using environment variables: +To create an external table using environment variables: ```bash $ export AWS_DEFAULT_REGION=us-east-2 @@ -247,7 +307,7 @@ $ export AWS_ACCESS_KEY_ID=****** $ datafusion-cli `datafusion-cli v21.0.0 -> create external table test stored as parquet location 's3://bucket/path/file.parquet'; +> create CREATE TABLE test STORED AS PARQUET LOCATION 's3://bucket/path/file.parquet'; 0 rows in set. Query took 0.374 seconds. > select * from test; +----------+----------+ @@ -258,19 +318,39 @@ $ datafusion-cli 1 row in set. Query took 0.171 seconds. ``` +To read from a public S3 bucket without signatures, use the +`aws.SKIP_SIGNATURE` option: + +```sql +CREATE EXTERNAL TABLE nyc_taxi_rides +STORED AS PARQUET LOCATION 's3://altinity-clickhouse-data/nyc_taxi_rides/data/tripdata_parquet/' +OPTIONS(aws.SKIP_SIGNATURE true); +``` + +Credentials are taken in this order of precedence: + +1. Explicitly specified in the `OPTIONS` clause of the `CREATE EXTERNAL TABLE` statement. +2. Determined by [`aws-config`] crate (standard environment variables such as `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` as well as other AWS specific features). + +If no credentials are specified, DataFusion CLI will use unsigned requests to S3, +which allows reading from public buckets. + Supported configuration options are: -| Environment Variable | Configuration Option | Description | -| ---------------------------------------- | ----------------------- | ---------------------------------------------------- | -| `AWS_ACCESS_KEY_ID` | `aws.access_key_id` | | -| `AWS_SECRET_ACCESS_KEY` | `aws.secret_access_key` | | -| `AWS_DEFAULT_REGION` | `aws.region` | | -| `AWS_ENDPOINT` | `aws.endpoint` | | -| `AWS_SESSION_TOKEN` | `aws.token` | | -| `AWS_CONTAINER_CREDENTIALS_RELATIVE_URI` | | See [IAM Roles] | -| `AWS_ALLOW_HTTP` | | set to "true" to permit HTTP connections without TLS | +| Environment Variable | Configuration Option | Description | +| ---------------------------------------- | ----------------------- | ---------------------------------------------- | +| `AWS_ACCESS_KEY_ID` | `aws.access_key_id` | | +| `AWS_SECRET_ACCESS_KEY` | `aws.secret_access_key` | | +| `AWS_DEFAULT_REGION` | `aws.region` | | +| `AWS_ENDPOINT` | `aws.endpoint` | | +| `AWS_SESSION_TOKEN` | `aws.token` | | +| `AWS_CONTAINER_CREDENTIALS_RELATIVE_URI` | | See [IAM Roles] | +| `AWS_ALLOW_HTTP` | | If "true", permit HTTP connections without TLS | +| `AWS_SKIP_SIGNATURE` | `aws.skip_signature` | If "true", does not sign requests | +| | `aws.nosign` | Alias for `skip_signature` | [iam roles]: https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-iam-roles.html +[`aws-config`]: https://docs.rs/aws-config/latest/aws_config/ ## OSS diff --git a/docs/source/user-guide/cli/usage.md b/docs/source/user-guide/cli/usage.md index 68b09d319984..13f0e7cff175 100644 --- a/docs/source/user-guide/cli/usage.md +++ b/docs/source/user-guide/cli/usage.md @@ -57,6 +57,9 @@ OPTIONS: --mem-pool-type Specify the memory pool type 'greedy' or 'fair', default to 'greedy' + --top-memory-consumers + The number of top memory consumers to display when query fails due to memory exhaustion. To disable memory consumer tracking, set this value to 0 [default: 3] + -d, --disk-limit Available disk space for spilling queries (e.g. '10g'), default to None (uses DataFusion's default value of '100g') diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index a90da66e4b0b..c618aa18c231 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -35,100 +35,128 @@ Values are parsed according to the [same rules used in casts from Utf8](https:// If the value in the environment variable cannot be cast to the type of the configuration option, the default value will be used instead and a warning emitted. Environment variables are read during `SessionConfig` initialisation so they must be set beforehand and will not affect running sessions. -| key | default | description | -| ----------------------------------------------------------------------- | ------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| datafusion.catalog.create_default_catalog_and_schema | true | Whether the default catalog and schema should be created automatically. | -| datafusion.catalog.default_catalog | datafusion | The default catalog name - this impacts what SQL queries use if not specified | -| datafusion.catalog.default_schema | public | The default schema name - this impacts what SQL queries use if not specified | -| datafusion.catalog.information_schema | false | Should DataFusion provide access to `information_schema` virtual tables for displaying schema information | -| datafusion.catalog.location | NULL | Location scanned to load tables for `default` schema | -| datafusion.catalog.format | NULL | Type of `TableProvider` to use when loading `default` schema | -| datafusion.catalog.has_header | true | Default value for `format.has_header` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. | -| datafusion.catalog.newlines_in_values | false | Specifies whether newlines in (quoted) CSV values are supported. This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. Parsing newlines in quoted values may be affected by execution behaviour such as parallel file scanning. Setting this to `true` ensures that newlines in values are parsed successfully, which may reduce performance. | -| datafusion.execution.batch_size | 8192 | Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption | -| datafusion.execution.coalesce_batches | true | When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting | -| datafusion.execution.collect_statistics | false | Should DataFusion collect statistics after listing files | -| datafusion.execution.target_partitions | 0 | Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system | -| datafusion.execution.time_zone | +00:00 | The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour | -| datafusion.execution.parquet.enable_page_index | true | (reading) If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. | -| datafusion.execution.parquet.pruning | true | (reading) If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file | -| datafusion.execution.parquet.skip_metadata | true | (reading) If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata | -| datafusion.execution.parquet.metadata_size_hint | NULL | (reading) If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer | -| datafusion.execution.parquet.pushdown_filters | false | (reading) If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". | -| datafusion.execution.parquet.reorder_filters | false | (reading) If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | -| datafusion.execution.parquet.schema_force_view_types | true | (reading) If true, parquet reader will read columns of `Utf8/Utf8Large` with `Utf8View`, and `Binary/BinaryLarge` with `BinaryView`. | -| datafusion.execution.parquet.binary_as_string | false | (reading) If true, parquet reader will read columns of `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. Parquet files generated by some legacy writers do not correctly set the UTF8 flag for strings, causing string columns to be loaded as BLOB instead. | -| datafusion.execution.parquet.coerce_int96 | NULL | (reading) If true, parquet reader will read columns of physical type int96 as originating from a different resolution than nanosecond. This is useful for reading data from systems like Spark which stores microsecond resolution timestamps in an int96 allowing it to write values with a larger date range than 64-bit timestamps with nanosecond resolution. | -| datafusion.execution.parquet.data_pagesize_limit | 1048576 | (writing) Sets best effort maximum size of data page in bytes | -| datafusion.execution.parquet.write_batch_size | 1024 | (writing) Sets write_batch_size in bytes | -| datafusion.execution.parquet.writer_version | 1.0 | (writing) Sets parquet writer version valid values are "1.0" and "2.0" | -| datafusion.execution.parquet.skip_arrow_metadata | false | (writing) Skip encoding the embedded arrow metadata in the KV_meta This is analogous to the `ArrowWriterOptions::with_skip_arrow_metadata`. Refer to | -| datafusion.execution.parquet.compression | zstd(3) | (writing) Sets default parquet compression codec. Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting Note that this default setting is not the same as the default parquet writer setting. | -| datafusion.execution.parquet.dictionary_enabled | true | (writing) Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | (writing) Sets best effort maximum dictionary page size, in bytes | -| datafusion.execution.parquet.statistics_enabled | page | (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.max_statistics_size | 4096 | (writing) Sets max statistics size for any column. If NULL, uses default parquet writer setting max_statistics_size is deprecated, currently it is not being used | -| datafusion.execution.parquet.max_row_group_size | 1048576 | (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | -| datafusion.execution.parquet.created_by | datafusion version 46.0.1 | (writing) Sets "created by" property | -| datafusion.execution.parquet.column_index_truncate_length | 64 | (writing) Sets column index truncate length | -| datafusion.execution.parquet.statistics_truncate_length | NULL | (writing) Sets statictics truncate length. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.data_page_row_count_limit | 20000 | (writing) Sets best effort maximum number of rows in data page | -| datafusion.execution.parquet.encoding | NULL | (writing) Sets default encoding for any column. Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.bloom_filter_on_read | true | (writing) Use any available bloom filters when reading parquet files | -| datafusion.execution.parquet.bloom_filter_on_write | false | (writing) Write bloom filters for all columns when creating parquet files | -| datafusion.execution.parquet.bloom_filter_fpp | NULL | (writing) Sets bloom filter false positive probability. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.bloom_filter_ndv | NULL | (writing) Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.allow_single_file_parallelism | true | (writing) Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. | -| datafusion.execution.parquet.maximum_parallel_row_group_writers | 1 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | -| datafusion.execution.parquet.maximum_buffered_record_batches_per_stream | 2 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | -| datafusion.execution.planning_concurrency | 0 | Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system | -| datafusion.execution.skip_physical_aggregate_schema_check | false | When set to true, skips verifying that the schema produced by planning the input of `LogicalPlan::Aggregate` exactly matches the schema of the input plan. When set to false, if the schema does not match exactly (including nullability and metadata), a planning error will be raised. This is used to workaround bugs in the planner that are now caught by the new schema verification step. | -| datafusion.execution.sort_spill_reservation_bytes | 10485760 | Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). | -| datafusion.execution.sort_in_place_threshold_bytes | 1048576 | When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. | -| datafusion.execution.meta_fetch_concurrency | 32 | Number of files to read in parallel when inferring schema and statistics | -| datafusion.execution.minimum_parallel_output_files | 4 | Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. | -| datafusion.execution.soft_max_rows_per_output_file | 50000000 | Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max | -| datafusion.execution.max_buffered_batches_per_output_file | 2 | This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption | -| datafusion.execution.listing_table_ignore_subdirectory | true | Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). | -| datafusion.execution.enable_recursive_ctes | true | Should DataFusion support recursive CTEs | -| datafusion.execution.split_file_groups_by_statistics | false | Attempt to eliminate sorts by packing & sorting files with non-overlapping statistics into the same file groups. Currently experimental | -| datafusion.execution.keep_partition_by_columns | false | Should DataFusion keep the columns used for partition_by in the output RecordBatches | -| datafusion.execution.skip_partial_aggregation_probe_ratio_threshold | 0.8 | Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input | -| datafusion.execution.skip_partial_aggregation_probe_rows_threshold | 100000 | Number of input rows partial aggregation partition should process, before aggregation ratio check and trying to switch to skipping aggregation mode | -| datafusion.execution.use_row_number_estimates_to_optimize_partitioning | false | Should DataFusion use row number estimates at the input to decide whether increasing parallelism is beneficial or not. By default, only exact row numbers (not estimates) are used for this decision. Setting this flag to `true` will likely produce better plans. if the source of statistics is accurate. We plan to make this the default in the future. | -| datafusion.execution.enforce_batch_size_in_joins | false | Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. | -| datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | -| datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | -| datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | -| datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | -| datafusion.optimizer.repartition_aggregations | true | Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level | -| datafusion.optimizer.repartition_file_min_size | 10485760 | Minimum total files size in bytes to perform file scan repartitioning. | -| datafusion.optimizer.repartition_joins | true | Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level | -| datafusion.optimizer.allow_symmetric_joins_without_pruning | true | Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. | -| datafusion.optimizer.repartition_file_scans | true | When set to `true`, file groups will be repartitioned to achieve maximum parallelism. Currently Parquet and CSV formats are supported. If set to `true`, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false`, different files will be read in parallel, but repartitioning won't happen within a single file. | -| datafusion.optimizer.repartition_windows | true | Should DataFusion repartition data using the partitions keys to execute window functions in parallel using the provided `target_partitions` level | -| datafusion.optimizer.repartition_sorts | true | Should DataFusion execute sorts in a per-partition fashion and merge afterwards instead of coalescing first and sorting globally. With this flag is enabled, plans in the form below `text "SortExec: [a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` would turn into the plan below which performs better in multithreaded environments `text "SortPreservingMergeExec: [a@0 ASC]", " SortExec: [a@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` | -| datafusion.optimizer.prefer_existing_sort | false | When true, DataFusion will opportunistically remove sorts when the data is already sorted, (i.e. setting `preserve_order` to true on `RepartitionExec` and using `SortPreservingMergeExec`) When false, DataFusion will maximize plan parallelism using `RepartitionExec` even if this requires subsequently resorting data using a `SortExec`. | -| datafusion.optimizer.skip_failed_rules | false | When set to true, the logical plan optimizer will produce warning messages if any optimization rules produce errors and then proceed to the next rule. When set to false, any rules that produce errors will cause the query to fail | -| datafusion.optimizer.max_passes | 3 | Number of times that the optimizer will attempt to optimize the plan | -| datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | -| datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | -| datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | -| datafusion.optimizer.hash_join_single_partition_threshold_rows | 131072 | The maximum estimated size in rows for one input side of a HashJoin will be collected into a single partition | -| datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). | -| datafusion.optimizer.prefer_existing_union | false | When set to true, the optimizer will not attempt to convert Union to Interleave | -| datafusion.optimizer.expand_views_at_output | false | When set to true, if the returned type is a view type then the output will be coerced to a non-view. Coerces `Utf8View` to `LargeUtf8`, and `BinaryView` to `LargeBinary`. | -| datafusion.explain.logical_plan_only | false | When set to true, the explain statement will only print logical plans | -| datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | -| datafusion.explain.show_statistics | false | When set to true, the explain statement will print operator statistics for physical plans | -| datafusion.explain.show_sizes | true | When set to true, the explain statement will print the partition sizes | -| datafusion.explain.show_schema | false | When set to true, the explain statement will print schema information | -| datafusion.explain.format | indent | Display format of explain. Default is "indent". When set to "tree", it will print the plan in a tree-rendered format. | -| datafusion.sql_parser.parse_float_as_decimal | false | When set to true, SQL parser will parse float as decimal type | -| datafusion.sql_parser.enable_ident_normalization | true | When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) | -| datafusion.sql_parser.enable_options_value_normalization | false | When set to true, SQL parser will normalize options value (convert value to lowercase). Note that this option is ignored and will be removed in the future. All case-insensitive values are normalized automatically. | -| datafusion.sql_parser.dialect | generic | Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB and Databricks. | -| datafusion.sql_parser.support_varchar_with_length | true | If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but ignore the length. If false, error if a `VARCHAR` with a length is specified. The Arrow type system does not have a notion of maximum string length and thus DataFusion can not enforce such limits. | -| datafusion.sql_parser.map_varchar_to_utf8view | false | If true, `VARCHAR` is mapped to `Utf8View` during SQL planning. If false, `VARCHAR` is mapped to `Utf8` during SQL planning. Default is false. | -| datafusion.sql_parser.collect_spans | false | When set to true, the source locations relative to the original SQL query (i.e. [`Span`](https://docs.rs/sqlparser/latest/sqlparser/tokenizer/struct.Span.html)) will be collected and recorded in the logical plan nodes. | -| datafusion.sql_parser.recursion_limit | 50 | Specifies the recursion depth limit when parsing complex SQL Queries | +| key | default | description | +| ----------------------------------------------------------------------- | ------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| datafusion.catalog.create_default_catalog_and_schema | true | Whether the default catalog and schema should be created automatically. | +| datafusion.catalog.default_catalog | datafusion | The default catalog name - this impacts what SQL queries use if not specified | +| datafusion.catalog.default_schema | public | The default schema name - this impacts what SQL queries use if not specified | +| datafusion.catalog.information_schema | false | Should DataFusion provide access to `information_schema` virtual tables for displaying schema information | +| datafusion.catalog.location | NULL | Location scanned to load tables for `default` schema | +| datafusion.catalog.format | NULL | Type of `TableProvider` to use when loading `default` schema | +| datafusion.catalog.has_header | true | Default value for `format.has_header` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. | +| datafusion.catalog.newlines_in_values | false | Specifies whether newlines in (quoted) CSV values are supported. This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. Parsing newlines in quoted values may be affected by execution behaviour such as parallel file scanning. Setting this to `true` ensures that newlines in values are parsed successfully, which may reduce performance. | +| datafusion.execution.batch_size | 8192 | Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption | +| datafusion.execution.coalesce_batches | true | When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting | +| datafusion.execution.collect_statistics | true | Should DataFusion collect statistics when first creating a table. Has no effect after the table is created. Applies to the default `ListingTableProvider` in DataFusion. Defaults to true. | +| datafusion.execution.target_partitions | 0 | Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system | +| datafusion.execution.time_zone | +00:00 | The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour | +| datafusion.execution.parquet.enable_page_index | true | (reading) If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. | +| datafusion.execution.parquet.pruning | true | (reading) If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file | +| datafusion.execution.parquet.skip_metadata | true | (reading) If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata | +| datafusion.execution.parquet.metadata_size_hint | NULL | (reading) If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer | +| datafusion.execution.parquet.pushdown_filters | false | (reading) If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". | +| datafusion.execution.parquet.reorder_filters | false | (reading) If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | +| datafusion.execution.parquet.schema_force_view_types | true | (reading) If true, parquet reader will read columns of `Utf8/Utf8Large` with `Utf8View`, and `Binary/BinaryLarge` with `BinaryView`. | +| datafusion.execution.parquet.binary_as_string | false | (reading) If true, parquet reader will read columns of `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. Parquet files generated by some legacy writers do not correctly set the UTF8 flag for strings, causing string columns to be loaded as BLOB instead. | +| datafusion.execution.parquet.coerce_int96 | NULL | (reading) If true, parquet reader will read columns of physical type int96 as originating from a different resolution than nanosecond. This is useful for reading data from systems like Spark which stores microsecond resolution timestamps in an int96 allowing it to write values with a larger date range than 64-bit timestamps with nanosecond resolution. | +| datafusion.execution.parquet.bloom_filter_on_read | true | (reading) Use any available bloom filters when reading parquet files | +| datafusion.execution.parquet.data_pagesize_limit | 1048576 | (writing) Sets best effort maximum size of data page in bytes | +| datafusion.execution.parquet.write_batch_size | 1024 | (writing) Sets write_batch_size in bytes | +| datafusion.execution.parquet.writer_version | 1.0 | (writing) Sets parquet writer version valid values are "1.0" and "2.0" | +| datafusion.execution.parquet.skip_arrow_metadata | false | (writing) Skip encoding the embedded arrow metadata in the KV_meta This is analogous to the `ArrowWriterOptions::with_skip_arrow_metadata`. Refer to | +| datafusion.execution.parquet.compression | zstd(3) | (writing) Sets default parquet compression codec. Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting Note that this default setting is not the same as the default parquet writer setting. | +| datafusion.execution.parquet.dictionary_enabled | true | (writing) Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | (writing) Sets best effort maximum dictionary page size, in bytes | +| datafusion.execution.parquet.statistics_enabled | page | (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.max_statistics_size | 4096 | (writing) Sets max statistics size for any column. If NULL, uses default parquet writer setting max_statistics_size is deprecated, currently it is not being used | +| datafusion.execution.parquet.max_row_group_size | 1048576 | (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | +| datafusion.execution.parquet.created_by | datafusion version 48.0.0 | (writing) Sets "created by" property | +| datafusion.execution.parquet.column_index_truncate_length | 64 | (writing) Sets column index truncate length | +| datafusion.execution.parquet.statistics_truncate_length | NULL | (writing) Sets statictics truncate length. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.data_page_row_count_limit | 20000 | (writing) Sets best effort maximum number of rows in data page | +| datafusion.execution.parquet.encoding | NULL | (writing) Sets default encoding for any column. Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.bloom_filter_on_write | false | (writing) Write bloom filters for all columns when creating parquet files | +| datafusion.execution.parquet.bloom_filter_fpp | NULL | (writing) Sets bloom filter false positive probability. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.bloom_filter_ndv | NULL | (writing) Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.allow_single_file_parallelism | true | (writing) Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. | +| datafusion.execution.parquet.maximum_parallel_row_group_writers | 1 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | +| datafusion.execution.parquet.maximum_buffered_record_batches_per_stream | 2 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | +| datafusion.execution.planning_concurrency | 0 | Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system | +| datafusion.execution.skip_physical_aggregate_schema_check | false | When set to true, skips verifying that the schema produced by planning the input of `LogicalPlan::Aggregate` exactly matches the schema of the input plan. When set to false, if the schema does not match exactly (including nullability and metadata), a planning error will be raised. This is used to workaround bugs in the planner that are now caught by the new schema verification step. | +| datafusion.execution.spill_compression | uncompressed | Sets the compression codec used when spilling data to disk. Since datafusion writes spill files using the Arrow IPC Stream format, only codecs supported by the Arrow IPC Stream Writer are allowed. Valid values are: uncompressed, lz4_frame, zstd. Note: lz4_frame offers faster (de)compression, but typically results in larger spill files. In contrast, zstd achieves higher compression ratios at the cost of slower (de)compression speed. | +| datafusion.execution.sort_spill_reservation_bytes | 10485760 | Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). | +| datafusion.execution.sort_in_place_threshold_bytes | 1048576 | When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. | +| datafusion.execution.meta_fetch_concurrency | 32 | Number of files to read in parallel when inferring schema and statistics | +| datafusion.execution.minimum_parallel_output_files | 4 | Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. | +| datafusion.execution.soft_max_rows_per_output_file | 50000000 | Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max | +| datafusion.execution.max_buffered_batches_per_output_file | 2 | This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption | +| datafusion.execution.listing_table_ignore_subdirectory | true | Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). | +| datafusion.execution.enable_recursive_ctes | true | Should DataFusion support recursive CTEs | +| datafusion.execution.split_file_groups_by_statistics | false | Attempt to eliminate sorts by packing & sorting files with non-overlapping statistics into the same file groups. Currently experimental | +| datafusion.execution.keep_partition_by_columns | false | Should DataFusion keep the columns used for partition_by in the output RecordBatches | +| datafusion.execution.skip_partial_aggregation_probe_ratio_threshold | 0.8 | Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input | +| datafusion.execution.skip_partial_aggregation_probe_rows_threshold | 100000 | Number of input rows partial aggregation partition should process, before aggregation ratio check and trying to switch to skipping aggregation mode | +| datafusion.execution.use_row_number_estimates_to_optimize_partitioning | false | Should DataFusion use row number estimates at the input to decide whether increasing parallelism is beneficial or not. By default, only exact row numbers (not estimates) are used for this decision. Setting this flag to `true` will likely produce better plans. if the source of statistics is accurate. We plan to make this the default in the future. | +| datafusion.execution.enforce_batch_size_in_joins | false | Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. | +| datafusion.execution.objectstore_writer_buffer_size | 10485760 | Size (bytes) of data buffer DataFusion uses when writing output files. This affects the size of the data chunks that are uploaded to remote object stores (e.g. AWS S3). If very large (>= 100 GiB) output files are being written, it may be necessary to increase this size to avoid errors from the remote end point. | +| datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | +| datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | +| datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | +| datafusion.optimizer.enable_dynamic_filter_pushdown | true | When set to true attempts to push down dynamic filters generated by operators into the file scan phase. For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. This means that if we already have 10 timestamps in the year 2025 any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. | +| datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | +| datafusion.optimizer.repartition_aggregations | true | Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level | +| datafusion.optimizer.repartition_file_min_size | 10485760 | Minimum total files size in bytes to perform file scan repartitioning. | +| datafusion.optimizer.repartition_joins | true | Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level | +| datafusion.optimizer.allow_symmetric_joins_without_pruning | true | Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. | +| datafusion.optimizer.repartition_file_scans | true | When set to `true`, datasource partitions will be repartitioned to achieve maximum parallelism. This applies to both in-memory partitions and FileSource's file groups (1 group is 1 partition). For FileSources, only Parquet and CSV formats are currently supported. If set to `true` for a FileSource, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false` for a FileSource, different files will be read in parallel, but repartitioning won't happen within a single file. If set to `true` for an in-memory source, all memtable's partitions will have their batches repartitioned evenly to the desired number of `target_partitions`. Repartitioning can change the total number of partitions and batches per partition, but does not slice the initial record tables provided to the MemTable on creation. | +| datafusion.optimizer.repartition_windows | true | Should DataFusion repartition data using the partitions keys to execute window functions in parallel using the provided `target_partitions` level | +| datafusion.optimizer.repartition_sorts | true | Should DataFusion execute sorts in a per-partition fashion and merge afterwards instead of coalescing first and sorting globally. With this flag is enabled, plans in the form below `text "SortExec: [a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` would turn into the plan below which performs better in multithreaded environments `text "SortPreservingMergeExec: [a@0 ASC]", " SortExec: [a@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` | +| datafusion.optimizer.prefer_existing_sort | false | When true, DataFusion will opportunistically remove sorts when the data is already sorted, (i.e. setting `preserve_order` to true on `RepartitionExec` and using `SortPreservingMergeExec`) When false, DataFusion will maximize plan parallelism using `RepartitionExec` even if this requires subsequently resorting data using a `SortExec`. | +| datafusion.optimizer.skip_failed_rules | false | When set to true, the logical plan optimizer will produce warning messages if any optimization rules produce errors and then proceed to the next rule. When set to false, any rules that produce errors will cause the query to fail | +| datafusion.optimizer.max_passes | 3 | Number of times that the optimizer will attempt to optimize the plan | +| datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | +| datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | +| datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | +| datafusion.optimizer.hash_join_single_partition_threshold_rows | 131072 | The maximum estimated size in rows for one input side of a HashJoin will be collected into a single partition | +| datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). | +| datafusion.optimizer.prefer_existing_union | false | When set to true, the optimizer will not attempt to convert Union to Interleave | +| datafusion.optimizer.expand_views_at_output | false | When set to true, if the returned type is a view type then the output will be coerced to a non-view. Coerces `Utf8View` to `LargeUtf8`, and `BinaryView` to `LargeBinary`. | +| datafusion.explain.logical_plan_only | false | When set to true, the explain statement will only print logical plans | +| datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | +| datafusion.explain.show_statistics | false | When set to true, the explain statement will print operator statistics for physical plans | +| datafusion.explain.show_sizes | true | When set to true, the explain statement will print the partition sizes | +| datafusion.explain.show_schema | false | When set to true, the explain statement will print schema information | +| datafusion.explain.format | indent | Display format of explain. Default is "indent". When set to "tree", it will print the plan in a tree-rendered format. | +| datafusion.sql_parser.parse_float_as_decimal | false | When set to true, SQL parser will parse float as decimal type | +| datafusion.sql_parser.enable_ident_normalization | true | When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) | +| datafusion.sql_parser.enable_options_value_normalization | false | When set to true, SQL parser will normalize options value (convert value to lowercase). Note that this option is ignored and will be removed in the future. All case-insensitive values are normalized automatically. | +| datafusion.sql_parser.dialect | generic | Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB and Databricks. | +| datafusion.sql_parser.support_varchar_with_length | true | If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but ignore the length. If false, error if a `VARCHAR` with a length is specified. The Arrow type system does not have a notion of maximum string length and thus DataFusion can not enforce such limits. | +| datafusion.sql_parser.map_string_types_to_utf8view | true | If true, string types (VARCHAR, CHAR, Text, and String) are mapped to `Utf8View` during SQL planning. If false, they are mapped to `Utf8`. Default is true. | +| datafusion.sql_parser.collect_spans | false | When set to true, the source locations relative to the original SQL query (i.e. [`Span`](https://docs.rs/sqlparser/latest/sqlparser/tokenizer/struct.Span.html)) will be collected and recorded in the logical plan nodes. | +| datafusion.sql_parser.recursion_limit | 50 | Specifies the recursion depth limit when parsing complex SQL Queries | +| datafusion.format.safe | true | If set to `true` any formatting errors will be written to the output instead of being converted into a [`std::fmt::Error`] | +| datafusion.format.null | | Format string for nulls | +| datafusion.format.date_format | %Y-%m-%d | Date format for date arrays | +| datafusion.format.datetime_format | %Y-%m-%dT%H:%M:%S%.f | Format for DateTime arrays | +| datafusion.format.timestamp_format | %Y-%m-%dT%H:%M:%S%.f | Timestamp format for timestamp arrays | +| datafusion.format.timestamp_tz_format | NULL | Timestamp format for timestamp with timezone arrays. When `None`, ISO 8601 format is used. | +| datafusion.format.time_format | %H:%M:%S%.f | Time format for time arrays | +| datafusion.format.duration_format | pretty | Duration format. Can be either `"pretty"` or `"ISO8601"` | +| datafusion.format.types_info | false | Show types in visual representation batches | + +# Runtime Configuration Settings + +DataFusion runtime configurations can be set via SQL using the `SET` command. + +For example, to configure `datafusion.runtime.memory_limit`: + +```sql +SET datafusion.runtime.memory_limit = '2G'; +``` + +The following runtime configuration settings are available: + +| key | default | description | +| ------------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------- | +| datafusion.runtime.memory_limit | NULL | Maximum memory limit for query execution. Supports suffixes K (kilobytes), M (megabytes), and G (gigabytes). Example: '2G' for 2 gigabytes. | diff --git a/docs/source/user-guide/dataframe.md b/docs/source/user-guide/dataframe.md index 96be1bb9e256..82f1eeb2823d 100644 --- a/docs/source/user-guide/dataframe.md +++ b/docs/source/user-guide/dataframe.md @@ -50,13 +50,38 @@ use datafusion::prelude::*; Here is a minimal example showing the execution of a query using the DataFrame API. +Create DataFrame using macro API from in memory rows + ```rust use datafusion::prelude::*; use datafusion::error::Result; + +#[tokio::main] +async fn main() -> Result<()> { + // Create a new dataframe with in-memory data using macro + let df = dataframe!( + "a" => [1, 2, 3], + "b" => [true, true, false], + "c" => [Some("foo"), Some("bar"), None] + )?; + df.show().await?; + Ok(()) +} +``` + +Create DataFrame from file or in memory rows using standard API + +```rust +use datafusion::arrow::array::{Int32Array, RecordBatch, StringArray}; +use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::error::Result; use datafusion::functions_aggregate::expr_fn::min; +use datafusion::prelude::*; +use std::sync::Arc; #[tokio::main] async fn main() -> Result<()> { + // Read the data from a csv file let ctx = SessionContext::new(); let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; let df = df.filter(col("a").lt_eq(col("b")))? @@ -64,6 +89,22 @@ async fn main() -> Result<()> { .limit(0, Some(100))?; // Print results df.show().await?; + + // Create a new dataframe with in-memory data + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["foo", "bar", "baz"])), + ], + )?; + let df = ctx.read_batch(batch)?; + df.show().await?; + Ok(()) } ``` diff --git a/docs/source/user-guide/explain-usage.md b/docs/source/user-guide/explain-usage.md index d89ed5f0e7ea..68712012f43f 100644 --- a/docs/source/user-guide/explain-usage.md +++ b/docs/source/user-guide/explain-usage.md @@ -40,7 +40,7 @@ Let's see how DataFusion runs a query that selects the top 5 watch lists for the site `http://domcheloveplanet.ru/`: ```sql -EXPLAIN SELECT "WatchID" AS wid, "hits.parquet"."ClientIP" AS ip +EXPLAIN FORMAT INDENT SELECT "WatchID" AS wid, "hits.parquet"."ClientIP" AS ip FROM 'hits.parquet' WHERE starts_with("URL", 'http://domcheloveplanet.ru/') ORDER BY wid ASC, ip DESC @@ -268,7 +268,7 @@ LIMIT 10; We can again see the query plan by using `EXPLAIN`: ```sql -> EXPLAIN SELECT "UserID", COUNT(*) FROM 'hits.parquet' GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10; +> EXPLAIN FORMAT INDENT SELECT "UserID", COUNT(*) FROM 'hits.parquet' GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10; +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | plan_type | plan | +---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/docs/source/user-guide/features.md b/docs/source/user-guide/features.md index 1f73ce7eac11..4faeb0acf197 100644 --- a/docs/source/user-guide/features.md +++ b/docs/source/user-guide/features.md @@ -93,7 +93,8 @@ - [x] Memory limits enforced - [x] Spilling (to disk) Sort - [x] Spilling (to disk) Grouping -- [ ] Spilling (to disk) Joins +- [x] Spilling (to disk) Sort Merge Join +- [ ] Spilling (to disk) Hash Join ## Data Sources diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 14d6ab177dc3..040405f8f63e 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -40,9 +40,9 @@ Arrow](https://arrow.apache.org/). ## Features - Feature-rich [SQL support](https://datafusion.apache.org/user-guide/sql/index.html) and [DataFrame API](https://datafusion.apache.org/user-guide/dataframe.html) -- Blazingly fast, vectorized, multi-threaded, streaming execution engine. +- Blazingly fast, vectorized, multithreaded, streaming execution engine. - Native support for Parquet, CSV, JSON, and Avro file formats. Support - for custom file formats and non file datasources via the `TableProvider` trait. + for custom file formats and non-file datasources via the `TableProvider` trait. - Many extension points: user defined scalar/aggregate/window functions, DataSources, SQL, other query languages, custom plan and execution nodes, optimizer passes, and more. - Streaming, asynchronous IO directly from popular object stores, including AWS S3, @@ -68,14 +68,14 @@ DataFusion can be used without modification as an embedded SQL engine or can be customized and used as a foundation for building new systems. -While most current usecases are "analytic" or (throughput) some +While most current use cases are "analytic" or (throughput) some components of DataFusion such as the plan representations, are suitable for "streaming" and "transaction" style systems (low latency). Here are some example systems built using DataFusion: -- Specialized Analytical Database systems such as [HoraeDB] and more general Apache Spark like system such a [Ballista]. +- Specialized Analytical Database systems such as [HoraeDB] and more general Apache Spark like system such as [Ballista] - New query language engines such as [prql-query] and accelerators such as [VegaFusion] - Research platform for new Database Systems, such as [Flock] - SQL support to another library, such as [dask sql] @@ -95,19 +95,22 @@ Here are some active projects using DataFusion: - [Arroyo](https://github.com/ArroyoSystems/arroyo) Distributed stream processing engine in Rust +- [ArkFlow](https://github.com/arkflow-rs/arkflow) High-performance Rust stream processing engine - [Ballista](https://github.com/apache/datafusion-ballista) Distributed SQL Query Engine - [Blaze](https://github.com/kwai/blaze) The Blaze accelerator for Apache Spark leverages native vectorized execution to accelerate query processing - [CnosDB](https://github.com/cnosdb/cnosdb) Open Source Distributed Time Series Database - [Comet](https://github.com/apache/datafusion-comet) Apache Spark native query execution plugin -- [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) +- [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) Cube’s universal semantic layer platform is the next evolution of OLAP technology for AI, BI, spreadsheets, and embedded analytics - [Dask SQL](https://github.com/dask-contrib/dask-sql) Distributed SQL query engine in Python - [datafusion-dft](https://github.com/datafusion-contrib/datafusion-dft) Batteries included CLI, TUI, and server implementations for DataFusion. - [delta-rs](https://github.com/delta-io/delta-rs) Native Rust implementation of Delta Lake - [Exon](https://github.com/wheretrue/exon) Analysis toolkit for life-science applications +- [Feldera](https://github.com/feldera/feldera) Fast query engine for incremental computation - [Funnel](https://funnel.io/) Data Platform powering Marketing Intelligence applications. - [GlareDB](https://github.com/GlareDB/glaredb) Fast SQL database for querying and analyzing distributed data. - [GreptimeDB](https://github.com/GreptimeTeam/greptimedb) Open Source & Cloud Native Distributed Time Series Database - [HoraeDB](https://github.com/apache/incubator-horaedb) Distributed Time-Series Database +- [Iceberg-rust](https://github.com/apache/iceberg-rust) Rust implementation of Apache Iceberg - [InfluxDB](https://github.com/influxdata/influxdb) Time Series Database - [Kamu](https://github.com/kamu-data/kamu-cli/) Planet-scale streaming data pipeline - [LakeSoul](https://github.com/lakesoul-io/LakeSoul) Open source LakeHouse framework with native IO in Rust. @@ -118,11 +121,11 @@ Here are some active projects using DataFusion: - [Polygon.io](https://polygon.io/) Stock Market API - [qv](https://github.com/timvw/qv) Quickly view your data - [Restate](https://github.com/restatedev) Easily build resilient applications using distributed durable async/await -- [ROAPI](https://github.com/roapi/roapi) -- [Sail](https://github.com/lakehq/sail) Unifying stream, batch, and AI workloads with Apache Spark compatibility +- [ROAPI](https://github.com/roapi/roapi) Create full-fledged APIs for slowly moving datasets without writing a single line of code +- [Sail](https://github.com/lakehq/sail) Unifying stream, batch and AI workloads with Apache Spark compatibility - [Seafowl](https://github.com/splitgraph/seafowl) CDN-friendly analytical database - [Sleeper](https://github.com/gchq/sleeper) Serverless, cloud-native, log-structured merge tree based, scalable key-value store -- [Spice.ai](https://github.com/spiceai/spiceai) Unified SQL query interface & materialization engine +- [Spice.ai](https://github.com/spiceai/spiceai) Building blocks for data-driven AI applications - [Synnada](https://synnada.ai/) Streaming-first framework for data products - [VegaFusion](https://vegafusion.io/) Server-side acceleration for the [Vega](https://vega.github.io/) visualization grammar - [Telemetry](https://telemetry.sh/) Structured logging made easy @@ -179,6 +182,20 @@ provide integrations with other systems, some of which are described below: ## Why DataFusion? - _High Performance_: Leveraging Rust and Arrow's memory model, DataFusion is very fast. -- _Easy to Connect_: Being part of the Apache Arrow ecosystem (Arrow, Parquet and Flight), DataFusion works well with the rest of the big data ecosystem +- _Easy to Connect_: Being part of the Apache Arrow ecosystem (Arrow, Parquet, and Flight), DataFusion works well with the rest of the big data ecosystem - _Easy to Embed_: Allowing extension at almost any point in its design, and published regularly as a crate on [crates.io](http://crates.io), DataFusion can be integrated and tailored for your specific usecase. - _High Quality_: Extensively tested, both by itself and with the rest of the Arrow ecosystem, DataFusion can and is used as the foundation for production systems. + +## Rust Version Compatibility Policy + +The Rust toolchain releases are tracked at [Rust Versions](https://releases.rs) and follow +[semantic versioning](https://semver.org/). A Rust toolchain release can be identified +by a version string like `1.80.0`, or more generally `major.minor.patch`. + +DataFusion supports the last 4 stable Rust minor versions released and any such versions released within the last 4 months. + +For example, given the releases `1.78.0`, `1.79.0`, `1.80.0`, `1.80.1` and `1.81.0` DataFusion will support 1.78.0, which is 3 minor versions prior to the most minor recent `1.81`. + +Note: If a Rust hotfix is released for the current MSRV, the MSRV will be updated to the specific minor version that includes all applicable hotfixes preceding other policies. + +DataFusion enforces MSRV policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Fdatafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 684db52e6323..774a4fae6bf3 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -808,7 +808,7 @@ approx_distinct(expression) ### `approx_median` -Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(x, 0.5)`. +Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY x)`. ```sql approx_median(expression) @@ -834,7 +834,7 @@ approx_median(expression) Returns the approximate percentile of input values using the t-digest algorithm. ```sql -approx_percentile_cont(expression, percentile, centroids) +approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression) ``` #### Arguments @@ -846,12 +846,12 @@ approx_percentile_cont(expression, percentile, centroids) #### Example ```sql -> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; -+-------------------------------------------------+ -| approx_percentile_cont(column_name, 0.75, 100) | -+-------------------------------------------------+ -| 65.0 | -+-------------------------------------------------+ +> SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++-----------------------------------------------------------------------+ +| approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) | ++-----------------------------------------------------------------------+ +| 65.0 | ++-----------------------------------------------------------------------+ ``` ### `approx_percentile_cont_with_weight` @@ -859,7 +859,7 @@ approx_percentile_cont(expression, percentile, centroids) Returns the weighted approximate percentile of input values using the t-digest algorithm. ```sql -approx_percentile_cont_with_weight(expression, weight, percentile) +approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY expression) ``` #### Arguments @@ -871,10 +871,10 @@ approx_percentile_cont_with_weight(expression, weight, percentile) #### Example ```sql -> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; -+----------------------------------------------------------------------+ -| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | -+----------------------------------------------------------------------+ -| 78.5 | -+----------------------------------------------------------------------+ +> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++---------------------------------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) | ++---------------------------------------------------------------------------------------------+ +| 78.5 | ++---------------------------------------------------------------------------------------------+ ``` diff --git a/docs/source/user-guide/sql/ddl.md b/docs/source/user-guide/sql/ddl.md index 71475cff9a39..89294b7518a4 100644 --- a/docs/source/user-guide/sql/ddl.md +++ b/docs/source/user-guide/sql/ddl.md @@ -74,7 +74,7 @@ LOCATION := ( , ...) ``` -For a detailed list of write related options which can be passed in the OPTIONS key_value_list, see [Write Options](write_options). +For a comprehensive list of format-specific options that can be specified in the `OPTIONS` clause, see [Format Options](format_options.md). `file_type` is one of `CSV`, `ARROW`, `PARQUET`, `AVRO` or `JSON` @@ -82,6 +82,8 @@ For a detailed list of write related options which can be passed in the OPTIONS a path to a file or directory of partitioned files locally or on an object store. +### Example: Parquet + Parquet data sources can be registered by executing a `CREATE EXTERNAL TABLE` SQL statement such as the following. It is not necessary to provide schema information for Parquet files. @@ -91,6 +93,23 @@ STORED AS PARQUET LOCATION '/mnt/nyctaxi/tripdata.parquet'; ``` +:::{note} +Statistics +: By default, when a table is created, DataFusion will read the files +to gather statistics, which can be expensive but can accelerate subsequent +queries substantially. If you don't want to gather statistics +when creating a table, set the `datafusion.execution.collect_statistics` +configuration option to `false` before creating the table. For example: + +```sql +SET datafusion.execution.collect_statistics = false; +``` + +See the [config settings docs](../configs.md) for more details. +::: + +### Example: Comma Separated Value (CSV) + CSV data sources can also be registered by executing a `CREATE EXTERNAL TABLE` SQL statement. The schema will be inferred based on scanning a subset of the file. @@ -101,6 +120,8 @@ LOCATION '/path/to/aggregate_simple.csv' OPTIONS ('has_header' 'true'); ``` +### Example: Compression + It is also possible to use compressed files, such as `.csv.gz`: ```sql @@ -111,6 +132,8 @@ LOCATION '/path/to/aggregate_simple.csv.gz' OPTIONS ('has_header' 'true'); ``` +### Example: Specifying Schema + It is also possible to specify the schema manually. ```sql @@ -134,6 +157,8 @@ LOCATION '/path/to/aggregate_test_100.csv' OPTIONS ('has_header' 'true'); ``` +### Example: Partitioned Tables + It is also possible to specify a directory that contains a partitioned table (multiple files with the same schema) @@ -144,7 +169,9 @@ LOCATION '/path/to/directory/of/files' OPTIONS ('has_header' 'true'); ``` -With `CREATE UNBOUNDED EXTERNAL TABLE` SQL statement. We can create unbounded data sources such as following: +### Example: Unbounded Data Sources + +We can create unbounded data sources using the `CREATE UNBOUNDED EXTERNAL TABLE` SQL statement. ```sql CREATE UNBOUNDED EXTERNAL TABLE taxi @@ -154,6 +181,8 @@ LOCATION '/mnt/nyctaxi/tripdata.parquet'; Note that this statement actually reads data from a fixed-size file, so a better example would involve reading from a FIFO file. Nevertheless, once Datafusion sees the `UNBOUNDED` keyword in a data source, it tries to execute queries that refer to this unbounded source in streaming fashion. If this is not possible according to query specifications, plan generation fails stating it is not possible to execute given query in streaming fashion. Note that queries that can run with unbounded sources (i.e. in streaming mode) are a subset of those that can with bounded sources. A query that fails with unbounded source(s) may work with bounded source(s). +### Example: `WITH ORDER` Clause + When creating an output from a data source that is already ordered by an expression, you can pre-specify the order of the data using the `WITH ORDER` clause. This applies even if the expression used for @@ -190,7 +219,7 @@ WITH ORDER (sort_expression1 [ASC | DESC] [NULLS { FIRST | LAST }] [, sort_expression2 [ASC | DESC] [NULLS { FIRST | LAST }] ...]) ``` -### Cautions when using the WITH ORDER Clause +#### Cautions when using the WITH ORDER Clause - It's important to understand that using the `WITH ORDER` clause in the `CREATE EXTERNAL TABLE` statement only specifies the order in which the data should be read from the external file. If the data in the file is not already sorted according to the specified order, then the results may not be correct. @@ -287,3 +316,78 @@ DROP VIEW [ IF EXISTS ] view_name; -- drop users_v view from the customer_a schema DROP VIEW IF EXISTS customer_a.users_v; ``` + +## DESCRIBE + +Displays the schema of a table, showing column names, data types, and nullable status. Both `DESCRIBE` and `DESC` are supported as aliases. + +

+{ DESCRIBE | DESC } table_name
+
+ +The output contains three columns: + +- `column_name`: The name of the column +- `data_type`: The data type of the column (e.g., Int32, Utf8, Boolean) +- `is_nullable`: Whether the column can contain null values (YES/NO) + +### Example: Basic table description + +```sql +-- Create a table +CREATE TABLE users AS VALUES (1, 'Alice', true), (2, 'Bob', false); + +-- Describe the table structure +DESCRIBE users; +``` + +Output: + +```sql ++--------------+-----------+-------------+ +| column_name | data_type | is_nullable | ++--------------+-----------+-------------+ +| column1 | Int64 | YES | +| column2 | Utf8 | YES | +| column3 | Boolean | YES | ++--------------+-----------+-------------+ +``` + +### Example: Using DESC alias + +```sql +-- DESC is an alias for DESCRIBE +DESC users; +``` + +### Example: Describing external tables + +```sql +-- Create an external table +CREATE EXTERNAL TABLE taxi +STORED AS PARQUET +LOCATION '/mnt/nyctaxi/tripdata.parquet'; + +-- Describe its schema +DESCRIBE taxi; +``` + +Output might show: + +```sql ++--------------------+-----------------------------+-------------+ +| column_name | data_type | is_nullable | ++--------------------+-----------------------------+-------------+ +| vendor_id | Int32 | YES | +| pickup_datetime | Timestamp(Nanosecond, None) | NO | +| passenger_count | Int32 | YES | +| trip_distance | Float64 | YES | ++--------------------+-----------------------------+-------------+ +``` + +The `DESCRIBE` command works with all table types in DataFusion, including: + +- Regular tables created with `CREATE TABLE` +- External tables created with `CREATE EXTERNAL TABLE` +- Views created with `CREATE VIEW` +- Tables in different schemas using qualified names (e.g., `DESCRIBE schema_name.table_name`) diff --git a/docs/source/user-guide/sql/dml.md b/docs/source/user-guide/sql/dml.md index 4eda59d6dea1..c29447f23cd9 100644 --- a/docs/source/user-guide/sql/dml.md +++ b/docs/source/user-guide/sql/dml.md @@ -49,7 +49,7 @@ The output format is determined by the first match of the following rules: 1. Value of `STORED AS` 2. Filename extension (e.g. `foo.parquet` implies `PARQUET` format) -For a detailed list of valid OPTIONS, see [Write Options](write_options). +For a detailed list of valid OPTIONS, see [Format Options](format_options.md). ### Examples diff --git a/docs/source/user-guide/sql/explain.md b/docs/source/user-guide/sql/explain.md index 9984de147ecc..c5e2e215a6b6 100644 --- a/docs/source/user-guide/sql/explain.md +++ b/docs/source/user-guide/sql/explain.md @@ -118,7 +118,7 @@ See [Reading Explain Plans](../explain-usage.md) for more information on how to 0 row(s) fetched. Elapsed 0.004 seconds. -> EXPLAIN SELECT SUM(x) FROM t GROUP BY b; +> EXPLAIN FORMAT INDENT SELECT SUM(x) FROM t GROUP BY b; +---------------+-------------------------------------------------------------------------------+ | plan_type | plan | +---------------+-------------------------------------------------------------------------------+ diff --git a/docs/source/user-guide/sql/format_options.md b/docs/source/user-guide/sql/format_options.md new file mode 100644 index 000000000000..e8008eafb166 --- /dev/null +++ b/docs/source/user-guide/sql/format_options.md @@ -0,0 +1,180 @@ + + +# Format Options + +DataFusion supports customizing how data is read from or written to disk as a result of a `COPY`, `INSERT INTO`, or `CREATE EXTERNAL TABLE` statements. There are a few special options, file format (e.g., CSV or Parquet) specific options, and Parquet column-specific options. In some cases, Options can be specified in multiple ways with a set order of precedence. + +## Specifying Options and Order of Precedence + +Format-related options can be specified in three ways, in decreasing order of precedence: + +- `CREATE EXTERNAL TABLE` syntax +- `COPY` option tuples +- Session-level config defaults + +For a list of supported session-level config defaults, see [Configuration Settings](../configs). These defaults apply to all operations but have the lowest level of precedence. + +If creating an external table, table-specific format options can be specified when the table is created using the `OPTIONS` clause: + +```sql +CREATE EXTERNAL TABLE + my_table(a bigint, b bigint) + STORED AS csv + LOCATION '/tmp/my_csv_table/' + OPTIONS( + NULL_VALUE 'NAN', + 'has_header' 'true', + 'format.delimiter' ';' + ); +``` + +When running `INSERT INTO my_table ...`, the options from the `CREATE TABLE` will be respected (e.g., gzip compression, special delimiter, and header row included). Note that compression, header, and delimiter settings can also be specified within the `OPTIONS` tuple list. Dedicated syntax within the SQL statement always takes precedence over arbitrary option tuples, so if both are specified, the `OPTIONS` setting will be ignored. + +For example, with the table defined above, running the following command: + +```sql +INSERT INTO my_table VALUES(1,2); +``` + +Results in a new CSV file with the specified options: + +```shell +$ cat /tmp/my_csv_table/bmC8zWFvLMtWX68R_0.csv +a;b +1;2 +``` + +Finally, options can be passed when running a `COPY` command. + +```sql +COPY source_table + TO 'test/table_with_options' + PARTITIONED BY (column3, column4) + OPTIONS ( + format parquet, + compression snappy, + 'compression::column1' 'zstd(5)', + ) +``` + +In this example, we write the entire `source_table` out to a folder of Parquet files. One Parquet file will be written in parallel to the folder for each partition in the query. The next option `compression` set to `snappy` indicates that unless otherwise specified, all columns should use the snappy compression codec. The option `compression::col1` sets an override, so that the column `col1` in the Parquet file will use the ZSTD compression codec with compression level `5`. In general, Parquet options that support column-specific settings can be specified with the syntax `OPTION::COLUMN.NESTED.PATH`. + +# Available Options + +## JSON Format Options + +The following options are available when reading or writing JSON files. Note: If any unsupported option is specified, an error will be raised and the query will fail. + +| Option | Description | Default Value | +| ----------- | ---------------------------------------------------------------------------------------------------------------------------------- | ------------- | +| COMPRESSION | Sets the compression that should be applied to the entire JSON file. Supported values are GZIP, BZIP2, XZ, ZSTD, and UNCOMPRESSED. | UNCOMPRESSED | + +**Example:** + +```sql +CREATE EXTERNAL TABLE t(a int) +STORED AS JSON +LOCATION '/tmp/foo/' +OPTIONS('COMPRESSION' 'gzip'); +``` + +## CSV Format Options + +The following options are available when reading or writing CSV files. Note: If any unsupported option is specified, an error will be raised and the query will fail. + +| Option | Description | Default Value | +| -------------------- | --------------------------------------------------------------------------------------------------------------------------------- | ------------------ | +| COMPRESSION | Sets the compression that should be applied to the entire CSV file. Supported values are GZIP, BZIP2, XZ, ZSTD, and UNCOMPRESSED. | UNCOMPRESSED | +| HAS_HEADER | Sets if the CSV file should include column headers. If not set, uses session or system default. | None | +| DELIMITER | Sets the character which should be used as the column delimiter within the CSV file. | `,` (comma) | +| QUOTE | Sets the character which should be used for quoting values within the CSV file. | `"` (double quote) | +| TERMINATOR | Sets the character which should be used as the line terminator within the CSV file. | None | +| ESCAPE | Sets the character which should be used for escaping special characters within the CSV file. | None | +| DOUBLE_QUOTE | Sets if quotes within quoted fields should be escaped by doubling them (e.g., `"aaa""bbb"`). | None | +| NEWLINES_IN_VALUES | Sets if newlines in quoted values are supported. If not set, uses session or system default. | None | +| DATE_FORMAT | Sets the format that dates should be encoded in within the CSV file. | None | +| DATETIME_FORMAT | Sets the format that datetimes should be encoded in within the CSV file. | None | +| TIMESTAMP_FORMAT | Sets the format that timestamps should be encoded in within the CSV file. | None | +| TIMESTAMP_TZ_FORMAT | Sets the format that timestamps with timezone should be encoded in within the CSV file. | None | +| TIME_FORMAT | Sets the format that times should be encoded in within the CSV file. | None | +| NULL_VALUE | Sets the string which should be used to indicate null values within the CSV file. | None | +| NULL_REGEX | Sets the regex pattern to match null values when loading CSVs. | None | +| SCHEMA_INFER_MAX_REC | Sets the maximum number of records to scan to infer the schema. | None | +| COMMENT | Sets the character which should be used to indicate comment lines in the CSV file. | None | + +**Example:** + +```sql +CREATE EXTERNAL TABLE t (col1 varchar, col2 int, col3 boolean) +STORED AS CSV +LOCATION '/tmp/foo/' +OPTIONS('DELIMITER' '|', 'HAS_HEADER' 'true', 'NEWLINES_IN_VALUES' 'true'); +``` + +## Parquet Format Options + +The following options are available when reading or writing Parquet files. If any unsupported option is specified, an error will be raised and the query will fail. If a column-specific option is specified for a column that does not exist, the option will be ignored without error. + +| Option | Can be Column Specific? | Description | OPTIONS Key | Default Value | +| ------------------------------------------ | ----------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------- | ------------------------ | +| COMPRESSION | Yes | Sets the internal Parquet **compression codec** for data pages, optionally including the compression level. Applies globally if set without `::col`, or specifically to a column if set using `'compression::column_name'`. Valid values: `uncompressed`, `snappy`, `gzip(level)`, `lzo`, `brotli(level)`, `lz4`, `zstd(level)`, `lz4_raw`. | `'compression'` or `'compression::col'` | zstd(3) | +| ENCODING | Yes | Sets the **encoding** scheme for data pages. Valid values: `plain`, `plain_dictionary`, `rle`, `bit_packed`, `delta_binary_packed`, `delta_length_byte_array`, `delta_byte_array`, `rle_dictionary`, `byte_stream_split`. Use key `'encoding'` or `'encoding::col'` in OPTIONS. | `'encoding'` or `'encoding::col'` | None | +| DICTIONARY_ENABLED | Yes | Sets whether dictionary encoding should be enabled globally or for a specific column. | `'dictionary_enabled'` or `'dictionary_enabled::col'` | true | +| STATISTICS_ENABLED | Yes | Sets the level of statistics to write (`none`, `chunk`, `page`). | `'statistics_enabled'` or `'statistics_enabled::col'` | page | +| BLOOM_FILTER_ENABLED | Yes | Sets whether a bloom filter should be written for a specific column. | `'bloom_filter_enabled::column_name'` | None | +| BLOOM_FILTER_FPP | Yes | Sets bloom filter false positive probability (global or per column). | `'bloom_filter_fpp'` or `'bloom_filter_fpp::col'` | None | +| BLOOM_FILTER_NDV | Yes | Sets bloom filter number of distinct values (global or per column). | `'bloom_filter_ndv'` or `'bloom_filter_ndv::col'` | None | +| MAX_ROW_GROUP_SIZE | No | Sets the maximum number of rows per row group. Larger groups require more memory but can improve compression and scan efficiency. | `'max_row_group_size'` | 1048576 | +| ENABLE_PAGE_INDEX | No | If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce I/O and decoding. | `'enable_page_index'` | true | +| PRUNING | No | If true, enables row group pruning based on min/max statistics. | `'pruning'` | true | +| SKIP_METADATA | No | If true, skips optional embedded metadata in the file schema. | `'skip_metadata'` | true | +| METADATA_SIZE_HINT | No | Sets the size hint (in bytes) for fetching Parquet file metadata. | `'metadata_size_hint'` | None | +| PUSHDOWN_FILTERS | No | If true, enables filter pushdown during Parquet decoding. | `'pushdown_filters'` | false | +| REORDER_FILTERS | No | If true, enables heuristic reordering of filters during Parquet decoding. | `'reorder_filters'` | false | +| SCHEMA_FORCE_VIEW_TYPES | No | If true, reads Utf8/Binary columns as view types. | `'schema_force_view_types'` | true | +| BINARY_AS_STRING | No | If true, reads Binary columns as strings. | `'binary_as_string'` | false | +| DATA_PAGESIZE_LIMIT | No | Sets best effort maximum size of data page in bytes. | `'data_pagesize_limit'` | 1048576 | +| DATA_PAGE_ROW_COUNT_LIMIT | No | Sets best effort maximum number of rows in data page. | `'data_page_row_count_limit'` | 20000 | +| DICTIONARY_PAGE_SIZE_LIMIT | No | Sets best effort maximum dictionary page size, in bytes. | `'dictionary_page_size_limit'` | 1048576 | +| WRITE_BATCH_SIZE | No | Sets write_batch_size in bytes. | `'write_batch_size'` | 1024 | +| WRITER_VERSION | No | Sets the Parquet writer version (`1.0` or `2.0`). | `'writer_version'` | 1.0 | +| SKIP_ARROW_METADATA | No | If true, skips writing Arrow schema information into the Parquet file metadata. | `'skip_arrow_metadata'` | false | +| CREATED_BY | No | Sets the "created by" string in the Parquet file metadata. | `'created_by'` | datafusion version X.Y.Z | +| COLUMN_INDEX_TRUNCATE_LENGTH | No | Sets the length (in bytes) to truncate min/max values in column indexes. | `'column_index_truncate_length'` | 64 | +| STATISTICS_TRUNCATE_LENGTH | No | Sets statistics truncate length. | `'statistics_truncate_length'` | None | +| BLOOM_FILTER_ON_WRITE | No | Sets whether bloom filters should be written for all columns by default (can be overridden per column). | `'bloom_filter_on_write'` | false | +| ALLOW_SINGLE_FILE_PARALLELISM | No | Enables parallel serialization of columns in a single file. | `'allow_single_file_parallelism'` | true | +| MAXIMUM_PARALLEL_ROW_GROUP_WRITERS | No | Maximum number of parallel row group writers. | `'maximum_parallel_row_group_writers'` | 1 | +| MAXIMUM_BUFFERED_RECORD_BATCHES_PER_STREAM | No | Maximum number of buffered record batches per stream. | `'maximum_buffered_record_batches_per_stream'` | 2 | +| KEY_VALUE_METADATA | No (Key is specific) | Adds custom key-value pairs to the file metadata. Use the format `'metadata::your_key_name' 'your_value'`. Multiple entries allowed. | `'metadata::key_name'` | None | + +**Example:** + +```sql +CREATE EXTERNAL TABLE t (id bigint, value double, category varchar) +STORED AS PARQUET +LOCATION '/tmp/parquet_data/' +OPTIONS( + 'COMPRESSION::user_id' 'snappy', + 'ENCODING::col_a' 'delta_binary_packed', + 'MAX_ROW_GROUP_SIZE' '1000000', + 'BLOOM_FILTER_ENABLED::id' 'true' +); +``` diff --git a/docs/source/user-guide/sql/index.rst b/docs/source/user-guide/sql/index.rst index 8e3f51bf8b0b..a13d40334b63 100644 --- a/docs/source/user-guide/sql/index.rst +++ b/docs/source/user-guide/sql/index.rst @@ -33,5 +33,5 @@ SQL Reference window_functions scalar_functions special_functions - write_options + format_options prepared_statements diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 0f08934c8a9c..eb4b86e4b486 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -2133,7 +2133,7 @@ _Alias of [date_trunc](#date_trunc)._ ### `from_unixtime` -Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp. +Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp. ```sql from_unixtime(expression[, timezone]) @@ -2552,6 +2552,7 @@ _Alias of [current_date](#current_date)._ - [array_join](#array_join) - [array_length](#array_length) - [array_max](#array_max) +- [array_min](#array_min) - [array_ndims](#array_ndims) - [array_pop_back](#array_pop_back) - [array_pop_front](#array_pop_front) @@ -3058,6 +3059,29 @@ array_max(array) - list_max +### `array_min` + +Returns the minimum value in the array. + +```sql +array_min(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_min([3,1,4,2]); ++-----------------------------------------+ +| array_min(List([3,1,4,2])) | ++-----------------------------------------+ +| 1 | ++-----------------------------------------+ +``` + ### `array_ndims` Returns the number of dimensions of the array. @@ -3142,7 +3166,7 @@ array_pop_front(array) ### `array_position` -Returns the position of the first occurrence of the specified element in the array. +Returns the position of the first occurrence of the specified element in the array, or NULL if not found. ```sql array_position(array, element) @@ -3153,7 +3177,7 @@ array_position(array, element, index) - **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - **element**: Element to search for position in the array. -- **index**: Index at which to start searching. +- **index**: Index at which to start searching (1-indexed). #### Example @@ -4105,6 +4129,7 @@ select struct(a as field_a, b) from t; - [element_at](#element_at) - [map](#map) +- [map_entries](#map_entries) - [map_extract](#map_extract) - [map_keys](#map_keys) - [map_values](#map_values) @@ -4162,6 +4187,30 @@ SELECT MAKE_MAP(['key1', 'key2'], ['value1', null]); {key1: value1, key2: } ``` +### `map_entries` + +Returns a list of all entries in the map. + +```sql +map_entries(map) +``` + +#### Arguments + +- **map**: Map expression. Can be a constant, column, or function, and any combination of map operators. + +#### Example + +```sql +SELECT map_entries(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[{'key': a, 'value': 1}, {'key': b, 'value': NULL}, {'key': c, 'value': 3}] + +SELECT map_entries(map([100, 5], [42, 43])); +---- +[{'key': 100, 'value': 42}, {'key': 5, 'value': 43}] +``` + ### `map_extract` Returns a list containing the value for the given key or an empty list if the key is not present in the map. @@ -4404,6 +4453,7 @@ sha512(expression) Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator - [union_extract](#union_extract) +- [union_tag](#union_tag) ### `union_extract` @@ -4433,6 +4483,33 @@ union_extract(union, field_name) +--------------+----------------------------------+----------------------------------+ ``` +### `union_tag` + +Returns the name of the currently selected field in the union + +```sql +union_tag(union_expression) +``` + +#### Arguments + +- **union**: Union expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +❯ select union_column, union_tag(union_column) from table_with_union; ++--------------+-------------------------+ +| union_column | union_tag(union_column) | ++--------------+-------------------------+ +| {a=1} | a | +| {b=3.0} | b | +| {a=4} | a | +| {b=} | b | +| {a=} | a | ++--------------+-------------------------+ +``` + ## Other Functions - [arrow_cast](#arrow_cast) diff --git a/docs/source/user-guide/sql/select.md b/docs/source/user-guide/sql/select.md index b2fa0a630588..84aac431a6a2 100644 --- a/docs/source/user-guide/sql/select.md +++ b/docs/source/user-guide/sql/select.md @@ -84,7 +84,7 @@ SELECT a FROM table WHERE a > 10 ## JOIN clause -DataFusion supports `INNER JOIN`, `LEFT OUTER JOIN`, `RIGHT OUTER JOIN`, `FULL OUTER JOIN`, `NATURAL JOIN` and `CROSS JOIN`. +DataFusion supports `INNER JOIN`, `LEFT OUTER JOIN`, `RIGHT OUTER JOIN`, `FULL OUTER JOIN`, `NATURAL JOIN`, `CROSS JOIN`, `LEFT SEMI JOIN`, `RIGHT SEMI JOIN`, `LEFT ANTI JOIN`, and `RIGHT ANTI JOIN`. The following examples are based on this table: @@ -102,7 +102,7 @@ select * from x; The keywords `JOIN` or `INNER JOIN` define a join that only shows rows where there is a match in both tables. ```sql -select * from x inner join x y ON x.column_1 = y.column_1; +SELECT * FROM x INNER JOIN x y ON x.column_1 = y.column_1; +----------+----------+----------+----------+ | column_1 | column_2 | column_1 | column_2 | +----------+----------+----------+----------+ @@ -116,7 +116,7 @@ The keywords `LEFT JOIN` or `LEFT OUTER JOIN` define a join that includes all ro is not a match in the right table. When there is no match, null values are produced for the right side of the join. ```sql -select * from x left join x y ON x.column_1 = y.column_2; +SELECT * FROM x LEFT JOIN x y ON x.column_1 = y.column_2; +----------+----------+----------+----------+ | column_1 | column_2 | column_1 | column_2 | +----------+----------+----------+----------+ @@ -130,7 +130,7 @@ The keywords `RIGHT JOIN` or `RIGHT OUTER JOIN` define a join that includes all is not a match in the left table. When there is no match, null values are produced for the left side of the join. ```sql -select * from x right join x y ON x.column_1 = y.column_2; +SELECT * FROM x RIGHT JOIN x y ON x.column_1 = y.column_2; +----------+----------+----------+----------+ | column_1 | column_2 | column_1 | column_2 | +----------+----------+----------+----------+ @@ -145,7 +145,7 @@ The keywords `FULL JOIN` or `FULL OUTER JOIN` define a join that is effectively either side of the join where there is not a match. ```sql -select * from x full outer join x y ON x.column_1 = y.column_2; +SELECT * FROM x FULL OUTER JOIN x y ON x.column_1 = y.column_2; +----------+----------+----------+----------+ | column_1 | column_2 | column_1 | column_2 | +----------+----------+----------+----------+ @@ -156,11 +156,11 @@ select * from x full outer join x y ON x.column_1 = y.column_2; ### NATURAL JOIN -A natural join defines an inner join based on common column names found between the input tables. When no common -column names are found, it behaves like a cross join. +A `NATURAL JOIN` defines an inner join based on common column names found between the input tables. When no common +column names are found, it behaves like a `CROSS JOIN`. ```sql -select * from x natural join x y; +SELECT * FROM x NATURAL JOIN x y; +----------+----------+ | column_1 | column_2 | +----------+----------+ @@ -170,11 +170,11 @@ select * from x natural join x y; ### CROSS JOIN -A cross join produces a cartesian product that matches every row in the left side of the join with every row in the +A `CROSS JOIN` produces a cartesian product that matches every row in the left side of the join with every row in the right side of the join. ```sql -select * from x cross join x y; +SELECT * FROM x CROSS JOIN x y; +----------+----------+----------+----------+ | column_1 | column_2 | column_1 | column_2 | +----------+----------+----------+----------+ @@ -182,6 +182,60 @@ select * from x cross join x y; +----------+----------+----------+----------+ ``` +### LEFT SEMI JOIN + +The `LEFT SEMI JOIN` returns all rows from the left table that have at least one matching row in the right table, and +projects only the columns from the left table. + +```sql +SELECT * FROM x LEFT SEMI JOIN x y ON x.column_1 = y.column_1; ++----------+----------+ +| column_1 | column_2 | ++----------+----------+ +| 1 | 2 | ++----------+----------+ +``` + +### RIGHT SEMI JOIN + +The `RIGHT SEMI JOIN` returns all rows from the right table that have at least one matching row in the left table, and +only projects the columns from the right table. + +```sql +SELECT * FROM x RIGHT SEMI JOIN x y ON x.column_1 = y.column_1; ++----------+----------+ +| column_1 | column_2 | ++----------+----------+ +| 1 | 2 | ++----------+----------+ +``` + +### LEFT ANTI JOIN + +The `LEFT ANTI JOIN` returns all rows from the left table that do not have any matching row in the right table, projecting +only the left table’s columns. + +```sql +SELECT * FROM x LEFT ANTI JOIN x y ON x.column_1 = y.column_1; ++----------+----------+ +| column_1 | column_2 | ++----------+----------+ ++----------+----------+ +``` + +### RIGHT ANTI JOIN + +The `RIGHT ANTI JOIN` returns all rows from the right table that do not have any matching row in the left table, projecting +only the right table’s columns. + +```sql +SELECT * FROM x RIGHT ANTI JOIN x y ON x.column_1 = y.column_1; ++----------+----------+ +| column_1 | column_2 | ++----------+----------+ ++----------+----------+ +``` + ## GROUP BY clause Example: diff --git a/docs/source/user-guide/sql/window_functions.md b/docs/source/user-guide/sql/window_functions.md index 68a700380312..bcb33bad7fb5 100644 --- a/docs/source/user-guide/sql/window_functions.md +++ b/docs/source/user-guide/sql/window_functions.md @@ -193,6 +193,29 @@ Returns the rank of the current row without gaps. This function ranks rows in a dense_rank() ``` +#### Example + +```sql + --Example usage of the dense_rank window function: + SELECT department, + salary, + dense_rank() OVER (PARTITION BY department ORDER BY salary DESC) AS dense_rank + FROM employees; +``` + +```sql ++-------------+--------+------------+ +| department | salary | dense_rank | ++-------------+--------+------------+ +| Sales | 70000 | 1 | +| Sales | 50000 | 2 | +| Sales | 50000 | 2 | +| Sales | 30000 | 3 | +| Engineering | 90000 | 1 | +| Engineering | 80000 | 2 | ++-------------+--------+------------+ +``` + ### `ntile` Integer ranging from 1 to the argument value, dividing the partition as equally as possible @@ -205,6 +228,31 @@ ntile(expression) - **expression**: An integer describing the number groups the partition should be split into +#### Example + +```sql + --Example usage of the ntile window function: + SELECT employee_id, + salary, + ntile(4) OVER (ORDER BY salary DESC) AS quartile + FROM employees; +``` + +```sql ++-------------+--------+----------+ +| employee_id | salary | quartile | ++-------------+--------+----------+ +| 1 | 90000 | 1 | +| 2 | 85000 | 1 | +| 3 | 80000 | 2 | +| 4 | 70000 | 2 | +| 5 | 60000 | 3 | +| 6 | 50000 | 3 | +| 7 | 40000 | 4 | +| 8 | 30000 | 4 | ++-------------+--------+----------+ +``` + ### `percent_rank` Returns the percentage rank of the current row within its partition. The value ranges from 0 to 1 and is computed as `(rank - 1) / (total_rows - 1)`. @@ -213,6 +261,26 @@ Returns the percentage rank of the current row within its partition. The value r percent_rank() ``` +#### Example + +```sql + --Example usage of the percent_rank window function: + SELECT employee_id, + salary, + percent_rank() OVER (ORDER BY salary) AS percent_rank + FROM employees; +``` + +```sql ++-------------+--------+---------------+ +| employee_id | salary | percent_rank | ++-------------+--------+---------------+ +| 1 | 30000 | 0.00 | +| 2 | 50000 | 0.50 | +| 3 | 70000 | 1.00 | ++-------------+--------+---------------+ +``` + ### `rank` Returns the rank of the current row within its partition, allowing gaps between ranks. This function provides a ranking similar to `row_number`, but skips ranks for identical values. @@ -221,6 +289,29 @@ Returns the rank of the current row within its partition, allowing gaps between rank() ``` +#### Example + +```sql + --Example usage of the rank window function: + SELECT department, + salary, + rank() OVER (PARTITION BY department ORDER BY salary DESC) AS rank + FROM employees; +``` + +```sql ++-------------+--------+------+ +| department | salary | rank | ++-------------+--------+------+ +| Sales | 70000 | 1 | +| Sales | 50000 | 2 | +| Sales | 50000 | 2 | +| Sales | 30000 | 4 | +| Engineering | 90000 | 1 | +| Engineering | 80000 | 2 | ++-------------+--------+------+ +``` + ### `row_number` Number of the current row within its partition, counting from 1. @@ -229,6 +320,30 @@ Number of the current row within its partition, counting from 1. row_number() ``` +#### Example + +```sql + --Example usage of the row_number window function: + SELECT department, + salary, + row_number() OVER (PARTITION BY department ORDER BY salary DESC) AS row_num + FROM employees; +``` + +````sql ++-------------+--------+---------+ +| department | salary | row_num | ++-------------+--------+---------+ +| Sales | 70000 | 1 | +| Sales | 50000 | 2 | +| Sales | 50000 | 3 | +| Sales | 30000 | 4 | +| Engineering | 90000 | 1 | +| Engineering | 80000 | 2 | ++-------------+--------+---------+ +```# + + ## Analytical Functions - [first_value](#first_value) @@ -243,12 +358,35 @@ Returns value evaluated at the row that is the first row of the window frame. ```sql first_value(expression) -``` +```` #### Arguments - **expression**: Expression to operate on +#### Example + +```sql + --Example usage of the first_value window function: + SELECT department, + employee_id, + salary, + first_value(salary) OVER (PARTITION BY department ORDER BY salary DESC) AS top_salary + FROM employees; +``` + +```sql ++-------------+-------------+--------+------------+ +| department | employee_id | salary | top_salary | ++-------------+-------------+--------+------------+ +| Sales | 1 | 70000 | 70000 | +| Sales | 2 | 50000 | 70000 | +| Sales | 3 | 30000 | 70000 | +| Engineering | 4 | 90000 | 90000 | +| Engineering | 5 | 80000 | 90000 | ++-------------+-------------+--------+------------+ +``` + ### `lag` Returns value evaluated at the row that is offset rows before the current row within the partition; if there is no such row, instead return default (which must be of the same type as value). @@ -263,6 +401,27 @@ lag(expression, offset, default) - **offset**: Integer. Specifies how many rows back the value of expression should be retrieved. Defaults to 1. - **default**: The default value if the offset is not within the partition. Must be of the same type as expression. +#### Example + +```sql + --Example usage of the lag window function: + SELECT employee_id, + salary, + lag(salary, 1, 0) OVER (ORDER BY employee_id) AS prev_salary + FROM employees; +``` + +```sql ++-------------+--------+-------------+ +| employee_id | salary | prev_salary | ++-------------+--------+-------------+ +| 1 | 30000 | 0 | +| 2 | 50000 | 30000 | +| 3 | 70000 | 50000 | +| 4 | 60000 | 70000 | ++-------------+--------+-------------+ +``` + ### `last_value` Returns value evaluated at the row that is the last row of the window frame. @@ -275,6 +434,29 @@ last_value(expression) - **expression**: Expression to operate on +#### Example + +```sql +-- SQL example of last_value: +SELECT department, + employee_id, + salary, + last_value(salary) OVER (PARTITION BY department ORDER BY salary) AS running_last_salary +FROM employees; +``` + +```sql ++-------------+-------------+--------+---------------------+ +| department | employee_id | salary | running_last_salary | ++-------------+-------------+--------+---------------------+ +| Sales | 1 | 30000 | 30000 | +| Sales | 2 | 50000 | 50000 | +| Sales | 3 | 70000 | 70000 | +| Engineering | 4 | 40000 | 40000 | +| Engineering | 5 | 60000 | 60000 | ++-------------+-------------+--------+---------------------+ +``` + ### `lead` Returns value evaluated at the row that is offset rows after the current row within the partition; if there is no such row, instead return default (which must be of the same type as value). @@ -289,6 +471,30 @@ lead(expression, offset, default) - **offset**: Integer. Specifies how many rows forward the value of expression should be retrieved. Defaults to 1. - **default**: The default value if the offset is not within the partition. Must be of the same type as expression. +#### Example + +```sql +-- Example usage of lead() : +SELECT + employee_id, + department, + salary, + lead(salary, 1, 0) OVER (PARTITION BY department ORDER BY salary) AS next_salary +FROM employees; +``` + +```sql ++-------------+-------------+--------+--------------+ +| employee_id | department | salary | next_salary | ++-------------+-------------+--------+--------------+ +| 1 | Sales | 30000 | 50000 | +| 2 | Sales | 50000 | 70000 | +| 3 | Sales | 70000 | 0 | +| 4 | Engineering | 40000 | 60000 | +| 5 | Engineering | 60000 | 0 | ++-------------+-------------+--------+--------------+ +``` + ### `nth_value` Returns the value evaluated at the nth row of the window frame (counting from 1). Returns NULL if no such row exists. diff --git a/docs/source/user-guide/sql/write_options.md b/docs/source/user-guide/sql/write_options.md deleted file mode 100644 index 521e29436212..000000000000 --- a/docs/source/user-guide/sql/write_options.md +++ /dev/null @@ -1,127 +0,0 @@ - - -# Write Options - -DataFusion supports customizing how data is written out to disk as a result of a `COPY` or `INSERT INTO` query. There are a few special options, file format (e.g. CSV or parquet) specific options, and parquet column specific options. Options can also in some cases be specified in multiple ways with a set order of precedence. - -## Specifying Options and Order of Precedence - -Write related options can be specified in the following ways: - -- Session level config defaults -- `CREATE EXTERNAL TABLE` options -- `COPY` option tuples - -For a list of supported session level config defaults see [Configuration Settings](../configs). These defaults apply to all write operations but have the lowest level of precedence. - -If inserting to an external table, table specific write options can be specified when the table is created using the `OPTIONS` clause: - -```sql -CREATE EXTERNAL TABLE - my_table(a bigint, b bigint) - STORED AS csv - COMPRESSION TYPE gzip - LOCATION '/test/location/my_csv_table/' - OPTIONS( - NULL_VALUE 'NAN', - 'has_header' 'true', - 'format.delimiter' ';' - ) -``` - -When running `INSERT INTO my_table ...`, the options from the `CREATE TABLE` will be respected (gzip compression, special delimiter, and header row included). There will be a single output file if the output path doesn't have folder format, i.e. ending with a `\`. Note that compression, header, and delimiter settings can also be specified within the `OPTIONS` tuple list. Dedicated syntax within the SQL statement always takes precedence over arbitrary option tuples, so if both are specified the `OPTIONS` setting will be ignored. NULL_VALUE is a CSV format specific option that determines how null values should be encoded within the CSV file. - -Finally, options can be passed when running a `COPY` command. - - - -```sql -COPY source_table - TO 'test/table_with_options' - PARTITIONED BY (column3, column4) - OPTIONS ( - format parquet, - compression snappy, - 'compression::column1' 'zstd(5)', - ) -``` - -In this example, we write the entirety of `source_table` out to a folder of parquet files. One parquet file will be written in parallel to the folder for each partition in the query. The next option `compression` set to `snappy` indicates that unless otherwise specified all columns should use the snappy compression codec. The option `compression::col1` sets an override, so that the column `col1` in the parquet file will use `ZSTD` compression codec with compression level `5`. In general, parquet options which support column specific settings can be specified with the syntax `OPTION::COLUMN.NESTED.PATH`. - -## Available Options - -### Execution Specific Options - -The following options are available when executing a `COPY` query. - -| Option | Description | Default Value | -| ----------------------------------- | ---------------------------------------------------------------------------------- | ------------- | -| execution.keep_partition_by_columns | Flag to retain the columns in the output data when using `PARTITIONED BY` queries. | false | - -Note: `execution.keep_partition_by_columns` flag can also be enabled through `ExecutionOptions` within `SessionConfig`. - -### JSON Format Specific Options - -The following options are available when writing JSON files. Note: If any unsupported option is specified, an error will be raised and the query will fail. - -| Option | Description | Default Value | -| ----------- | ---------------------------------------------------------------------------------------------------------------------------------- | ------------- | -| COMPRESSION | Sets the compression that should be applied to the entire JSON file. Supported values are GZIP, BZIP2, XZ, ZSTD, and UNCOMPRESSED. | UNCOMPRESSED | - -### CSV Format Specific Options - -The following options are available when writing CSV files. Note: if any unsupported options is specified an error will be raised and the query will fail. - -| Option | Description | Default Value | -| --------------- | --------------------------------------------------------------------------------------------------------------------------------- | ---------------- | -| COMPRESSION | Sets the compression that should be applied to the entire CSV file. Supported values are GZIP, BZIP2, XZ, ZSTD, and UNCOMPRESSED. | UNCOMPRESSED | -| HEADER | Sets if the CSV file should include column headers | false | -| DATE_FORMAT | Sets the format that dates should be encoded in within the CSV file | arrow-rs default | -| DATETIME_FORMAT | Sets the format that datetimes should be encoded in within the CSV file | arrow-rs default | -| TIME_FORMAT | Sets the format that times should be encoded in within the CSV file | arrow-rs default | -| RFC3339 | If true, uses RFC339 format for date and time encodings | arrow-rs default | -| NULL_VALUE | Sets the string which should be used to indicate null values within the CSV file. | arrow-rs default | -| DELIMITER | Sets the character which should be used as the column delimiter within the CSV file. | arrow-rs default | - -### Parquet Format Specific Options - -The following options are available when writing parquet files. If any unsupported option is specified an error will be raised and the query will fail. If a column specific option is specified for a column which does not exist, the option will be ignored without error. For default values, see: [Configuration Settings](https://datafusion.apache.org/user-guide/configs.html). - -| Option | Can be Column Specific? | Description | -| ---------------------------- | ----------------------- | ----------------------------------------------------------------------------------------------------------------------------------- | -| COMPRESSION | Yes | Sets the compression codec and if applicable compression level to use | -| MAX_ROW_GROUP_SIZE | No | Sets the maximum number of rows that can be encoded in a single row group. Larger row groups require more memory to write and read. | -| DATA_PAGESIZE_LIMIT | No | Sets the best effort maximum page size in bytes | -| WRITE_BATCH_SIZE | No | Maximum number of rows written for each column in a single batch | -| WRITER_VERSION | No | Parquet writer version (1.0 or 2.0) | -| DICTIONARY_PAGE_SIZE_LIMIT | No | Sets best effort maximum dictionary page size in bytes | -| CREATED_BY | No | Sets the "created by" property in the parquet file | -| COLUMN_INDEX_TRUNCATE_LENGTH | No | Sets the max length of min/max value fields in the column index. | -| DATA_PAGE_ROW_COUNT_LIMIT | No | Sets best effort maximum number of rows in a data page. | -| BLOOM_FILTER_ENABLED | Yes | Sets whether a bloom filter should be written into the file. | -| ENCODING | Yes | Sets the encoding that should be used (e.g. PLAIN or RLE) | -| DICTIONARY_ENABLED | Yes | Sets if dictionary encoding is enabled. Use this instead of ENCODING to set dictionary encoding. | -| STATISTICS_ENABLED | Yes | Sets if statistics are enabled at PAGE or ROW_GROUP level. | -| MAX_STATISTICS_SIZE | Yes | Sets the maximum size in bytes that statistics can take up. | -| BLOOM_FILTER_FPP | Yes | Sets the false positive probability (fpp) for the bloom filter. Implicitly sets BLOOM_FILTER_ENABLED to true. | -| BLOOM_FILTER_NDV | Yes | Sets the number of distinct values (ndv) for the bloom filter. Implicitly sets bloom_filter_enabled to true. | diff --git a/parquet-testing b/parquet-testing index 6e851ddd768d..107b36603e05 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit 6e851ddd768d6af741c7b15dc594874399fc3cff +Subproject commit 107b36603e051aee26bd93e04b871034f6c756c0 diff --git a/rust-toolchain.toml b/rust-toolchain.toml index a85e6fa54299..c52dd7322d9a 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -19,5 +19,5 @@ # to compile this workspace and run CI jobs. [toolchain] -channel = "1.86.0" +channel = "1.87.0" components = ["rustfmt", "clippy"] diff --git a/test-utils/src/array_gen/binary.rs b/test-utils/src/array_gen/binary.rs index d342118fa85d..9740eeae5e7f 100644 --- a/test-utils/src/array_gen/binary.rs +++ b/test-utils/src/array_gen/binary.rs @@ -46,11 +46,11 @@ impl BinaryArrayGenerator { // Pick num_binaries randomly from the distinct binary table let indices: UInt32Array = (0..self.num_binaries) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_binaries > 1 { let range = 0..(self.num_distinct_binaries as u32); - Some(self.rng.gen_range(range)) + Some(self.rng.random_range(range)) } else { Some(0) } @@ -68,11 +68,11 @@ impl BinaryArrayGenerator { let indices: UInt32Array = (0..self.num_binaries) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_binaries > 1 { let range = 0..(self.num_distinct_binaries as u32); - Some(self.rng.gen_range(range)) + Some(self.rng.random_range(range)) } else { Some(0) } @@ -88,7 +88,7 @@ fn random_binary(rng: &mut StdRng, max_len: usize) -> Vec { if max_len == 0 { Vec::new() } else { - let len = rng.gen_range(1..=max_len); - (0..len).map(|_| rng.gen()).collect() + let len = rng.random_range(1..=max_len); + (0..len).map(|_| rng.random()).collect() } } diff --git a/test-utils/src/array_gen/boolean.rs b/test-utils/src/array_gen/boolean.rs index f3b83dd245f7..004d615b4caa 100644 --- a/test-utils/src/array_gen/boolean.rs +++ b/test-utils/src/array_gen/boolean.rs @@ -34,7 +34,7 @@ impl BooleanArrayGenerator { // Table of booleans from which to draw (distinct means 1 or 2) let distinct_booleans: BooleanArray = match self.num_distinct_booleans { 1 => { - let value = self.rng.gen::(); + let value = self.rng.random::(); let mut builder = BooleanBuilder::with_capacity(1); builder.append_value(value); builder.finish() @@ -51,10 +51,10 @@ impl BooleanArrayGenerator { // Generate indices to select from the distinct booleans let indices: UInt32Array = (0..self.num_booleans) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_booleans > 1 { - Some(self.rng.gen_range(0..self.num_distinct_booleans as u32)) + Some(self.rng.random_range(0..self.num_distinct_booleans as u32)) } else { Some(0) } diff --git a/test-utils/src/array_gen/decimal.rs b/test-utils/src/array_gen/decimal.rs index d46ea9fe5457..c5ec8ac5e893 100644 --- a/test-utils/src/array_gen/decimal.rs +++ b/test-utils/src/array_gen/decimal.rs @@ -62,11 +62,11 @@ impl DecimalArrayGenerator { // pick num_decimals randomly from the distinct decimal table let indices: UInt32Array = (0..self.num_decimals) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_decimals > 1 { let range = 1..(self.num_distinct_decimals as u32); - Some(self.rng.gen_range(range)) + Some(self.rng.random_range(range)) } else { Some(0) } diff --git a/test-utils/src/array_gen/primitive.rs b/test-utils/src/array_gen/primitive.rs index 58d39c14e65d..62a38a1b4ce1 100644 --- a/test-utils/src/array_gen/primitive.rs +++ b/test-utils/src/array_gen/primitive.rs @@ -18,7 +18,8 @@ use arrow::array::{ArrayRef, ArrowPrimitiveType, PrimitiveArray, UInt32Array}; use arrow::datatypes::DataType; use chrono_tz::{Tz, TZ_VARIANTS}; -use rand::{rngs::StdRng, seq::SliceRandom, thread_rng, Rng}; +use rand::prelude::IndexedRandom; +use rand::{rng, rngs::StdRng, Rng}; use std::sync::Arc; use super::random_data::RandomNativeData; @@ -66,6 +67,7 @@ impl PrimitiveArrayGenerator { | DataType::Time32(_) | DataType::Time64(_) | DataType::Interval(_) + | DataType::Duration(_) | DataType::Binary | DataType::LargeBinary | DataType::BinaryView @@ -81,11 +83,11 @@ impl PrimitiveArrayGenerator { // pick num_primitives randomly from the distinct string table let indices: UInt32Array = (0..self.num_primitives) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_primitives > 1 { let range = 1..(self.num_distinct_primitives as u32); - Some(self.rng.gen_range(range)) + Some(self.rng.random_range(range)) } else { Some(0) } @@ -102,7 +104,7 @@ impl PrimitiveArrayGenerator { /// - `Some(Arc)` containing the timezone name. /// - `None` if no timezone is selected. fn generate_timezone() -> Option> { - let mut rng = thread_rng(); + let mut rng = rng(); // Allows for timezones + None let mut timezone_options: Vec> = vec![None]; diff --git a/test-utils/src/array_gen/random_data.rs b/test-utils/src/array_gen/random_data.rs index a7297d45fdf0..78518b7bf9dc 100644 --- a/test-utils/src/array_gen/random_data.rs +++ b/test-utils/src/array_gen/random_data.rs @@ -17,15 +17,16 @@ use arrow::array::ArrowPrimitiveType; use arrow::datatypes::{ - i256, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTime, - IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, - IntervalYearMonthType, Time32MillisecondType, Time32SecondType, - Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, - UInt32Type, UInt64Type, UInt8Type, + i256, Date32Type, Date64Type, Decimal128Type, Decimal256Type, + DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, + DurationSecondType, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, + Int8Type, IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano, + IntervalMonthDayNanoType, IntervalYearMonthType, Time32MillisecondType, + Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; -use rand::distributions::Standard; +use rand::distr::StandardUniform; use rand::prelude::Distribution; use rand::rngs::StdRng; use rand::Rng; @@ -40,11 +41,11 @@ macro_rules! basic_random_data { ($ARROW_TYPE: ty) => { impl RandomNativeData for $ARROW_TYPE where - Standard: Distribution, + StandardUniform: Distribution, { #[inline] fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { - rng.gen::() + rng.random::() } } }; @@ -71,11 +72,16 @@ basic_random_data!(TimestampSecondType); basic_random_data!(TimestampMillisecondType); basic_random_data!(TimestampMicrosecondType); basic_random_data!(TimestampNanosecondType); +// Note DurationSecondType is restricted to i64::MIN / 1000 to i64::MAX / 1000 +// due to https://github.com/apache/arrow-rs/issues/7533 so handle it specially below +basic_random_data!(DurationMillisecondType); +basic_random_data!(DurationMicrosecondType); +basic_random_data!(DurationNanosecondType); impl RandomNativeData for Date64Type { fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { // TODO: constrain this range to valid dates if necessary - let date_value = rng.gen_range(i64::MIN..=i64::MAX); + let date_value = rng.random_range(i64::MIN..=i64::MAX); let millis_per_day = 86_400_000; date_value - (date_value % millis_per_day) } @@ -84,8 +90,8 @@ impl RandomNativeData for Date64Type { impl RandomNativeData for IntervalDayTimeType { fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { IntervalDayTime { - days: rng.gen::(), - milliseconds: rng.gen::(), + days: rng.random::(), + milliseconds: rng.random::(), } } } @@ -93,15 +99,24 @@ impl RandomNativeData for IntervalDayTimeType { impl RandomNativeData for IntervalMonthDayNanoType { fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { IntervalMonthDayNano { - months: rng.gen::(), - days: rng.gen::(), - nanoseconds: rng.gen::(), + months: rng.random::(), + days: rng.random::(), + nanoseconds: rng.random::(), } } } +// Restrict Duration(Seconds) to i64::MIN / 1000 to i64::MAX / 1000 to +// avoid panics on pretty printing. See +// https://github.com/apache/arrow-rs/issues/7533 +impl RandomNativeData for DurationSecondType { + fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { + rng.random::() / 1000 + } +} + impl RandomNativeData for Decimal256Type { fn generate_random_native_data(rng: &mut StdRng) -> Self::Native { - i256::from_parts(rng.gen::(), rng.gen::()) + i256::from_parts(rng.random::(), rng.random::()) } } diff --git a/test-utils/src/array_gen/string.rs b/test-utils/src/array_gen/string.rs index ac659ae67bc0..546485fd8dc1 100644 --- a/test-utils/src/array_gen/string.rs +++ b/test-utils/src/array_gen/string.rs @@ -18,6 +18,7 @@ use arrow::array::{ ArrayRef, GenericStringArray, OffsetSizeTrait, StringViewArray, UInt32Array, }; +use rand::distr::StandardUniform; use rand::rngs::StdRng; use rand::Rng; @@ -47,11 +48,11 @@ impl StringArrayGenerator { // pick num_strings randomly from the distinct string table let indices: UInt32Array = (0..self.num_strings) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_strings > 1 { let range = 1..(self.num_distinct_strings as u32); - Some(self.rng.gen_range(range)) + Some(self.rng.random_range(range)) } else { Some(0) } @@ -71,11 +72,11 @@ impl StringArrayGenerator { // pick num_strings randomly from the distinct string table let indices: UInt32Array = (0..self.num_strings) .map(|_| { - if self.rng.gen::() < self.null_pct { + if self.rng.random::() < self.null_pct { None } else if self.num_distinct_strings > 1 { let range = 1..(self.num_distinct_strings as u32); - Some(self.rng.gen_range(range)) + Some(self.rng.random_range(range)) } else { Some(0) } @@ -92,10 +93,10 @@ fn random_string(rng: &mut StdRng, max_len: usize) -> String { // pick characters at random (not just ascii) match max_len { 0 => "".to_string(), - 1 => String::from(rng.gen::()), + 1 => String::from(rng.random::()), _ => { - let len = rng.gen_range(1..=max_len); - rng.sample_iter::(rand::distributions::Standard) + let len = rng.random_range(1..=max_len); + rng.sample_iter::(StandardUniform) .take(len) .collect() } diff --git a/test-utils/src/data_gen.rs b/test-utils/src/data_gen.rs index 7ac6f3d3e255..2228010b28dd 100644 --- a/test-utils/src/data_gen.rs +++ b/test-utils/src/data_gen.rs @@ -104,10 +104,11 @@ impl BatchBuilder { } fn append(&mut self, rng: &mut StdRng, host: &str, service: &str) { - let num_pods = rng.gen_range(self.options.pods_per_host.clone()); + let num_pods = rng.random_range(self.options.pods_per_host.clone()); let pods = generate_sorted_strings(rng, num_pods, 30..40); for pod in pods { - let num_containers = rng.gen_range(self.options.containers_per_pod.clone()); + let num_containers = + rng.random_range(self.options.containers_per_pod.clone()); for container_idx in 0..num_containers { let container = format!("{service}_container_{container_idx}"); let image = format!( @@ -115,7 +116,7 @@ impl BatchBuilder { ); let num_entries = - rng.gen_range(self.options.entries_per_container.clone()); + rng.random_range(self.options.entries_per_container.clone()); for i in 0..num_entries { if self.is_finished() { return; @@ -154,7 +155,7 @@ impl BatchBuilder { if self.options.include_nulls { // Append a null value if the option is set // Use both "NULL" as a string and a null value - if rng.gen_bool(0.5) { + if rng.random_bool(0.5) { self.client_addr.append_null(); } else { self.client_addr.append_value("NULL"); @@ -162,26 +163,26 @@ impl BatchBuilder { } else { self.client_addr.append_value(format!( "{}.{}.{}.{}", - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::() + rng.random::(), + rng.random::(), + rng.random::(), + rng.random::() )); } - self.request_duration.append_value(rng.gen()); + self.request_duration.append_value(rng.random()); self.request_user_agent .append_value(random_string(rng, 20..100)); self.request_method - .append_value(methods[rng.gen_range(0..methods.len())]); + .append_value(methods[rng.random_range(0..methods.len())]); self.request_host .append_value(format!("https://{service}.mydomain.com")); self.request_bytes - .append_option(rng.gen_bool(0.9).then(|| rng.gen())); + .append_option(rng.random_bool(0.9).then(|| rng.random())); self.response_bytes - .append_option(rng.gen_bool(0.9).then(|| rng.gen())); + .append_option(rng.random_bool(0.9).then(|| rng.random())); self.response_status - .append_value(status[rng.gen_range(0..status.len())]); + .append_value(status[rng.random_range(0..status.len())]); self.prices_status.append_value(self.row_count as i128); } @@ -216,9 +217,9 @@ impl BatchBuilder { } fn random_string(rng: &mut StdRng, len_range: Range) -> String { - let len = rng.gen_range(len_range); + let len = rng.random_range(len_range); (0..len) - .map(|_| rng.gen_range(b'a'..=b'z') as char) + .map(|_| rng.random_range(b'a'..=b'z') as char) .collect::() } @@ -364,7 +365,7 @@ impl Iterator for AccessLogGenerator { self.host_idx += 1; for service in &["frontend", "backend", "database", "cache"] { - if self.rng.gen_bool(0.5) { + if self.rng.random_bool(0.5) { continue; } if builder.is_finished() { diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 47f23de4951e..be2bc0712afb 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -67,9 +67,9 @@ pub fn add_empty_batches( .flat_map(|batch| { // insert 0, or 1 empty batches before and after the current batch let empty_batch = RecordBatch::new_empty(schema.clone()); - std::iter::repeat_n(empty_batch.clone(), rng.gen_range(0..2)) + std::iter::repeat_n(empty_batch.clone(), rng.random_range(0..2)) .chain(std::iter::once(batch)) - .chain(std::iter::repeat_n(empty_batch, rng.gen_range(0..2))) + .chain(std::iter::repeat_n(empty_batch, rng.random_range(0..2))) }) .collect() } @@ -100,7 +100,7 @@ pub fn stagger_batch_with_seed(batch: RecordBatch, seed: u64) -> Vec 0 { - let batch_size = rng.gen_range(0..remainder.num_rows() + 1); + let batch_size = rng.random_range(0..remainder.num_rows() + 1); batches.push(remainder.slice(0, batch_size)); remainder = remainder.slice(batch_size, remainder.num_rows() - batch_size); diff --git a/test-utils/src/string_gen.rs b/test-utils/src/string_gen.rs index b598241db1e9..75ed03898a27 100644 --- a/test-utils/src/string_gen.rs +++ b/test-utils/src/string_gen.rs @@ -19,7 +19,7 @@ use crate::array_gen::StringArrayGenerator; use crate::stagger_batch; use arrow::record_batch::RecordBatch; use rand::rngs::StdRng; -use rand::{thread_rng, Rng, SeedableRng}; +use rand::{rng, Rng, SeedableRng}; /// Randomly generate strings pub struct StringBatchGenerator(StringArrayGenerator); @@ -56,18 +56,18 @@ impl StringBatchGenerator { stagger_batch(batch) } - /// Return an set of `BatchGenerator`s that cover a range of interesting + /// Return a set of `BatchGenerator`s that cover a range of interesting /// cases pub fn interesting_cases() -> Vec { let mut cases = vec![]; - let mut rng = thread_rng(); + let mut rng = rng(); for null_pct in [0.0, 0.01, 0.1, 0.5] { for _ in 0..10 { // max length of generated strings - let max_len = rng.gen_range(1..50); - let num_strings = rng.gen_range(1..100); + let max_len = rng.random_range(1..50); + let num_strings = rng.random_range(1..100); let num_distinct_strings = if num_strings > 1 { - rng.gen_range(1..num_strings) + rng.random_range(1..num_strings) } else { num_strings }; @@ -76,7 +76,7 @@ impl StringBatchGenerator { num_strings, num_distinct_strings, null_pct, - rng: StdRng::from_seed(rng.gen()), + rng: StdRng::from_seed(rng.random()), })) } }